In [None]:
import cv2
import copy
import h5py
import imgaug.augmenters as iaa
import numpy as np
from os import getcwd as cwd
from os.path import join as pj
from sklearn.model_selection import KFold
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import visdom

# Logger
from IO.logger import Logger
# model
from model.segnet.segnet import SegNet_
from model.unet.unet import Unet_
from model.optimizer import AdamW

In [None]:
class args:
    experiment_name = "segnet_b20_lr1e-6_aug_pretrain"
    # paths
    all_data_path = pj(cwd(), "data/all_classification_data", "classify_insect_std_20200806_size_seg_step")
    model_root = pj(cwd(), "output_model/size_segmentation", experiment_name)
    # train config
    model_name = "segnet" # select in ["segnet", "unet"]
    bs = 20
    lr = 1e-6
    nepoch = 100
    plus_distance = False # use if loss = segmentation_loss + distance_loss
    alpha = 1e-2 # use if plus_distance = True
    pretrain = True
    # visdom
    visdom = True
    port = 8097

In [None]:
if torch.cuda.is_available():
    args.cuda = True
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
    args.cuda = False
    torch.set_default_tensor_type('torch.FloatTensor')

#### Save args

In [None]:
args_logger = Logger(args)
args_logger.save()

### visdom

In [None]:
if args.visdom:
    # Create visdom
    vis = visdom.Visdom(port=args.port)
    
    win_train_loss = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='train_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    win_train_dist_diff = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='train_dist_diff',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    win_test_loss = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='test_loss',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )
    win_test_dist_diff = vis.line(
        X=np.array([0]),
        Y=np.array([0]),
        opts=dict(
            title='test_dist_diff',
            xlabel='epoch',
            ylabel='loss',
            width=800,
            height=400
        )
    )

In [None]:
def visualize(vis, phase, visualized_data, window):
    vis.line(
        X=np.array([phase]),
        Y=np.array([visualized_data]),
        update='append',
        win=window
    )

### dataset

In [None]:
class size_segmentation_dataset(data.Dataset):
    
    def __init__(self, images, labels=None, training=False, evaluation=False):
        self.images = images
        self.labels = labels
        self.training = training
        self.evaluation = evaluation
        
        aug_list = [
            iaa.pillike.Autocontrast(),
            iaa.Invert(0.5),
            iaa.pillike.Equalize(),
            iaa.Solarize(0.5, threshold=(32, 128)),
            iaa.color.Posterize(),
            iaa.pillike.EnhanceContrast(),
            iaa.pillike.EnhanceColor(),
            iaa.pillike.EnhanceBrightness(),
            iaa.pillike.EnhanceSharpness(),
        ]
        self.aug_seq = iaa.SomeOf((0, 2), aug_list, random_order=True)
        
    def __getitem__(self, index):
        image = self.images[index].astype("uint8")
        if self.training is True:
            image = self.aug_seq(image=image)
        
        image = image.astype("float32")
        image = cv2.normalize(image, image, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX)
        image = image.transpose(2,0,1).astype("float32")
        image = torch.from_numpy(image)
        
        if self.training is True or self.evaluation is True:
            label = self.labels[index].astype("float32")
            return image, label
        else:
            return image
    
    def __len__(self):
        return self.images.shape[0]

### training

