In [None]:
import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
import pickle

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
import torchvision
import torch.nn.functional as F
from torch.autograd import Variable

from PIL import Image
import cv2
import albumentations as albu

import time
import os
from tqdm.notebook import tqdm

# from torchsummary import summary
# import segmentation_models_pytorch as smp

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

### Preprocessing

In [None]:
Train_IMAGE_PATH = '../input/final-datasetbenignetpvmf/Data_Benign+ET+PV+MF/Train/'
Train_MASK_PATH = '../input/final-datasetbenignetpvmf/Data_Benign+ET+PV+MF/Train_gt/'

In [None]:
n_classes = 5 

def create_df(IMAGE_PATH):
    name = []
    for dirname, _, filenames in os.walk(IMAGE_PATH):
        for filename in filenames:
            name.append(filename.split('.')[0])
    
    return pd.DataFrame({'id': name}, index = np.arange(0, len(name)))

df = create_df(Train_IMAGE_PATH)
print('Total Patches: ', len(df))

In [None]:
Val_IMAGE_PATH = "../input/final-datasetbenignetpvmf/Data_Benign+ET+PV+MF/Validation/"
Val_MASK_PATH = "../input/final-datasetbenignetpvmf/Data_Benign+ET+PV+MF/Validation_gt/"
# Test_IMAGE_PATH = "../input/new-data/val_full-20211130T184212Z-001/val_full/images/"
# Test_MASK_PATH = "../input/new-data/val_full-20211130T184212Z-001/val_full/mask/"

In [None]:
df_val = create_df(Val_IMAGE_PATH)
print('Total Val Images: ', len(df_val))
print(df_val)

In [None]:
print(df_val['id'][50])

In [None]:
# df_test = create_df(Test_IMAGE_PATH)
# print('Total Test Images: ', len(df_test))

In [None]:
print(df['id'][50])

In [None]:
print(df)

In [None]:
img = Image.open(Train_IMAGE_PATH  + df['id'][50] + '.png')

mask = Image.open(Train_MASK_PATH  + df['id'][50] + '_gt.png')
print('Image Size', np.asarray(img).shape)
print('Mask Size', np.asarray(mask).shape)


plt.imshow(img)
plt.imshow(mask, alpha=0.6)
plt.title('Picture with Mask Appplied')
plt.show()

In [None]:
class_map = [[ 0. , 0. , 0.], [0. , 255. , 0.], [255. , 0., 0.] ,[0. , 0. , 255.],[0. , 255. , 255.],[255. , 255. , 0.] ] #Black[0] , Green[1] , Red[2] , Blue[3], Turqoise[4], Yellow[5]

In [None]:
'''This method will convert mask labels(to be trained) from RGB to a 2D image whic holds class labels of the pixels.'''
def form_2D_label(mask,class_map):
    # plt.imshow(mask)
    
    mask = mask.astype("uint8")
    label = np.zeros(mask.shape[:2],dtype= np.uint8)
    
    for i, rgb in enumerate(class_map):
        label[(mask == rgb).all(axis=2)] = i
    
    return label

In [None]:
class_count = np.asarray([560201119,  95301323,  29131557 , 29131557, 29131557, 29131557]) #change these pixel counts for six classes
tot = np.sum(class_count)

class_weights = (tot - class_count)/tot
print(class_weights)
weights = torch.tensor(list(class_weights)).to(device, dtype = torch.float)

In [None]:
transform = albu.Compose([
    albu.HorizontalFlip(p=0.2),
    albu.VerticalFlip(p=0.2),
    albu.CropAndPad (percent = -0.2, keep_size=True, interpolation=cv2.INTER_NEAREST, p=0.2),
    albu.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit= 15, interpolation= cv2.INTER_NEAREST,
                                          border_mode= cv2.BORDER_REPLICATE, p=0.2)
])

In [None]:
from torch._C import NoneType
class CytoDataset(Dataset):
    
    def __init__(self, img_path, mask_path, X , transform=transform):
        self.img_path = img_path
        self.mask_path = mask_path
        self.X = X
        self.transform = transform
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        if self.transform == None:
            img = cv2.imread(self.img_path  + self.X[idx] + '.png')
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            mask = cv2.imread(self.mask_path  + self.X[idx] + '_gt.png')
            mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
        else:
            img = cv2.imread(self.img_path  + self.X[idx] + '.png')
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            mask = cv2.imread(self.mask_path  + self.X[idx] + '_gt.png')
            mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
            augmentations = transform(image=img,mask=mask)
            img = augmentations['image']
            mask = augmentations['mask']
        # plt.imshow(mask)
        mask_2 = form_2D_label(mask, class_map)
        
        img = (img - img.min())/(img.max() - img.min())
        #, T.Normalize(self.mean, self.std)
        t = T.Compose([T.ToTensor()])
        img = t(img)
        mask_2 = torch.from_numpy(mask_2).long()
        # print(img.shape)
        # print(mask_2.shape)
            
        return img, mask_2

