In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from datetime import datetime
from datetime import date 
from datetime import time
import cv2
import glob
import os 
from pathlib import Path
import random

from PIL import Image
import numpy as np
import pandas  as pd
import matplotlib.pyplot as plt

import SimpleITK as sitk
from torchsummary import summary

from sklearn.model_selection import train_test_split 
from sklearn.utils import shuffle
import skimage
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import datasets,transforms, models
import torchvision.transforms.functional as TF
import albumentations as A
from albumentations.pytorch import ToTensorV2

## Helper Functions & Dataset Preparations

In [33]:
# Create Dataset and loaded images for segmentation models training
class SegDataset(Dataset):
    def __init__(self, filename_slices, transform):
        super().__init__()
        self._filename_slices = filename_slices 
        self._transform = transform    

    def __getitem__(self, index):
        # Read in the image
        label_path,label_img,gt_path,gt_img,idx = self._filename_slices[index]
        mm , MM = np.min(np.array(gt_img)), np.max(np.array(gt_img))
        image = (np.array(gt_img) - mm) / (MM - mm) * 255
        image = np.array(image,dtype = np.float32)

        # Read in the mask
        mask = np.array(label_img)
        
        # Data augmentation 
        if self._transform is not None:
            augmented = self._transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
        
        return image, mask
    
    def __len__(self):
        
        return len(self._filename_slices)

# Create Dataset and loaded images for classification models training
class ClassDataset(Dataset):
    def __init__(self, filename_slices, transform):
        """
        Initialized
        """
        super().__init__()
        self._filename_slices = filename_slices 
        self._transform = transform    

    def __getitem__(self, index):

        # Read in the image
        if len(self._filename_slices) == 1:
            label_img = self._filename_slices[index][1]
            gt_img = self._filename_slices[index][0]
            gt_path = self._filename_slices[index][2]
            labelvalue = self._filename_slices[index][3]
        
        else:
            label_img = self._filename_slices[index][1]
            gt_img = self._filename_slices[index][0]
            gt_path = self._filename_slices[index][2]
            labelvalue = self._filename_slices[index][3]
            
        # Data augmentation 
        if self._transform is not None:
            augmented = self._transform(image=np.array(label_img))
            label_img = augmented['image']
            augmented1 = self._transform(image=np.array(gt_img))
            gt_img = augmented1['image']
        
        return gt_img, label_img, gt_path, labelvalue
    
    def __len__(self):
        
        return len(self._filename_slices)

def dice_coef(y_true, y_pred):
    # Convert the input to a numpy array if it is a torch Tensor
    if isinstance(y_true, torch.Tensor):
        y_true = y_true.cpu().detach().numpy()
    
    if isinstance(y_pred, torch.Tensor):
        y_pred = y_pred.cpu().detach().numpy()
    
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    intersection = np.sum(y_true_f * y_pred_f)
    smooth = 0.0001
    return (2. * intersection + smooth) / (np.sum(y_true_f) + np.sum(y_pred_f) + smooth)


def read_slices(all_data):
    filename_slices = []
    for idx, fn_mask in enumerate(all_data):
        label_path = Path(fn_mask)
        label_path = str(label_path) 

        label_img = (Image.open(Path(label_path[5:len((label_path))])))
        label_img = np.array(label_img)
        label_img[label_img<128] = 0
        label_img[label_img>=128] = 1
        
        img_need = label_path[21:len(str(fn_mask))]
        
        gt_path = '240618_data_3ch/'+img_need
        gt_img = np.array(Image.open(gt_path))
        label_img = np.array(label_img)

        filename_slices += [(label_path,label_img,gt_path,gt_img,idx)]
    return filename_slices

In [None]:
file_ck = torch.load('filenames.pth')
tt = file_ck['tt']
test_filenames = file_ck['testfiles']
train_filenames, val_filenames = train_test_split(tt, test_size = 0.3) 
print("Number of images (Before Image Augmentation) in Training set : ",len(train_filenames),"  Number of images in Testing set : ",len(test_filenames))
print("Number of images (Before Image Augmentation) in Validation set : ",len(val_filenames))

