In [None]:
import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
import pickle

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
import torchvision
import torch.nn.functional as F
from torch.autograd import Variable

from PIL import Image
import cv2
import albumentations as albu

import time
import os
from tqdm import tqdm

# from torchsummary import summary
# import segmentation_models_pytorch as smp



In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
Train_IMAGE_PATH = '../input/dataset-new/Data_Benign_plus_ET/Train/'
Train_MASK_PATH = '../input/dataset-new/Data_Benign_plus_ET/Train_gt/'

In [None]:
n_classes = 4 

def create_df(IMAGE_PATH):
    name = []
    for dirname, _, filenames in os.walk(IMAGE_PATH):
        for filename in filenames:
            name.append(filename.split('.')[0])
    
    return pd.DataFrame({'id': name}, index = np.arange(0, len(name)))

df = create_df(Train_IMAGE_PATH)
print('Total Patches: ', len(df))

In [None]:
Val_IMAGE_PATH = "../input/dataset-new/Data_Benign_plus_ET/Validation/"
Val_MASK_PATH = "../input/dataset-new/Data_Benign_plus_ET/Validation_gt/"

In [None]:
df_val = create_df(Val_IMAGE_PATH)
print('Total Val Images: ', len(df_val))
print(df_val)

In [None]:
print(df['id'][50])

In [None]:
print(df)

In [None]:
img = Image.open(Train_IMAGE_PATH  + df['id'][80] + '.png')

mask = Image.open(Train_MASK_PATH  + df['id'][80] + '_gt.png')
print('Image Size', np.asarray(img).shape)
print('Mask Size', np.asarray(mask).shape)


plt.imshow(img)
plt.imshow(mask, alpha=0.6)
plt.title('Picture with Mask Appplied')
plt.show()

In [None]:
class_map = [[ 0. , 0. , 0.], [0. , 255. , 0.], [255. , 0., 0.] ,[0. , 0. , 255.] ] #Black[0] , Green[1] , Red[2] , Blue[3]

In [None]:
'''This method will convert mask labels(to be trained) from RGB to a 2D image whic holds class labels of the pixels.'''
def form_2D_label(mask,class_map):
    # plt.imshow(mask)
    
    mask = mask.astype("uint8")
    label = np.zeros(mask.shape[:2],dtype= np.uint8)
    
    for i, rgb in enumerate(class_map):
        label[(mask == rgb).all(axis=2)] = i
    
    return label

In [None]:
CUDA_LAUNCH_BLOCKING=1

In [None]:
class_count = np.asarray([560201119,  95301323,  29131557, 29131557]) #change these pixel counts for four classes
tot = np.sum(class_count)

class_weights = (tot - class_count)/tot
print(class_weights)
weights = torch.tensor(list(class_weights)).to(device, dtype = torch.float)

In [None]:
transform = albu.Compose([
    albu.HorizontalFlip(p=0.5),
    albu.VerticalFlip(p=0.5),
    albu.CropAndPad (percent = -0.2, keep_size=True, interpolation=cv2.INTER_NEAREST, p=0.5),
    albu.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit= 15, interpolation= cv2.INTER_NEAREST,
                                          border_mode= cv2.BORDER_REPLICATE, p=0.5)
])

In [None]:
from torch._C import NoneType
class CytoDataset(Dataset):
    
    def __init__(self, img_path, mask_path, X , transform=transform):
        self.img_path = img_path
        self.mask_path = mask_path
        self.X = X
        self.transform = transform
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        if self.transform == None:
            img = cv2.imread(self.img_path  + self.X[idx] + '.png')
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            mask = cv2.imread(self.mask_path  + self.X[idx] + '_gt.png')
            mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
        else:
            img = cv2.imread(self.img_path  + self.X[idx] + '.png')
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            mask = cv2.imread(self.mask_path  + self.X[idx] + '_gt.png')
            mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
            augmentations = transform(image=img,mask=mask)
            img = augmentations['image']
            mask = augmentations['mask']
        # plt.imshow(mask)
        mask_2 = form_2D_label(mask, class_map)
        
        img = (img - img.min())/(img.max() - img.min())
        #, T.Normalize(self.mean, self.std)
        t = T.Compose([T.ToTensor()])
        img = t(img)
        mask_2 = torch.from_numpy(mask_2).long()
        # print(img.shape)
        # print(mask_2.shape)
            
        return img, mask_2