In [None]:
class CytoDataset_val(Dataset):
    
    def __init__(self, img_path, mask_path, X , transform=None):
        self.img_path = img_path
        self.mask_path = mask_path
        self.X = X
        self.transform = transform
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        img = cv2.imread(self.img_path  + self.X[idx] + '.png')
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.mask_path  + self.X[idx] + '_gt.png')
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
        # plt.imshow(mask)
        mask_2 = form_2D_label(mask, class_map)
        
        img = (img - img.min())/(img.max() - img.min())
        #, T.Normalize(self.mean, self.std)
        t = T.Compose([T.ToTensor()])
        img = t(img)
        mask_2 = torch.from_numpy(mask_2).long()
        # print(img.shape)
        # print(mask_2.shape)
            
        return img, mask_2

In [None]:
#with augmentation
train_set = CytoDataset(Train_IMAGE_PATH, Train_MASK_PATH, df['id'].values)
val_set = CytoDataset(Val_IMAGE_PATH, Val_MASK_PATH, df_val['id'].values,transform=None)
#without augmentation
# train_set = CytoDataset(Train_IMAGE_PATH, Train_MASK_PATH, df['id'].values,transform=None)
# val_set = CytoDataset(Val_IMAGE_PATH, Val_MASK_PATH, df_val['id'].values,transform=None)

#dataloader
batch_size= 1

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True)  

In [None]:
 print( train_set.__len__())

In [None]:
print(val_set.__len__())

In [None]:
print(val_loader)

In [None]:
# ---------------------------------- Unet model --------------------------------------

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

#         self.inc = DoubleConv(n_channels, 16)
#         self.down1 = Down(16, 32) 
#         self.down2 = Down(32, 64)
#         self.down3 = Down(64, 128)
#         self.down4 = Down(128, 128)
#         self.up1 = Up(256, 64, bilinear)
#         self.up2 = Up(128, 32, bilinear)
#         self.up3 = Up(64, 16, bilinear)
#         self.up4 = Up(32, 16, bilinear)
#         self.outc = OutConv(16, n_classes)
        
        self.inc = DoubleConv(n_channels, 16*4)
        self.down1 = Down(16*4, 32*4) 
        self.down2 = Down(32*4, 64*4)
        self.down3 = Down(64*4, 128*4)
        self.down4 = Down(128*4, 128*4)
        self.up1 = Up(256*4, 64*4, bilinear)
        self.up2 = Up(128*4, 32*4, bilinear)
        self.up3 = Up(64*4, 16*4, bilinear)
        self.up4 = Up(32*4, 16*4, bilinear)
        self.outc = OutConv(16*4, n_classes)
        
#         self.inc = DoubleConv(n_channels, 16*8)
#         self.down1 = Down(16*8, 32*8) 
#         self.down2 = Down(32*8, 64*8)
#         self.down3 = Down(64*8, 128*8)
#         self.down4 = Down(128*8, 128*8)
#         self.up1 = Up(256*8, 64*8, bilinear)
#         self.up2 = Up(128*8, 32*8, bilinear)
#         self.up3 = Up(64*8, 16*8, bilinear)
#         self.up4 = Up(32*8, 16*8, bilinear)
#         self.outc = OutConv(16*8, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable
from packaging import version

class CrossEntropy2d(nn.Module):

    def __init__(self, size_average=True, ignore_label=255):
        super(CrossEntropy2d, self).__init__()
        self.size_average = size_average
        self.ignore_label = ignore_label

    def forward(self, predict, target, weight= None):
        """
            Args:
                predict:(n, c, h, w)
                target:(n, h, w)
                weight (Tensor, optional): a manual rescaling weight given to each class.
                                           If given, has to be a Tensor of size "nclasses"
        """
        assert not target.requires_grad
        assert predict.dim() == 4
        assert target.dim() == 3
        assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0))
        assert predict.size(2) == target.size(1), "{0} vs {1} ".format(predict.size(2), target.size(1))
        assert predict.size(3) == target.size(2), "{0} vs {1} ".format(predict.size(3), target.size(3))
        n, c, h, w = predict.size()
        target_mask = (target >= 0) * (target != self.ignore_label)
        target = target[target_mask]
        if not target.data.dim():
            return Variable(torch.zeros(1))
        predict = predict.transpose(1, 2).transpose(2, 3).contiguous()
        predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c)
        loss = F.cross_entropy(predict, target, weight=weight, size_average=self.size_average)
        return loss