label_all = pd.read_csv('test_1ch.csv',header = None)
labels_files = label_all[0]
labels_co = label_all[1]

filename_slices = []
for i in range(len(test_filenames)):
    str_filename = str(test_filenames[i])
    str_filename = str_filename[5:len(str_filename)]
    
    label_path = str_filename
    labels = labels_co[labels_files[labels_files == str_filename].index]-1
    
    label_img1 = Image.open(label_path)
    label_img = np.array(label_img1)
    label_img1.close()
    label_img[label_img<128] = 0
    label_img[label_img>=128] = 255
    
    img_need = label_path[17:len(str(label_path))]
    gt_path = '240618_data_3ch/'+img_need
    gt_img1 = Image.open(gt_path)
    gt_img = np.array(gt_img1)
    gt_img1.close()
    filename_slices += [(gt_img,label_img,gt_path,labels.values[0])]

train_filenamesnew = []
val_filenamesnew = []

for i in range(len(train_filenames)):
    str_filename = str(train_filenames[i])
    str_filename = str_filename[5:len(str_filename)]
    train_filenamesnew.append(str_filename)
    
for i in range(len(val_filenames)):
    str_filename = str(val_filenames[i])
    str_filename = str_filename[5:len(str_filename)]
    val_filenamesnew.append(str_filename)
    
trainfiles = []
for i in range(len(train_filenames)):
    str_filename = str(train_filenames[i])
    str_filename = str_filename[5:len(str_filename)]
    
    label_path = str_filename
    labels = labels_co[labels_files[labels_files == str_filename].index]-1
    
    label_img1 = Image.open(label_path)
    label_img = np.array(label_img1)
    label_img1.close()
    label_img[label_img<128] = 0
    label_img[label_img>=128] = 255
    
    img_need = label_path[17:len(str(label_path))]
    gt_path = '240618_data_3ch/'+img_need
    gt_img1 = Image.open(gt_path)
    gt_img = np.array(gt_img1)
    gt_img1.close()
    trainfiles += [(gt_img,label_img,gt_path,labels.values[0])]
    
valfiles = []
for i in range(len(val_filenames)):
    str_filename = str(val_filenames[i])
    str_filename = str_filename[5:len(str_filename)]
    
    label_path = str_filename
    labels = labels_co[labels_files[labels_files == str_filename].index]-1
    
    label_img1 = Image.open(label_path)
    label_img = np.array(label_img1)
    label_img1.close()
    label_img[label_img<128] = 0
    label_img[label_img>=128] = 255
    
    img_need = label_path[17:len(str(label_path))]
    gt_path = '240618_data_3ch/'+img_need
    gt_img1 = Image.open(gt_path)
    gt_img = np.array(gt_img1)
    gt_img1.close()
    valfiles += [(gt_img,label_img,gt_path,labels.values[0])]


train_file_labels = read_slices(train_filenames)
val_file_labels = read_slices(val_filenames)
test_file_labels = read_slices(test_filenames)

## Segmentation Network Architecture

### UNet