In [None]:
class CytoDataset_val(Dataset):
    
    def __init__(self, img_path, mask_path, X , transform=None):
        self.img_path = img_path
        self.mask_path = mask_path
        self.X = X
        self.transform = transform
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        img = cv2.imread(self.img_path  + self.X[idx] + '.png')
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.mask_path  + self.X[idx] + '_gt.png')
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
        # plt.imshow(mask)
        mask_2 = form_2D_label(mask, class_map)
        
        img = (img - img.min())/(img.max() - img.min())
        #, T.Normalize(self.mean, self.std)
        t = T.Compose([T.ToTensor()])
        img = t(img)
        mask_2 = torch.from_numpy(mask_2).long()
        # print(img.shape)
        # print(mask_2.shape)
            
        return img, mask_2

In [None]:
train_set = CytoDataset(Train_IMAGE_PATH, Train_MASK_PATH, df['id'].values)
val_set = CytoDataset_val(Val_IMAGE_PATH, Val_MASK_PATH, df_val['id'].values)

#dataloader
batch_size= 1

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True)  

In [None]:
 print( train_set.__len__())

In [None]:
print(val_set.__len__())

In [None]:
# ---------------------------------- Res_Unet++ model --------------------------------------
import torch.nn as nn
import torch


class ResidualConv(nn.Module):
    def __init__(self, input_dim, output_dim, stride, padding):
        super(ResidualConv, self).__init__()

        self.conv_block = nn.Sequential(
            nn.BatchNorm2d(input_dim),
            nn.ReLU(),
            nn.Conv2d(
                input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
            ),
            nn.BatchNorm2d(output_dim),
            nn.ReLU(),
            nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
        )
        self.conv_skip = nn.Sequential(
            nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(output_dim),
        )

    def forward(self, x):

        return self.conv_block(x) + self.conv_skip(x)


class Upsample(nn.Module):
    def __init__(self, input_dim, output_dim, kernel, stride):
        super(Upsample, self).__init__()

        self.upsample = nn.ConvTranspose2d(
            input_dim, output_dim, kernel_size=kernel, stride=stride
        )

    def forward(self, x):
        return self.upsample(x)


class Squeeze_Excite_Block(nn.Module):
    def __init__(self, channel, reduction=16):
        super(Squeeze_Excite_Block, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)


class ASPP(nn.Module):
    def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
        super(ASPP, self).__init__()

        self.aspp_block1 = nn.Sequential(
            nn.Conv2d(
                in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
            ),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_dims),
        )
        self.aspp_block2 = nn.Sequential(
            nn.Conv2d(
                in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
            ),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_dims),
        )
        self.aspp_block3 = nn.Sequential(
            nn.Conv2d(
                in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
            ),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_dims),
        )

        self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
        self._init_weights()

    def forward(self, x):
        x1 = self.aspp_block1(x)
        x2 = self.aspp_block2(x)
        x3 = self.aspp_block3(x)
        out = torch.cat([x1, x2, x3], dim=1)
        return self.output(out)

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


class Upsample_(nn.Module):
    def __init__(self, scale=2):
        super(Upsample_, self).__init__()

        self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)

    def forward(self, x):
        return self.upsample(x)


class AttentionBlock(nn.Module):
    def __init__(self, input_encoder, input_decoder, output_dim):
        super(AttentionBlock, self).__init__()

        self.conv_encoder = nn.Sequential(
            nn.BatchNorm2d(input_encoder),
            nn.ReLU(),
            nn.Conv2d(input_encoder, output_dim, 3, padding=1),
            nn.MaxPool2d(2, 2),
        )

        self.conv_decoder = nn.Sequential(
            nn.BatchNorm2d(input_decoder),
            nn.ReLU(),
            nn.Conv2d(input_decoder, output_dim, 3, padding=1),
        )

        self.conv_attn = nn.Sequential(
            nn.BatchNorm2d(output_dim),
            nn.ReLU(),
            nn.Conv2d(output_dim, 1, 1),
        )

    def forward(self, x1, x2):
        out = self.conv_encoder(x1) + self.conv_decoder(x2)
        out = self.conv_attn(out)
        return out * x2

In [None]:
import torch.nn as nn
import torch
# from core.modules import (
#     ResidualConv,
#     ASPP,
#     AttentionBlock,
#     Upsample_,
#     Squeeze_Excite_Block,
# )