class BCEWithLogitsLoss2d(nn.Module):

    def __init__(self, size_average=True, ignore_label=255):
        super(BCEWithLogitsLoss2d, self).__init__()
        self.size_average = size_average
        self.ignore_label = ignore_label

    def forward(self, predict, target, weight=None):
        """
            Args:
                predict:(n, 1, h, w)
                target:(n, 1, h, w)
                weight (Tensor, optional): a manual rescaling weight given to each class.
                                           If given, has to be a Tensor of size "nclasses"
        """
        assert not target.requires_grad
        assert predict.dim() == 4
        assert target.dim() == 4
        assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0))
        assert predict.size(2) == target.size(2), "{0} vs {1} ".format(predict.size(2), target.size(2))
        assert predict.size(3) == target.size(3), "{0} vs {1} ".format(predict.size(3), target.size(3))
        n, c, h, w = predict.size()
        target_mask = (target >= 0) * (target != self.ignore_label)
        target = target[target_mask]
        if not target.data.dim():
            return Variable(torch.zeros(1))
        predict = predict[target_mask]
        loss = F.binary_cross_entropy_with_logits(predict, target, weight=weight, size_average=self.size_average)
        return loss

In [None]:
args = {'IMG_MEAN' : np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32),  #calculate the mean of RGB in real images of train data
        'MODEL' : 'U_net' , 'BATCH_SIZE' : 1,
'ITER_SIZE' : 1,  'NUM_WORKERS' : 4, 
'IGNORE_LABEL' : 255 , 'INPUT_SIZE' : '512,512' , 'LEARNING_RATE' : 2.5e-4 , 'MOMENTUM' : 0.9,  'NUM_CLASSES' : 6, 
'NUM_STEPS' : 20000, 
'POWER' : 0.9,
'RANDOM_SEED' : 1234,
'SAVE_NUM_IMAGES' : 2,
'SAVE_PRED_EVERY' : 5000,
'SNAPSHOT_DIR' : './snapshots/',
'WEIGHT_DECAY' : 0.0005,

'LEARNING_RATE_D' : 1e-4,
'LAMBDA_ADV_PRED' : 0.1,
'PARTIAL_DATA' : None, #0.5,

'SEMI_START' : 5000,
'LAMBDA_SEMI' : 0.1,
'MASK_T' : 0.2,

'LAMBDA_SEMI_ADV':0.001,
'SEMI_START_ADV' : 0,
'D_REMAIN' : False, 'GPU':True,
       'RESTORE_FROM' : 'http://vllab1.ucmerced.edu/~whung/adv-semi-seg/resnet101COCO-41f33a49.pth'}

In [None]:
try:
    import copy_reg
except:
    import copyreg as copy_reg

In [None]:
def pixel_accuracy(output, mask):
    with torch.no_grad():
        output = torch.argmax(F.softmax(output, dim=1), dim=1)
        correct = torch.eq(output, mask).int()
        accuracy = float(correct.sum()) / float(correct.numel())
    return accuracy

def mIoU(pred_mask, mask, smooth=1e-10, n_classes=5):
    with torch.no_grad():
        pred_mask = F.softmax(pred_mask, dim=1)
        pred_mask = torch.argmax(pred_mask, dim=1)
        pred_mask = pred_mask.contiguous().view(-1)
        mask = mask.contiguous().view(-1)

        iou_per_class = []
        for clas in range(0, n_classes): #loop per pixel class
            true_class = pred_mask == clas
            true_label = mask == clas

            if true_label.long().sum().item() == 0: #no exist label in this loop
                iou_per_class.append(np.nan)
            else:
                intersect = torch.logical_and(true_class, true_label).sum().float().item()
                union = torch.logical_or(true_class, true_label).sum().float().item()

                iou = (intersect + smooth) / (union +smooth)
                iou_per_class.append(iou)
        return np.nanmean(iou_per_class)
    
def map_thiss(y_pred,class_map):
    y_pred_rgb = np.zeros((y_pred.shape[0],y_pred.shape[1],y_pred.shape[2],3))
    for i in range(y_pred.shape[0]):
        image = np.zeros((y_pred.shape[1],y_pred.shape[2],3))
        for j in range(y_pred.shape[1]):
            for k in range(y_pred.shape[2]):
                image[j,k,:] = class_map[int(y_pred[i][j][k])]
        y_pred_rgb[i] = image
    return y_pred_rgb