In [29]:
class contracting(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(3, 64, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.ReLU())
        self.layer2 = nn.Sequential(nn.Conv2d(64, 128, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(128, 128, 3, stride=1, padding=1), nn.ReLU())
        self.layer3 = nn.Sequential(nn.Conv2d(128, 256, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(256, 256, 3, stride=1, padding=1), nn.ReLU())
        self.layer4 = nn.Sequential(nn.Conv2d(256, 512, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(512, 512, 3, stride=1, padding=1), nn.ReLU())
        self.layer5 = nn.Sequential(nn.Conv2d(512, 1024, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(1024, 1024, 3, stride=1, padding=1), nn.ReLU())
        self.down_sample = nn.MaxPool2d(2, stride=2)

    def forward(self, X):
        X1 = self.layer1(X)
        X2 = self.layer2(self.down_sample(X1))
        X3 = self.layer3(self.down_sample(X2))
        X4 = self.layer4(self.down_sample(X3))
        X5 = self.layer5(self.down_sample(X4))
        return X5, X4, X3, X2, X1

class expansive(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Conv2d(64, 2, 3, stride=1, padding=1)
        self.layer2 = nn.Sequential(nn.Conv2d(128, 64, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.ReLU())
        self.layer3 = nn.Sequential(nn.Conv2d(256, 128, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(128, 128, 3, stride=1, padding=1), nn.ReLU())
        self.layer4 = nn.Sequential(nn.Conv2d(512, 256, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(256, 256, 3, stride=1, padding=1), nn.ReLU())
        self.layer5 = nn.Sequential(nn.Conv2d(1024, 512, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(512, 512, 3, stride=1, padding=1), nn.ReLU())
        self.up_sample_54 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.up_sample_43 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.up_sample_32 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.up_sample_21 = nn.ConvTranspose2d(128, 64, 2, stride=2)

    def forward(self, X5, X4, X3, X2, X1):
        X = self.up_sample_54(X5)
        X4 = torch.cat([X, X4], dim=1)
        X4 = self.layer5(X4)

        X = self.up_sample_43(X4)
        X3 = torch.cat([X, X3], dim=1)
        X3 = self.layer4(X3)

        X = self.up_sample_32(X3)
        X2 = torch.cat([X, X2], dim=1)
        X2 = self.layer3(X2)

        X = self.up_sample_21(X2)
        X1 = torch.cat([X, X1], dim=1)
        X1 = self.layer2(X1)

        X = self.layer1(X1)

        return X

class unet(nn.Module):
    def __init__(self):
        super().__init__()
        self.down = contracting()
        self.up = expansive()

    def forward(self, X):
        X5, X4, X3, X2, X1 = self.down(X)
        X = self.up(X5, X4, X3, X2, X1)
        return X

### MnUV3

In [30]:
class hswish(nn.Module):
    def forward(self, x):
        out = x * F.relu6(x + 3, inplace=True) / 6
        return out

class hsigmoid(nn.Module):
    def forward(self, x):
        out = F.relu6(x + 3, inplace=True) / 6
        return out

class SeModule(nn.Module):
    def __init__(self, in_size, reduction=4):
        super(SeModule, self).__init__()
        expand_size =  max(in_size // reduction, 8)
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_size, expand_size, kernel_size=1, bias=False),
            nn.BatchNorm2d(expand_size),
            nn.ReLU(inplace=True),
            nn.Conv2d(expand_size, in_size, kernel_size=1, bias=False),
            nn.Hardsigmoid())

    def forward(self, x):
        return x * self.se(x)

class Block(nn.Module):
    def __init__(self, kernel_size, in_size, expand_size, out_size, act, se, stride):
        super(Block, self).__init__()
        self.stride = stride

        self.conv1 = nn.Conv2d(in_size, expand_size, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(expand_size)
        self.act1 = act(inplace=True)

        self.conv2 = nn.Conv2d(expand_size, expand_size, kernel_size=kernel_size, stride=stride, padding=kernel_size//2, groups=expand_size, bias=False)
        self.bn2 = nn.BatchNorm2d(expand_size)
        self.act2 = act(inplace=True)
        self.se = SeModule(expand_size) if se else nn.Identity()

        self.conv3 = nn.Conv2d(expand_size, out_size, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_size)
        self.act3 = act(inplace=True)

        self.skip = None
        if stride == 1 and in_size != out_size:
            self.skip = nn.Sequential(
                nn.Conv2d(in_size, out_size, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_size)
            )

        if stride == 2 and in_size != out_size:
            self.skip = nn.Sequential(
                nn.Conv2d(in_channels=in_size, out_channels=in_size, kernel_size=3, groups=in_size, stride=2, padding=1, bias=False),
                nn.BatchNorm2d(in_size),
                nn.Conv2d(in_size, out_size, kernel_size=1, bias=True),
                nn.BatchNorm2d(out_size)
            )

        if stride == 2 and in_size == out_size:
            self.skip = nn.Sequential(
                nn.Conv2d(in_channels=in_size, out_channels=out_size, kernel_size=3, groups=in_size, stride=2, padding=1, bias=False),
                nn.BatchNorm2d(out_size)
            )

    def forward(self, x):
        skip = x
        out = self.act1(self.bn1(self.conv1(x)))
        out = self.act2(self.bn2(self.conv2(out)))
        out = self.se(out)
        out = self.bn3(self.conv3(out))
        
        if self.skip is not None:
            skip = self.skip(skip)
        return self.act3(out + skip)

class expansive1(nn.Module):
    def __init__(self):
        super().__init__()
        self.up_sample_54 = nn.ConvTranspose2d(960, 484, 2, stride=2)
        self.layer5 = nn.Sequential(nn.Conv2d(960, 484, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(484, 484, 3, stride=1, padding=1), nn.ReLU())
        
        self.up_sample_43 = nn.ConvTranspose2d(484, 420, 2, stride=2)
        self.layer4 = nn.Sequential(nn.Conv2d(484, 420, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(420, 420, 3, stride=1, padding=1), nn.ReLU())
        
        self.up_sample_32 = nn.ConvTranspose2d(420, 404, 2, stride=2)
        self.layer3 = nn.Sequential(nn.Conv2d(420, 404, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(404, 404, 3, stride=1, padding=1), nn.ReLU())
        
        self.up_sample_21 = nn.ConvTranspose2d(404, 401, 2, stride=2)
        self.layer2 = nn.Sequential(nn.Conv2d(404, 401, 3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(401, 401, 3, stride=1, padding=1), nn.ReLU())
        self.layer1 = nn.Conv2d(401, 2, 3, stride=1, padding=1)

    def forward(self, X5, X4, X3, X2, X1):

        X = self.up_sample_54(X5)
        X4 = torch.cat([X, X4], dim=1)
        X4 = self.layer5(X4)
        X = self.up_sample_43(X4)
        X3 = torch.cat([X, X3], dim=1)
        X3 = self.layer4(X3)
        X = self.up_sample_32(X3)
        X2 = torch.cat([X, X2], dim=1)
        X2 = self.layer3(X2)
        X = self.up_sample_21(X2)
        X1 = torch.cat([X, X1], dim=1)
        X1 = self.layer2(X1)
        X = self.layer1(X1)

        return X

class MnUV3(nn.Module):
    def __init__(self):
        super(MnUV3, self).__init__()
        # Encoding
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.hs1 = nn.Hardswish(inplace=True)
        self.bneck1 = Block(3, 16, 64, 64, nn.ReLU, False, 2)
        self.bneck2 = Block(5, 64, 672, 476, nn.Hardswish, True, 2)
        self.conv2 = nn.Conv2d(476, 960, kernel_size=1, stride=2, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(960)
        self.hs2 = nn.Hardswish(inplace=True)
        self.linear4 = nn.Linear(960, 2)
        # Decoding
        self.up = expansive1()

    def forward(self, x):
        # Encoding
        x1 = self.hs1(self.bn1(self.conv1(x))) 
        xa1 = self.bneck1(x1) 
        xa2 = self.bneck2(xa1)
        x3 = (self.hs2(self.bn2(self.conv2(xa2))))  

        # Decoding
        x = self.up(x3,xa2,xa1,x1,x)
        return x

## Classification Network Architecture

### ClassNet

In [31]:
def conv2dout(hin,win,conv,pool=2):
    k=conv.kernel_size
    s=conv.stride
    p=conv.padding
    d=conv.dilation
    ho=np.floor((hin+2*p[0]-d[0]*(k[0]-1)-1)/s[0]+1)
    wo=np.floor((win+2*p[1]-d[1]*(k[1]-1)-1)/s[1]+1)
    
    if pool:
        ho = ho/pool
        wo = wo/pool
    return int(ho),int(wo)

class Network(nn.Module):    
    def __init__(self):
        super(Network, self).__init__()

        # Convolution Layers
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3)
        h,w = conv2dout(256,256,self.conv1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3)
        h,w = conv2dout(h,w,self.conv2)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3)
        h,w = conv2dout(h,w,self.conv2)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=3)
        h,w = conv2dout(h,w,self.conv4)
        self.num_flatten=512*h*w
        self.fc1 = nn.Linear(self.num_flatten, 50)
        self.fc2 = nn.Linear(50, 2)

    def forward(self,X):
        X = F.relu(self.conv1(X)); 
        X = F.max_pool2d(X, 2, 2)
        X = F.relu(self.conv2(X))
        X = F.max_pool2d(X, 2, 2)
        X = F.relu(self.conv3(X))
        X = F.max_pool2d(X, 2, 2)
        X = F.relu(self.conv4(X))
        X = F.max_pool2d(X, 2, 2)
        X = X.view(-1, self.num_flatten)
        X = F.relu(self.fc1(X))
        X=F.dropout(X, 0.1)
        X = self.fc2(X)
        X = F.softmax(X,dim = 1)
        
        return X[:,0]

## Segmentation Models Training

In [None]:
train_transform_seg = A.Compose(
    [
        A.Normalize(mean=0.0, std=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Rotate(limit=10, p=0.5),
        
        # A.RandomBrightnessContrast(p=0.5),
        # A.Defocus(p = 0.5),
        # A.Downscale(scale_min = 0.3,scale_max = 0.55,p = 0.5,interpolation = cv2.INTER_CUBIC),
        # A.GaussNoise(var_limit=0.4,p = 0.5),

        A.Resize(256, 256),
        ToTensorV2(),
    ]
)

test_transform_seg = A.Compose(
    [
        A.Normalize(mean=0.0, std=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Rotate(limit=10, p=0.5),
        
        # A.RandomBrightnessContrast(p=0.5),
        # A.Defocus(p = 0.5),
        # A.Downscale(scale_min = 0.3,scale_max = 0.55,p = 0.5,interpolation = cv2.INTER_CUBIC),
        # A.GaussNoise(var_limit=0.4,p = 0.5),

        A.Resize(256, 256),
        ToTensorV2(),]
)


torch.cuda.empty_cache()
epochs = 300

lrnew = 1e-4
batch_size = 16

for trial in range(10):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    resume = 0 # 1: from (half) trained models, 0: start a new training
    model = MnUV3() # model = unet()
    model.to(device)
    print(f'Device is {device}')
    optimizer = torch.optim.Adam(model.parameters(), lr=lrnew)
    train_accuracy = np.zeros(epochs)
    train_loss = np.zeros(epochs)
    validation_accuracy = np.zeros(epochs)
    validation_loss = np.zeros(epochs)

    if resume == 0:
        now = datetime.now()
        current_time = date(now.year, now.month,now.day).strftime('%y%m%d')
        current_new = time(now.hour, now.minute).strftime('%H%M')
        foldername = os.getcwd() +'/'+ current_time+'_'+current_new+'_epoch'+str(epochs)+'_batch'+str(batch_size)+'_img'+'_adam'+str(lrnew)+'_v3_train_no_test_no/'
        if os.path.isdir(foldername) == False:
            os.mkdir(foldername)
        print(foldername)
        start_epoch = 1

    else:
        foldername = checkpoint_fpath[0:84]
        checkpoint = torch.load(checkpoint_fpath)
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        train_accuracy[0:len(checkpoint['train_acc'])] = checkpoint['train_acc']
        train_loss[0:len(checkpoint['train_loss'])] = checkpoint['train_loss']
        validation_accuracy[0:len(checkpoint['val_acc'])] = checkpoint['val_acc']
        validation_loss[0:len(checkpoint['val_loss'])] = checkpoint['val_loss']

    # Create the datsets and dataloaders for every subset 
    train_dataset = SegDataset(train_file_labels, train_transform_seg)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=0, shuffle=True) 
    val_dataset = SegDataset(val_file_labels, test_transform_seg)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=0, shuffle=False) 

    # Create the loss function
    loss = torch.nn.CrossEntropyLoss()

    for epoch in range(start_epoch, epochs):
        # Training Loop
        print('='*30)
        print('Epoch {} / {}'.format(epoch, epochs))
        
        # Set variables
        epoch_loss = 0
        epoch_accuracy = 0
        epoch_counter = 0
        model.train()

        # Batch Training Loop (Loop over the batches)
        for index, (X, Y) in enumerate(train_dataloader):
            if device is not None:
                X = X.to(device)
                Y = Y.to(device)

            # Call the model (image to mask)
            R = model(X)
            # Compute the loss
            L = loss(R, Y.long())
            # Do PyTorch stuff - Training 
            optimizer.zero_grad()
            L.backward()
            optimizer.step()
            pred = R.data.max(1)[1]
            
            # Analyze the accuracy of the batch 
            epoch_accuracy += dice_coef(pred, Y) * X.shape[0]
            epoch_loss += L.data.item() * X.shape[0]
            print(f'\tBatch {index} loss {L.data.item():3.3f}')
            epoch_counter += X.shape[0]
    
        # Epoch Train Loss
        train_accuracy[epoch] = epoch_accuracy/epoch_counter
        train_loss[epoch] = epoch_loss/epoch_counter
        print(f"Loss: {train_loss[epoch]:3.3f}, Accuracy: {train_accuracy[epoch]:3.3f}")
        
        # Validation Loop 
        epoch_loss = 0
        epoch_accuracy = 0
        epoch_counter = 0
        model.eval()
        
        # Batch Validation Loop (Loop over the batches)
        with torch.no_grad():
            for index, (X, Y) in enumerate(val_dataloader):
                if device is not None:
                    X = X.to(device)
                    Y = Y.to(device)

                # Call the model (image to mask)
                R = model(X)

                # Compute the loss
                L = loss(R, Y.long())
                pred = R.data.max(1)[1]

                # Analyze the accuracy of the batch 
                epoch_accuracy += dice_coef(pred, Y) * X.shape[0]
                epoch_loss += L.data.item() * X.shape[0]
                print('\tBatch ', index, L.data.item())
                epoch_counter += X.shape[0]

        # Epoch Train Loss
        validation_accuracy[epoch] = epoch_accuracy/epoch_counter
        validation_loss[epoch] = epoch_loss/epoch_counter
        print(f"Validation Loss: {validation_loss[epoch]:3.3f}, Accuracy: {validation_accuracy[epoch]:3.3f}")

        if epoch % 10 == 9 and epoch > 70:
            checkpoint = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
                }
            torch.save(checkpoint, foldername+str(epoch)+'.pth')
        if epoch == 299:
            checkpoint = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'train_acc': train_accuracy,
                'train_loss': train_loss,
                'val_acc': validation_accuracy,
                'val_loss': validation_loss
                }
            torch.save(checkpoint, foldername+str(epoch)+'.pth')


## Classification Models Training

In [None]:
train_transform_class = A.Compose(
    [
        A.Normalize(mean=0.0, std=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Rotate(limit=10, p=0.5),

        A.Resize(256, 256),
        ToTensorV2(),
    ]
)

test_transform_class = A.Compose(
    [
        A.Normalize(mean=0.0, std=1.0),
        A.Resize(256, 256),
        ToTensorV2()]
)

for trial in range(10):
    torch.cuda.empty_cache()
    epochs = 300
    lrnew = 1e-3
    batch_size = 128

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    resume = 0
    model = Network()

    model.to(device)
    print(f'Device is {device}')
    optimizer1 = torch.optim.Adam(model.parameters(), lr=lrnew)
    optimizer2 = torch.optim.Adam(model.parameters(), lr=1e-4)

    train_accuracy = np.zeros(epochs)
    train_loss = np.zeros(epochs)
    validation_accuracy = np.zeros(epochs)
    validation_loss = np.zeros(epochs)

    if resume == 0:
        now = datetime.now()
        current_time = date(now.year, now.month,now.day).strftime('%y%m%d')
        current_new = time(now.hour, now.minute).strftime('%H%M')
        foldername = os.getcwd() + '/'+ current_time+'_'+current_new+'_epoch'+str(epochs)+'_batch'+str(batch_size)+'_img'+'_adam'+str(lrnew)+'_classify/'
        if os.path.isdir(foldername) == False:
            os.mkdir(foldername)
        print(foldername)
        start_epoch = 1

    train_dataset = ClassDataset(trainfiles,train_transform_class)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=0, shuffle=True) 
    val_dataset = ClassDataset(valfiles, test_transform_class)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=0, shuffle=False) 

    # Create the loss function
    loss = torch.nn.BCELoss()

    for epoch in range(start_epoch,epochs):
        if epoch > 20:
            op = optimizer2
        else:
            op = optimizer1
        # Training Loop
        print('='*30)
        print('Epoch {} / {}'.format(epoch, epochs))
        
        # Set variables
        epoch_loss = 0
        epoch_accuracy = 0
        epoch_counter = 0
        model.train()

        # Batch Training Loop 
        for index, (a,X,path, Y) in enumerate(train_dataloader):
            if device is not None:
                X = X.to(device)
                Y = Y.to(device)

            # Call the model (image to mask)
            R = model(X)
            # Compute the loss
            L = loss(R.float(), Y.float())
            # Do PyTorch stuff - Training 
            op.zero_grad()
            L.backward()
            op.step()
            
            pred = R.data
            # Analyze the accuracy of the batch 
            epoch_accuracy += np.array(abs(Y-R.data).cpu()).sum()
            epoch_loss += L.data.item() * X.shape[0]
            print(f'\tBatch {index} loss {L.data.item():3.3f}')
            epoch_counter += X.shape[0]

        # Epoch Train Loss
        train_accuracy[epoch] = 1-epoch_accuracy/epoch_counter
        train_loss[epoch] = epoch_loss/epoch_counter
        print(f"Loss: {train_loss[epoch]:3.3f}, Accuracy: {train_accuracy[epoch]:3.3f}")
        
        # Validation Loop 
        epoch_loss = 0
        epoch_accuracy = 0
        epoch_counter = 0
        model.eval()
        
        # Batch Validation Loop 
        with torch.no_grad():
            for index, (a,X,path, Y) in enumerate(val_dataloader):
                if device is not None:
                    X = X.to(device)
                    Y = Y.to(device)

                # Call the model 
                R = model(X)

                # Compute the loss
                L = loss(R.float(), Y.float())
                pred = R.data

                # Accuracy of the batch 
                epoch_accuracy += np.array(abs(Y-R.data).cpu()).sum()
                epoch_loss += L.data.item() * X.shape[0]
                print('\tBatch ', index, L.data.item())

                epoch_counter += X.shape[0]

        # Epoch Train Loss
        validation_accuracy[epoch] = 1-epoch_accuracy/epoch_counter
        validation_loss[epoch] = epoch_loss/epoch_counter
        print(f"Validation Loss: {validation_loss[epoch]:3.3f}, Accuracy: {validation_accuracy[epoch]:3.3f}")

        if epoch % 10 == 9:
            checkpoint = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': op.state_dict(),
                'train_acc': train_accuracy,
                'train_loss': train_loss,
                'val_acc': validation_accuracy,
                'val_loss': validation_loss
                }
            torch.save(checkpoint, foldername+str(epoch)+'.pth')