class ResUnetPlusPlus(nn.Module):
    def __init__(self, channel, filters=[32, 64, 128, 256, 512]):
        super(ResUnetPlusPlus, self).__init__()

        self.input_layer = nn.Sequential(
            nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
            nn.BatchNorm2d(filters[0]),
            nn.ReLU(),
            nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
        )
        self.input_skip = nn.Sequential(
            nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
        )

        self.squeeze_excite1 = Squeeze_Excite_Block(filters[0])

        self.residual_conv1 = ResidualConv(filters[0], filters[1], 2, 1)

        self.squeeze_excite2 = Squeeze_Excite_Block(filters[1])

        self.residual_conv2 = ResidualConv(filters[1], filters[2], 2, 1)

        self.squeeze_excite3 = Squeeze_Excite_Block(filters[2])

        self.residual_conv3 = ResidualConv(filters[2], filters[3], 2, 1)

        self.aspp_bridge = ASPP(filters[3], filters[4])

        self.attn1 = AttentionBlock(filters[2], filters[4], filters[4])
        self.upsample1 = Upsample_(2)
        self.up_residual_conv1 = ResidualConv(filters[4] + filters[2], filters[3], 1, 1)

        self.attn2 = AttentionBlock(filters[1], filters[3], filters[3])
        self.upsample2 = Upsample_(2)
        self.up_residual_conv2 = ResidualConv(filters[3] + filters[1], filters[2], 1, 1)

        self.attn3 = AttentionBlock(filters[0], filters[2], filters[2])
        self.upsample3 = Upsample_(2)
        self.up_residual_conv3 = ResidualConv(filters[2] + filters[0], filters[1], 1, 1)

        self.aspp_out = ASPP(filters[1], filters[0])

#         self.output_layer = nn.Sequential(nn.Conv2d(filters[0], 1, 1), nn.Sigmoid())
        self.output_layer = nn.Conv2d(filters[0], 4, 1)
    def forward(self, x):
        x1 = self.input_layer(x) + self.input_skip(x)

        x2 = self.squeeze_excite1(x1)
        x2 = self.residual_conv1(x2)

        x3 = self.squeeze_excite2(x2)
        x3 = self.residual_conv2(x3)

        x4 = self.squeeze_excite3(x3)
        x4 = self.residual_conv3(x4)

        x5 = self.aspp_bridge(x4)

        x6 = self.attn1(x3, x5)
        x6 = self.upsample1(x6)
        x6 = torch.cat([x6, x3], dim=1)
        x6 = self.up_residual_conv1(x6)

        x7 = self.attn2(x2, x6)
        x7 = self.upsample2(x7)
        x7 = torch.cat([x7, x2], dim=1)
        x7 = self.up_residual_conv2(x7)

        x8 = self.attn3(x1, x7)
        x8 = self.upsample3(x8)
        x8 = torch.cat([x8, x1], dim=1)
        x8 = self.up_residual_conv3(x8)

        x9 = self.aspp_out(x8)
        out = self.output_layer(x9)

        return out

In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable
from packaging import version

class CrossEntropy2d(nn.Module):

    def __init__(self, size_average=True, ignore_label=255):
        super(CrossEntropy2d, self).__init__()
        self.size_average = size_average
        self.ignore_label = ignore_label

    def forward(self, predict, target, weight= None):
        """
            Args:
                predict:(n, c, h, w)
                target:(n, h, w)
                weight (Tensor, optional): a manual rescaling weight given to each class.
                                           If given, has to be a Tensor of size "nclasses"
        """
        assert not target.requires_grad
        assert predict.dim() == 4
        assert target.dim() == 3
        assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0))
        assert predict.size(2) == target.size(1), "{0} vs {1} ".format(predict.size(2), target.size(1))
        assert predict.size(3) == target.size(2), "{0} vs {1} ".format(predict.size(3), target.size(3))
        n, c, h, w = predict.size()
        target_mask = (target >= 0) * (target != self.ignore_label)
        target = target[target_mask]
        if not target.data.dim():
            return Variable(torch.zeros(1))
        predict = predict.transpose(1, 2).transpose(2, 3).contiguous()
        predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c)
        loss = F.cross_entropy(predict, target, weight=weight, size_average=self.size_average)
        return loss