def plot_result(img, title):
    plt.figure(figsize=(12, 6))
    plt.title(title)
    plt.imshow(img[0])
    plt.show()
    
def export_model(model, optimizer=None, name=None, step=None):

        # set output filename
        if name is not None:
            out_file = name
        else:
            out_file = "checkpoint"
        if step is not None:
            out_file += "_step_" + str(step)
            
        out_file = os.path.join("./", out_file + ".pth")

        # save model
        data = {"model_state_dict": model.state_dict()}
        if step is not None:
            data["step"] = step
        if optimizer is not None:
            data["optimizer_state_dict"] = optimizer.state_dict()
        torch.save(data, out_file)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def loss_calc(pred, label, gpu):
    """
    This function returns cross entropy loss for semantic segmentation
    """
    # out shape batch_size x channels x h x w -> batch_size x channels x h x w
    # label shape h x w x 1 x batch_size  -> batch_size x 1 x h x w
    label = Variable(label.long()).cuda(device)
    criterion = CrossEntropy2d().cuda(device)

    return criterion(pred, label)


def lr_poly(base_lr, iter, max_iter, power):
    return base_lr*((1-float(iter)/max_iter)**(power))


def adjust_learning_rate(optimizer, i_iter):
    lr = lr_poly(args["LEARNING_RATE"], i_iter, args["NUM_STEPS"], args["POWER"])
    optimizer.param_groups[0]['lr'] = lr
    if len(optimizer.param_groups) > 1 :
        optimizer.param_groups[1]['lr'] = lr * 10

def adjust_learning_rate_D(optimizer, i_iter):
    lr = lr_poly(args["LEARNING_RATE_D"], i_iter, args["NUM_STEPS"], args["POWER"])
    optimizer.param_groups[0]['lr'] = lr
    if len(optimizer.param_groups) > 1 :
        optimizer.param_groups[1]['lr'] = lr * 10

def one_hot(label):
    label = label.numpy()
    one_hot = np.zeros((label.shape[0], args["NUM_CLASSES"], label.shape[1], label.shape[2]), dtype=label.dtype)
    for i in range(args["NUM_CLASSES"]):
        one_hot[:,i,...] = (label==i)
    #handle ignore labels
    return torch.FloatTensor(one_hot)

def make_D_label(label, ignore_mask):
    ignore_mask = np.expand_dims(ignore_mask, axis=1)
    D_label = np.ones(ignore_mask.shape)*label
    D_label[ignore_mask] = 255
    D_label = Variable(torch.FloatTensor(D_label)).cuda(device)

    return D_label