In [None]:
def train(model, train_dataloader, test_dataloader, lr=1e-5, nepoch=100, visdom=False):
    # define loss
    l1_loss = nn.L1Loss(reduction='mean')
    
    # define optimizer
    opt = AdamW(model.parameters(), lr=lr)
    
    # set model train mode
    model.train()
    
    for epoch in range(nepoch):
        total_train_loss = 0
        total_test_loss = 0
        total_train_dist_diff = 0
        total_test_dist_diff = 0
        # train
        count = 0
        for image, label in train_dataloader:
            count += 1
            label = label[:, None, :, :]
            if args.cuda is True:
                image = image.cuda()
                label = label.cuda()
            opt.zero_grad()
            out = model(image)
            train_loss = l1_loss(out, label)
            target_distances = calc_distance(label.squeeze(1))
            output_distances = calc_distance(out.squeeze(1))
            if args.cuda is True:
                target_distances = target_distances.cuda()
                output_distances = output_distances.cuda()
            train_dist_diff = l1_loss(target_distances, output_distances)
            if args.plus_distance is True:
                train_loss += args.alpha * train_dist_diff
            total_train_loss += train_loss.item()
            total_train_dist_diff += train_dist_diff.item()
            train_loss.backward()
            opt.step()
        
        print("train: target_dist = {}, output_dist = {}".format(target_distances[0], output_distances[0]))
        total_train_avg_dist_diff = total_train_dist_diff / count
        
        # valid
        model.eval()
        count = 0
        for image, label in test_dataloader:
            count += 1
            label = label[:, None, :, :]
            if args.cuda is True:
                image = image.cuda()
                label = label.cuda()
            out = model(image)
            test_loss = l1_loss(out, label)
            target_distances = calc_distance(label.squeeze(1))
            output_distances = calc_distance(out.squeeze(1))
            if args.cuda is True:
                target_distances = target_distances.cuda()
                output_distances = output_distances.cuda()
            test_dist_diff = l1_loss(target_distances, output_distances)
            total_test_loss += test_loss.item()
            total_test_dist_diff += test_dist_diff.item()
            
        print("test: target_dist = {}, output_dist = {}".format(target_distances[0], output_distances[0]))
        total_test_avg_dist_diff = total_test_dist_diff / count
        model.train()
        
        if visdom:
            visualize(vis, epoch+1, total_train_loss, win_train_loss)
            visualize(vis, epoch+1, total_train_avg_dist_diff, win_train_dist_diff)
            visualize(vis, epoch+1, total_test_loss, win_test_loss)
            visualize(vis, epoch+1, total_test_avg_dist_diff, win_test_dist_diff)
        print("epoch=%s: train_loss=%f, train_dist_diff=%f, test_loss=%f, test_dist_diff=%f" % 
              (epoch, total_train_loss, total_train_avg_dist_diff, 
               total_test_loss, total_test_avg_dist_diff))
        print("---------------")

### utils

In [None]:
def unravel_index(index, shape):
    out = []
    for dim in reversed(shape):
        out.append(int(index % dim))
        index = index // dim
    return tuple(reversed(out))

    
def calc_distance(label):
    label = torch.clone(label)
    distances = []
    for elem_label in label:
        p1 = np.array(unravel_index(torch.argmax(elem_label), elem_label.shape))
        elem_label[p1] = 0.
        p2 = np.array(unravel_index(torch.argmax(elem_label), elem_label.shape))
        distances.append(np.linalg.norm(p1 - p2))
    return torch.Tensor(distances)

In [None]:
# load data
with h5py.File(args.all_data_path) as f:
    images = f["X"][:]
    labels = f["Y"][:]
    
# define kfold
kf = KFold(n_splits=5)
valid_count = 0

# cross validation
for train_index, test_index in kf.split(images):
    print("")
    valid_count += 1
    print("----- valid {} -----".format(valid_count))
    print("")
    # create validation data
    image_train, image_test = images[train_index], images[test_index]
    label_train, label_test = labels[train_index], labels[test_index]
    # create dataloader
    train_dataset = size_segmentation_dataset(image_train, label_train, training=True)
    train_dataloader = data.DataLoader(train_dataset, args.bs, num_workers=0, shuffle=True)
    test_dataset = size_segmentation_dataset(image_test, label_test, training=False, evaluation=True)
    test_dataloader = data.DataLoader(test_dataset, 1, num_workers=0, shuffle=False)
    
    # create model
    print("model = {}".format(args.model_name))
    if args.model_name == "unet":
        model = Unet_(3, 1, pretrained=args.pretrain).cuda()
    else:
        model = SegNet_(3, 1, pretrained=args.pretrain).cuda()
    
    # training
    train(model, train_dataloader, test_dataloader, lr=args.lr, nepoch=args.nepoch, visdom=args.visdom)
    torch.save(model.state_dict(), pj(args.model_root, "final.pth"))