class BCEWithLogitsLoss2d(nn.Module):

    def __init__(self, size_average=True, ignore_label=255):
        super(BCEWithLogitsLoss2d, self).__init__()
        self.size_average = size_average
        self.ignore_label = ignore_label

    def forward(self, predict, target, weight=None):
        """
            Args:
                predict:(n, 1, h, w)
                target:(n, 1, h, w)
                weight (Tensor, optional): a manual rescaling weight given to each class.
                                           If given, has to be a Tensor of size "nclasses"
        """
        assert not target.requires_grad
        assert predict.dim() == 4
        assert target.dim() == 4
        assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0))
        assert predict.size(2) == target.size(2), "{0} vs {1} ".format(predict.size(2), target.size(2))
        assert predict.size(3) == target.size(3), "{0} vs {1} ".format(predict.size(3), target.size(3))
        n, c, h, w = predict.size()
        target_mask = (target >= 0) * (target != self.ignore_label)
        target = target[target_mask]
        if not target.data.dim():
            return Variable(torch.zeros(1))
        predict = predict[target_mask]
        loss = F.binary_cross_entropy_with_logits(predict, target, weight=weight, size_average=self.size_average)
        return loss

In [None]:
args = {'IMG_MEAN' : np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32),  #calculate the mean of RGB in real images of train data
        'MODEL' : 'ResUnetPlusPlus' , 'BATCH_SIZE' : 1,
'ITER_SIZE' : 1,  'NUM_WORKERS' : 4, 
'IGNORE_LABEL' : 255 , 'INPUT_SIZE' : '512,512' , 'LEARNING_RATE' : 2.5e-4 , 'MOMENTUM' : 0.9,  'NUM_CLASSES' : 4, 
'NUM_STEPS' : 20000, 
'POWER' : 0.9,
'RANDOM_SEED' : 1234,
'SAVE_NUM_IMAGES' : 2,
'SAVE_PRED_EVERY' : 5000,
'SNAPSHOT_DIR' : './snapshots/',
'WEIGHT_DECAY' : 0.0005,

'LEARNING_RATE_D' : 1e-4,
'LAMBDA_ADV_PRED' : 0.1,
'PARTIAL_DATA' : None, #0.5,

'SEMI_START' : 5000,
'LAMBDA_SEMI' : 0.1,
'MASK_T' : 0.2,

'LAMBDA_SEMI_ADV':0.001,
'SEMI_START_ADV' : 0,
'D_REMAIN' : False, 'GPU':True,
       'RESTORE_FROM' : 'http://vllab1.ucmerced.edu/~whung/adv-semi-seg/resnet101COCO-41f33a49.pth'}

In [None]:
try:
    import copy_reg
except:
    import copyreg as copy_reg

In [None]:
def pixel_accuracy(output, mask):
    with torch.no_grad():
        output = torch.argmax(F.softmax(output, dim=1), dim=1)
        correct = torch.eq(output, mask).int()
        accuracy = float(correct.sum()) / float(correct.numel())
    return accuracy

def mIoU(pred_mask, mask, smooth=1e-10, n_classes=5):
    with torch.no_grad():
        pred_mask = F.softmax(pred_mask, dim=1)
        pred_mask = torch.argmax(pred_mask, dim=1)
        pred_mask = pred_mask.contiguous().view(-1)
        mask = mask.contiguous().view(-1)

        iou_per_class = []
        for clas in range(0, n_classes): #loop per pixel class
            true_class = pred_mask == clas
            true_label = mask == clas

            if true_label.long().sum().item() == 0: #no exist label in this loop
                iou_per_class.append(np.nan)
            else:
                intersect = torch.logical_and(true_class, true_label).sum().float().item()
                union = torch.logical_or(true_class, true_label).sum().float().item()

                iou = (intersect + smooth) / (union +smooth)
                iou_per_class.append(iou)
        return np.nanmean(iou_per_class)
    
def map_thiss(y_pred,class_map):
    y_pred_rgb = np.zeros((y_pred.shape[0],y_pred.shape[1],y_pred.shape[2],3))
    for i in range(y_pred.shape[0]):
        image = np.zeros((y_pred.shape[1],y_pred.shape[2],3))
        for j in range(y_pred.shape[1]):
            for k in range(y_pred.shape[2]):
                image[j,k,:] = class_map[int(y_pred[i][j][k])]
        y_pred_rgb[i] = image
    return y_pred_rgb

def plot_result(img, title):
    plt.figure(figsize=(12, 6))
    plt.title(title)
    plt.imshow(img[0])
    plt.show()
    