In [None]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def fit(epochs, model, train_loader, val_loader, criterion, optimizer, scheduler, patch=False):
    torch.cuda.empty_cache()
    train_losses = []
    test_losses = []
    val_iou = []; val_acc = []
    train_iou = []; train_acc = []
    lrs = []
    min_loss = np.inf
    decrease = 1 ; not_improve=0

    model.to(device)
    fit_time = time.time()
    for e in range(epochs):
        since = time.time()
        running_loss = 0
        iou_score = 0
        accuracy = 0
        #training loop
        model.train()
        for i, data in enumerate(tqdm(train_loader)):
            #training phase
            image_tiles, mask_tiles = data
            # print(image_tiles.shape)
            # print(mask_tiles.shape)
            image = image_tiles.to(device, dtype = torch.float); mask = mask_tiles.to(device, dtype= torch.float);
            #forward
            output = model(image)
            loss = loss_calc(output, mask, device)
            #evaluation metrics
            iou_score += mIoU(output, mask)
            accuracy += pixel_accuracy(output, mask)
            #backward
            loss.backward()
            optimizer.step() #update weight          
            optimizer.zero_grad() #reset gradient
            
            #step the learning rate
            lrs.append(get_lr(optimizer))
            scheduler.step() 
            
            running_loss += loss.item()
            
        model.eval()
        test_loss = 0
        test_accuracy = 0
        val_iou_score = 0
        #validation loop
        with torch.no_grad():
            for i, data in enumerate(tqdm(val_loader)):
                #reshape to 9 patches from single image, delete batch size
                image_tiles, mask_tiles = data


                image = image_tiles.to(device, dtype = torch.float); mask = mask_tiles.to(device, dtype = torch.float);
                output = model(image)
                #evaluation metrics
                if(i==1):
                    output_soft = F.softmax(output, dim=1)
                    output_num = output_soft.cpu().detach().numpy()
                    pred_mask = np.argmax(output_num, axis = 1)
        
                    y_pred_rgb = map_thiss(pred_mask,class_map)
                    y_test_rgb = map_thiss(mask,class_map)
                    plot_result(y_test_rgb,"Original Masks")
                    plot_result(y_pred_rgb,"Predicted Masks")
                val_iou_score +=  mIoU(output, mask)
                test_accuracy += pixel_accuracy(output, mask)
                #loss
                loss = loss_calc(output, mask, device)                                  
                test_loss += loss.item()
            
        #calculatio mean for each batch
        train_losses.append(running_loss/len(train_loader))
        test_losses.append(test_loss/len(val_loader))


        if min_loss > (test_loss/len(val_loader)):
            print('Loss Decreasing.. {:.3f} >> {:.3f} '.format(min_loss, (test_loss/len(val_loader))))
            min_loss = (test_loss/len(val_loader))
            decrease += 1
            #if decrease % 5 == 0:
            print('saving model...')
            export_model(model, optimizer=optimizer, name="final", step = e)
            #torch.save(model, 'Unet-Mobilenet_v2_val_loss-{:.3f}.pt'.format(test_loss/len(val_loader)))


        if (test_loss/len(val_loader)) > min_loss:
            not_improve += 1
            min_loss = (test_loss/len(val_loader))
            print(f'Loss Not Decrease for {not_improve} time')
            #if not_improve == 7:
                #print('Loss not decrease for 7 times, Stop Training')
                #break

        #iou
        val_iou.append(val_iou_score/len(val_loader))
        train_iou.append(iou_score/len(train_loader))
        train_acc.append(accuracy/len(train_loader))
        val_acc.append(test_accuracy/ len(val_loader))
        print("Epoch:{}/{}..".format(e+1, epochs),
              "Train Loss: {:.3f}..".format(running_loss/len(train_loader)),
              "Val Loss: {:.3f}..".format(test_loss/len(val_loader)),
              "Train mIoU:{:.3f}..".format(iou_score/len(train_loader)),
              "Val mIoU: {:.3f}..".format(val_iou_score/len(val_loader)),
              "Train Acc:{:.3f}..".format(accuracy/len(train_loader)),
              "Val Acc:{:.3f}..".format(test_accuracy/len(val_loader)),
              "Time: {:.2f}m".format((time.time()-since)/60))
        
    history = {'train_loss' : train_losses, 'val_loss': test_losses,
               'train_miou' :train_iou, 'val_miou':val_iou,
               'train_acc' :train_acc, 'val_acc':val_acc,
               'lrs': lrs}
    print('Total time: {:.2f} m' .format((time.time()- fit_time)/60))
    return history

In [None]:
max_lr = 1e-3
epoch = 50
weight_decay = 1e-4

model = UNet(n_classes= 6,
                 n_channels= 3)
criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, weight_decay=weight_decay)
sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=epoch,
                                            steps_per_epoch=len(train_loader))

history = fit(epoch, model, train_loader, val_loader, criterion, optimizer, sched)
# history = fit(epoch, model, val_loader, val_loader, criterion, optimizer, sched)

In [None]:
def plot_loss(history):
    plt.plot(history['val_loss'], label='val', marker='o')
    plt.plot( history['train_loss'], label='train', marker='o')
    plt.title('Loss per epoch'); plt.ylabel('loss');
    plt.xlabel('epoch')
    plt.legend(), plt.grid()
    plt.show()
    
def plot_score(history):
    plt.plot(history['train_miou'], label='train_mIoU', marker='*')
    plt.plot(history['val_miou'], label='val_mIoU',  marker='*')
    plt.title('Score per epoch'); plt.ylabel('mean IoU')
    plt.xlabel('epoch')
    plt.legend(), plt.grid()
    plt.show()
    
def plot_acc(history):
    plt.plot(history['train_acc'], label='train_accuracy', marker='*')
    plt.plot(history['val_acc'], label='val_accuracy',  marker='*')
    plt.title('Accuracy per epoch'); plt.ylabel('Accuracy')
    plt.xlabel('epoch')
    plt.legend(), plt.grid()
    plt.show()

In [None]:
plot_loss(history)
plot_score(history)
plot_acc(history)

In [None]:
with open('./history.pickle', 'wb') as handle:
    pickle.dump(history, handle, protocol=pickle.HIGHEST_PROTOCOL)