def export_model(model, optimizer=None, name=None, step=None):

        # set output filename
        if name is not None:
            out_file = name
        else:
            out_file = "checkpoint"
        if step is not None:
            out_file += "_step_" + str(step)
            
        out_file = os.path.join("./", out_file + ".pth")

        # save model
        data = {"model_state_dict": model.state_dict()}
        if step is not None:
            data["step"] = step
        if optimizer is not None:
            data["optimizer_state_dict"] = optimizer.state_dict()
        torch.save(data, out_file)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def loss_calc(pred, label, gpu):
    """
    This function returns cross entropy loss for semantic segmentation
    """
    # out shape batch_size x channels x h x w -> batch_size x channels x h x w
    # label shape h x w x 1 x batch_size  -> batch_size x 1 x h x w
    label = Variable(label.long()).cuda(device)
    criterion = CrossEntropy2d().cuda(device)

    return criterion(pred, label)


def lr_poly(base_lr, iter, max_iter, power):
    return base_lr*((1-float(iter)/max_iter)**(power))


def adjust_learning_rate(optimizer, i_iter):
    lr = lr_poly(args["LEARNING_RATE"], i_iter, args["NUM_STEPS"], args["POWER"])
    optimizer.param_groups[0]['lr'] = lr
    if len(optimizer.param_groups) > 1 :
        optimizer.param_groups[1]['lr'] = lr * 10

def adjust_learning_rate_D(optimizer, i_iter):
    lr = lr_poly(args["LEARNING_RATE_D"], i_iter, args["NUM_STEPS"], args["POWER"])
    optimizer.param_groups[0]['lr'] = lr
    if len(optimizer.param_groups) > 1 :
        optimizer.param_groups[1]['lr'] = lr * 10

def one_hot(label):
    label = label.numpy()
    one_hot = np.zeros((label.shape[0], args["NUM_CLASSES"], label.shape[1], label.shape[2]), dtype=label.dtype)
    for i in range(args["NUM_CLASSES"]):
        one_hot[:,i,...] = (label==i)
    #handle ignore labels
    return torch.FloatTensor(one_hot)

def make_D_label(label, ignore_mask):
    ignore_mask = np.expand_dims(ignore_mask, axis=1)
    D_label = np.ones(ignore_mask.shape)*label
    D_label[ignore_mask] = 255
    D_label = Variable(torch.FloatTensor(D_label)).cuda(device)

    return D_label

In [None]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def fit(epochs, model, train_loader, val_loader, criterion, optimizer, scheduler, patch=False):
    torch.cuda.empty_cache()
    train_losses = []
    test_losses = []
    val_iou = []; val_acc = []
    train_iou = []; train_acc = []
    lrs = []
    min_loss = np.inf
    decrease = 1 ; not_improve=0

    model.to(device)
    fit_time = time.time()
    for e in range(epochs):
        since = time.time()
        running_loss = 0
        iou_score = 0
        accuracy = 0
        #training loop
        model.train()
        for i, data in enumerate(tqdm(train_loader)):
            #training phase
            image_tiles, mask_tiles = data
            # print(image_tiles.shape)
            # print(mask_tiles.shape)
            image = image_tiles.to(device, dtype = torch.float); mask = mask_tiles.to(device, dtype= torch.float);
            #forward
            output = model(image)
            loss = loss_calc(output, mask, device)
            #evaluation metrics
            iou_score += mIoU(output, mask)
            accuracy += pixel_accuracy(output, mask)
            #backward
            loss.backward()
            optimizer.step() #update weight          
            optimizer.zero_grad() #reset gradient
            
            #step the learning rate
            lrs.append(get_lr(optimizer))
            scheduler.step() 
            
            running_loss += loss.item()
            
        model.eval()
        test_loss = 0
        test_accuracy = 0
        val_iou_score = 0
        #validation loop
        with torch.no_grad():
            for i, data in enumerate(tqdm(val_loader)):
                #reshape to 9 patches from single image, delete batch size
                image_tiles, mask_tiles = data


                image = image_tiles.to(device, dtype = torch.float); mask = mask_tiles.to(device, dtype = torch.float);
                output = model(image)
                #evaluation metrics
                if(i==1):
                    output_soft = F.softmax(output, dim=1)
                    output_num = output_soft.cpu().detach().numpy()
                    pred_mask = np.argmax(output_num, axis = 1)
        
                    y_pred_rgb = map_thiss(pred_mask,class_map)
                    y_test_rgb = map_thiss(mask,class_map)
                    plot_result(y_test_rgb,"Original Masks")
                    plot_result(y_pred_rgb,"Predicted Masks")
                val_iou_score +=  mIoU(output, mask)
                test_accuracy += pixel_accuracy(output, mask)
                #loss
                loss = loss_calc(output, mask, device)                                  
                test_loss += loss.item()
            
        #calculatio mean for each batch
        train_losses.append(running_loss/len(train_loader))
        test_losses.append(test_loss/len(val_loader))


        if min_loss > (test_loss/len(val_loader)):
            print('Loss Decreasing.. {:.3f} >> {:.3f} '.format(min_loss, (test_loss/len(val_loader))))
            min_loss = (test_loss/len(val_loader))
            decrease += 1
            #if decrease % 5 == 0:
            print('saving model...')
            export_model(model, optimizer=optimizer, name="final", step = e)
            #torch.save(model, 'Unet-Mobilenet_v2_val_loss-{:.3f}.pt'.format(test_loss/len(val_loader)))


        if (test_loss/len(val_loader)) > min_loss:
            not_improve += 1
            min_loss = (test_loss/len(val_loader))
            print(f'Loss Not Decrease for {not_improve} time')
            #if not_improve == 7:
                #print('Loss not decrease for 7 times, Stop Training')
                #break

        #iou
        val_iou.append(val_iou_score/len(val_loader))
        train_iou.append(iou_score/len(train_loader))
        train_acc.append(accuracy/len(train_loader))
        val_acc.append(test_accuracy/ len(val_loader))
        print("Epoch:{}/{}..".format(e+1, epochs),
              "Train Loss: {:.3f}..".format(running_loss/len(train_loader)),
              "Val Loss: {:.3f}..".format(test_loss/len(val_loader)),
              "Train mIoU:{:.3f}..".format(iou_score/len(train_loader)),
              "Val mIoU: {:.3f}..".format(val_iou_score/len(val_loader)),
              "Train Acc:{:.3f}..".format(accuracy/len(train_loader)),
              "Val Acc:{:.3f}..".format(test_accuracy/len(val_loader)),
              "Time: {:.2f}m".format((time.time()-since)/60))
        
    history = {'train_loss' : train_losses, 'val_loss': test_losses,
               'train_miou' :train_iou, 'val_miou':val_iou,
               'train_acc' :train_acc, 'val_acc':val_acc,
               'lrs': lrs}
    print('Total time: {:.2f} m' .format((time.time()- fit_time)/60))
    return history

In [None]:
# CUDA_LAUNCH_BLOCKING=1

In [None]:
max_lr = 1e-3
epoch = 50
weight_decay = 1e-4

model = ResUnetPlusPlus(3)
criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.AdamW(model.parameters(), lr=max_lr, weight_decay=weight_decay)
sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=epoch,
                                            steps_per_epoch=len(train_loader))



history = fit(epoch, model, train_loader, val_loader, criterion, optimizer, sched)
# history = fit(epoch, model, val_loader, val_loader, criterion, optimizer, sched)

In [None]:
# max_lr = 1e-3
# epoch = 5
# weight_decay = 1e-4

# model = ResUnetPlusPlus(3)
# criterion = nn.CrossEntropyLoss()

# state = torch.load('../input/weights-file/final_step_45.pth', map_location = 'cpu')
# model.load_state_dict(state['model_state_dict'])

# model.cuda()


# optimizer = torch.optim.AdamW(model.parameters())

# optimizer.load_state_dict(state['optimizer_state_dict'])
# sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=epoch,
#                                             steps_per_epoch=len(train_loader))



# history = fit(epoch, model, train_loader, val_loader, criterion, optimizer, sched)

In [None]:
def plot_loss(history):
    plt.plot(history['val_loss'], label='val', marker='o')
    plt.plot( history['train_loss'], label='train', marker='o')
    plt.title('Loss per epoch'); plt.ylabel('loss');
    plt.xlabel('epoch')
    plt.legend(), plt.grid()
    plt.show()
    
def plot_score(history):
    plt.plot(history['train_miou'], label='train_mIoU', marker='*')
    plt.plot(history['val_miou'], label='val_mIoU',  marker='*')
    plt.title('Score per epoch'); plt.ylabel('mean IoU')
    plt.xlabel('epoch')
    plt.legend(), plt.grid()
    plt.show()
    
def plot_acc(history):
    plt.plot(history['train_acc'], label='train_accuracy', marker='*')
    plt.plot(history['val_acc'], label='val_accuracy',  marker='*')
    plt.title('Accuracy per epoch'); plt.ylabel('Accuracy')
    plt.xlabel('epoch')
    plt.legend(), plt.grid()
    plt.show()

In [None]:
plot_loss(history)
plot_score(history)
plot_acc(history)

In [None]:
with open('./history.pickle', 'wb') as handle:
    pickle.dump(history, handle, protocol=pickle.HIGHEST_PROTOCOL)