In [45]:
import os
import sys
import logging
import tempfile
from glob import glob
from tqdm.notebook import tqdm

import numpy as np
import torch
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
import scipy
import matplotlib.pyplot as plt
import pandas as pd
import cv2
from skimage import io
import tqdm as notebook_tqdm

import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

In [46]:
MRI_FOLDER ="./data/dMRI_Pre"
MASK_FOLDER ="./data/RK_Pre"

TRAIN_SPLIT_PERC = 0.65
TEST_SPLIT_PERC = 0.20
VAL_SPLIT_PERC = 1 - TEST_SPLIT_PERC - TRAIN_SPLIT_PERC
TRAIN_BATCH_SIZE = 32

images = sorted(glob(os.path.join(MRI_FOLDER, "*.jpg")))
segs = sorted(glob(os.path.join(MASK_FOLDER, "*.jpg")))

In [47]:
MRIdata = pd.DataFrame({"dMRIpath": images,'MASKpath':segs})

def get_Wholes_NoWholes(dMRIpath):
    value = np.max(cv2.imread(dMRIpath))
    if value > 0 : 
        return 1
    else:
        return 0

print("Here")  
MRIdata['mask'] = MRIdata['MASKpath'].apply(lambda x: get_Wholes_NoWholes(x))
#MRIdata['MASKpath'] = MRIdata['MASKpath'].apply(lambda x: str(x))
print(MRIdata.shape)
############################################################################

#Prepare Data Loaders
image_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(), ])

mask_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(), ])
#--------------------------------------------------------------------------
#Normalize the data
def NormalizeData(img, mask):
    img = img / 255.
    mask = mask / 255.
    mask[mask > 0.5] = 1.0
    mask[mask <= 0.5] = 0.0
    
    return (img, mask)
#--------------------------------------------------------------------------
class MyDataset(Dataset):
    def __init__(self, df= MRIdata, 
                 NormalizeData = NormalizeData, 
                 image_transform=image_transform, mask_transform=mask_transform):
        self.df = df
        self.image_transform = image_transform
        self.mask_transform = mask_transform
        self.NormalizeData= NormalizeData

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        MRIpath = self.df.loc[idx, 'dMRIpath']
        MASKpath = self.df.loc[idx, 'MASKpath']

        mri = cv2.imread(MRIpath)
        mri = cv2.cvtColor(mri, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(MASKpath)
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
        mri, mask = self.NormalizeData(mri, mask)

        if self.image_transform:
            mri = self.image_transform(mri).float()

        if self.mask_transform:
            mask = self.mask_transform(mask)
        return mri, mask
    
################################################################

def prepare_loaders(df= MRIdata,
                    TrainNum= int(MRIdata.shape[0] * .6), 
                    ValidNum= int(MRIdata.shape[0] * .8), 
                    bs = 32):
#     shuffled = df.sample(frac=1) # TODO TM SHOULD WE SHUFFLE HERE 
    shuffled = df
    Train = shuffled[:TrainNum].reset_index(drop=True)
    Valid = shuffled[TrainNum : ValidNum].reset_index(drop=True)    
    Test  = shuffled[ValidNum:].reset_index(drop=True)
    print({"train":Train.shape, "test":Test.shape, "ValidLoad":Valid.shape, "orig": MRIdata.shape[0], "TrainNum": TrainNum, "ValidNum": ValidNum})

    TrainSet = MyDataset(df = Train)
    ValidSet = MyDataset(df = Valid)
    TestSet = MyDataset(df = Test)

    TrainLoad = DataLoader(TrainSet, batch_size = bs, shuffle = True)
    ValidLoad = DataLoader(ValidSet, batch_size = bs, shuffle = False)
    TestLoad = DataLoader(TestSet, batch_size = 4, shuffle = True)
    
    print("DataLoader Completed")
    
    return TrainLoad, ValidLoad, TestLoad
#--------------------------------------------------------------------------
TrainLoad, ValidLoad, TestLoad = prepare_loaders(df= MRIdata,
                                                            TrainNum= int(MRIdata.shape[0] * TRAIN_SPLIT_PERC), 
                                                            ValidNum= int(MRIdata.shape[0] * (TRAIN_SPLIT_PERC + TEST_SPLIT_PERC)), 
                                                            bs = TRAIN_BATCH_SIZE)

data = next(iter(TrainLoad))
# dataTe = next(iter(TestLoad))
# dataVa = next(iter(ValidLoad))
#data[0].shape, data[1].shape
# print({"train":data, "test":dataTe, "ValidLoad":dataVa})



Here
(135, 3)
{'train': (87, 3), 'test': (21, 3), 'ValidLoad': (27, 3), 'orig': 135, 'TrainNum': 87, 'ValidNum': 114}
DataLoader Completed


In [48]:
#------From Here We Build the UNet Model
device = torch.device("cuda:0")
if torch.cuda.is_available():
    print("GPU - {}\n".format(torch.cuda.get_device_name()))
else:
    device = torch.device("cpu")
    print("NO GPU: Using CPU")
print(device)
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

class Block(nn.Module):
    def __init__(self, inputs = 3, middles = 64, outs = 64):
        super().__init__()
        
        self.conv1 = nn.Conv2d(inputs, middles, 3, 1, 1)
        self.conv2 = nn.Conv2d(middles, outs, 3, 1, 1)
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm2d(outs)
        self.pool = nn.MaxPool2d(2, 2)
        
    def forward(self, x):       
        x = self.relu(self.conv1(x))
        x = self.relu(self.bn(self.conv2(x)))    
        return self.pool(x), x
 
###################################################################

class UNet(nn.Module):
    def __init__(self,):
        super().__init__()

        self.en1 = Block(3, 64, 64)
        self.en2 = Block(64, 128, 128)
        self.en3 = Block(128, 256, 256)
        self.en4 = Block(256, 512, 512)
        self.en5 = Block(512, 1024, 512)
        
        self.upsample4 = nn.ConvTranspose2d(512, 512, 2, stride = 2)
        self.de4 = Block(1024, 512, 256)
        
        self.upsample3 = nn.ConvTranspose2d(256, 256, 2, stride = 2)
        self.de3 = Block(512, 256, 128)
        
        self.upsample2 = nn.ConvTranspose2d(128, 128, 2, stride = 2)
        self.de2 = Block(256, 128, 64)
        
        self.upsample1 = nn.ConvTranspose2d(64, 64, 2, stride = 2)
        self.de1 = Block(128, 64, 64)
        
        self.conv_last = nn.Conv2d(64, 1, kernel_size=1, stride = 1, padding = 0)
        
    def forward(self, x):
        # x: [bs, 3, 256, 256]
        
        x, e1 = self.en1(x)
        # x: [bs, 64, 128, 128]
        # e1: [bs, 64, 256, 256]
        
        x, e2 = self.en2(x)
        # x: [bs, 128, 64, 64]
        # e2: [bs, 128, 128, 128]
        
        x, e3 = self.en3(x)
        # x: [bs, 256, 32, 32]
        # e3: [bs, 256, 64, 64]
        
        x, e4 = self.en4(x)
        # x: [bs, 512, 16, 16]
        # e4: [bs, 512, 32, 32]
        
        _, x = self.en5(x)
        # x: [bs, 512, 16, 16]
        
        x = self.upsample4(x)
        # x: [bs, 512, 32, 32]
        x = torch.cat([x, e4], dim=1)
        # x: [bs, 1024, 32, 32]
        _,  x = self.de4(x)
        # x: [bs, 256, 32, 32]
        
        x = self.upsample3(x)
        # x: [bs, 256, 64, 64]
        x = torch.cat([x, e3], dim=1)
        # x: [bs, 512, 64, 64]
        _, x = self.de3(x)
        # x: [bs, 128, 64, 64]
        
        x = self.upsample2(x)
        # x: [bs, 128, 128, 128]
        x = torch.cat([x, e2], dim=1)
        # x: [bs, 256, 128, 128]
        _, x = self.de2(x)
        # x: [bs, 64, 128, 128]
        
        x = self.upsample1(x)
        # x: [bs, 64, 256, 256]
        x = torch.cat([x, e1], dim=1)
        # x: [bs, 128, 256,256, 256
        _, x = self.de1(x)
        # x: [bs, 64, 256, 256]
        
        x = self.conv_last(x)
       
        return x
          
model = UNet().to(device)
#model
##########################################################


NO GPU: Using CPU
cpu


In [49]:
#-------Define Loss and Optimizer
import segmentation_models_pytorch as smp
from torch.optim import lr_scheduler

# loss_fn = nn.BCELoss().to(device)
loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
optimizer = torch.optim.Adam(model.parameters(), )

# Scheduler
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max = 200,eta_min = 1e-6)

##############################################################=======train_one_epoch
def train_one_epoch(model = model, 
                    dataloader = TrainLoad, 
                    loss_fn = loss_fn, 
                    optimizer = optimizer,
                    scheduler = None,
                    device = device, 
                    epoch = 1):
    model.train() 
    train_loss, dataset_size = 0,  0
    
    bar = tqdm(dataloader, total = len(dataloader))
    tp_l, fp_l, fn_l, tn_l = [], [], [], []
    
    for data in bar:
        x = data[0].to(device)     
        y_true = data[1].to(device) 
        y_pred = model(x)          
        
        loss = loss_fn(y_pred, y_true)
        
        pred_mask = (y_pred > 0.5).float()
        btp, bfp, bfn, btn = smp.metrics.get_stats(pred_mask.long(), y_true.long(), mode="binary")

        # 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if scheduler is not None:
            scheduler.step()
        
        # train_epoch_loss
        bs = x.shape[0]
        dataset_size += bs
        train_loss += (loss.item() * bs)
        train_epoch_loss = train_loss / dataset_size
        
        tp_l.append(btp)
        fp_l.append(bfp)
        fn_l.append(bfn)
        tn_l.append(btn)
        
        tp = torch.cat(tp_l)
        fp = torch.cat(fp_l)
        fn = torch.cat(fn_l)
        tn = torch.cat(tn_l)
        
        recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro")
        precision = smp.metrics.precision(tp, fp, fn, tn, reduction="micro")
        
        f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
        accuracy = smp.metrics.accuracy(tp, fp, fn, tn, reduction="macro")
        
        # per image IoU means that we first calculate IoU score for each image 
        # and then compute mean over these scores
        per_image_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")
        
        # dataset IoU means that we aggregate intersection and union over whole dataset
        # and then compute IoU score. The difference between dataset_iou and per_image_iou scores
        # in this particular case will not be much, however for dataset 
        # with "empty" images (images without target class) a large gap could be observed. 
        # Empty images influence a lot on per_image_iou and much less on dataset_iou.
        dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")

        bar.set_description(f"EP:{epoch} | TL:{train_epoch_loss:.3e} | ACC: {accuracy:.2f} | F1: {f1_score:.3f} ")
        
    metrics =  dict()
    
    metrics['f1_score'] = f1_score.detach().cpu().item()
    metrics['accuracy'] = accuracy.detach().cpu().item()
    
    metrics['recall'] = recall.detach().cpu().item()
    metrics['precision'] = precision.detach().cpu().item()
    
    metrics['dataset_iou'] = dataset_iou.detach().cpu().item()
    metrics['per_iou'] = per_image_iou.detach().cpu().item()
    
    metrics['loss'] = train_epoch_loss

    return metrics

###################################################################################-------run one_epoch

@torch.no_grad()
def valid_one_epoch(model = model, 
                    dataloader = ValidLoad, 
                    loss_fn = loss_fn,
                    device = device, 
                    epoch = 0):
    model.eval() 
    valid_loss, dataset_size = 0,  0
    bar = tqdm(dataloader, total = len(dataloader))
    tp_l, fp_l, fn_l, tn_l = [], [], [], []
    
    with torch.no_grad():
        for data in bar:
            x = data[0].to(device)     
            y_true = data[1].to(device) 
            y_pred = model(x)        
            
            loss = loss_fn(y_pred, y_true)
            
            pred_mask = (y_pred > 0.5).float()
            btp, bfp, bfn, btn = smp.metrics.get_stats(pred_mask.long(), y_true.long(), mode="binary")

            tp_l.append(btp)
            fp_l.append(bfp)
            fn_l.append(bfn)
            tn_l.append(btn)

            tp = torch.cat(tp_l)
            fp = torch.cat(fp_l)
            fn = torch.cat(fn_l)
            tn = torch.cat(tn_l)

            recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro")
            precision = smp.metrics.precision(tp, fp, fn, tn, reduction="micro")

            f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
            accuracy = smp.metrics.accuracy(tp, fp, fn, tn, reduction="macro")

            # per image IoU means that we first calculate IoU score for each image 
            # and then compute mean over these scores
            per_image_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")

            # dataset IoU means that we aggregate intersection and union over whole dataset
            # and then compute IoU score. The difference between dataset_iou and per_image_iou scores
            # in this particular case will not be much, however for dataset 
            # with "empty" images (images without target class) a large gap could be observed. 
            # Empty images influence a lot on per_image_iou and much less on dataset_iou.
            dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")

            # valid_epoch_loss 
            bs = x.shape[0]
            dataset_size += bs
            valid_loss += (loss.item() * bs)
            valid_epoch_loss = valid_loss / dataset_size

            bar.set_description(f"EP:{epoch} | VL:{valid_epoch_loss:.3e} | ACC: {accuracy:.2f} | F1: {f1_score:.3f} ")

    metrics =  dict()
    
    metrics['f1_score'] = f1_score.detach().cpu().item()
    metrics['accuracy'] = accuracy.detach().cpu().item()
    
    metrics['recall'] = recall.detach().cpu().item()
    metrics['precision'] = precision.detach().cpu().item()
    
    metrics['dataset_iou'] = dataset_iou.detach().cpu().item()
    metrics['per_iou'] = per_image_iou.detach().cpu().item()
    
    metrics['loss'] = valid_epoch_loss

    return metrics
##########################################################################

In [50]:
import copy
import time
import gc
from tqdm import tqdm
import json


def run_training(model = model, 
                 loss_fn = loss_fn, 
                 TrainLoad = TrainLoad,
                 ValidLoad = ValidLoad,
                 optimizer = optimizer, 
                 device = device, 
                 n_epochs=100, 
                 early_stop = 20,
                 scheduler = None):

    start = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())

    lowest_epoch, lowest_loss = np.inf, np.inf
    
    train_history, valid_history = [],  []
    train_recalls, valid_recalls = [],  []
    
    train_pres, valid_pres = [],  []
    train_accs, valid_accs = [],  []
    
    train_f1s, valid_f1s = [],  []
    
    train_per_ious, valid_per_ious = [], []
    train_dataset_ious, valid_dataset_ious = [], []
    
    print_iter = 5

    best_score = 0
    best_model = "None"

    for epoch in range(0, n_epochs):
        gc.collect()

        train_metrics = train_one_epoch(model= model,
                                       dataloader = TrainLoad,
                                       optimizer = optimizer,
                                       scheduler = scheduler,
                                       device = device,
                                       epoch = epoch + 1
                                       )
        
        valid_metrics = valid_one_epoch(model,
                                       dataloader = ValidLoad,
                                       device = device,
                                       epoch = epoch + 1)
        
        
        train_history += [train_metrics['loss']]
        valid_history += [valid_metrics['loss']]
        
        train_recalls += [train_metrics['recall']]
        valid_recalls += [valid_metrics['recall']]
        
        train_pres += [train_metrics['precision']]
        valid_pres += [valid_metrics['precision']]
        
        train_accs += [train_metrics['accuracy']]
        valid_accs += [valid_metrics['accuracy']]
        
        train_f1s += [train_metrics['f1_score']]
        valid_f1s += [valid_metrics['f1_score']]
        
        train_per_ious += [train_metrics['per_iou']]
        valid_per_ious += [valid_metrics['per_iou']]
        
        train_dataset_ious += [train_metrics['dataset_iou']]
        valid_dataset_ious += [valid_metrics['dataset_iou']]
        
        
        print()
        if (epoch + 1) % print_iter == 0:
            print(f"Epoch:{epoch + 1}|TL:{train_metrics['loss']:.3e}|VL:{valid_metrics['loss']:.3e}|F1:{valid_metrics['f1_score']:.4f}|Dataset IOU:{valid_metrics['dataset_iou']:.4f}|Per Img IOU:{valid_metrics['per_iou']:.4f}|")
            print()
            
        if best_score < valid_metrics['f1_score']:
            print(f"Validation F1 Improved({best_score:.2f}) --> ({ valid_metrics['f1_score']:.2f})")
            best_model = model
            best_score = valid_metrics['f1_score']
            best_model = copy.deepcopy(model.state_dict())
            PATH2 =  f"model_f1.bin"
            torch.save(model.state_dict(), PATH2)
            print(f"Better_F1_Model Saved")
            print()

        if valid_metrics['loss']< lowest_loss:
            print(f"Validation Loss Improved({lowest_loss:.4e}) --> ({ valid_metrics['loss']:.4e})")
            lowest_loss = valid_metrics['loss']
            lowest_epoch = epoch
            best_model_wts = copy.deepcopy(model.state_dict())
            PATH = f"model.bin"
            torch.save(model.state_dict(), PATH)
            print(f"Better Loss Model Saved")
            print()
        else:
            if early_stop > 0 and lowest_epoch + early_stop < epoch + 1:
                print("no improvement") 
                break
                
    print()
    end = time.time()
    time_elapsed = end - start
    print('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format(time_elapsed // 3600, (time_elapsed % 3600) // 60, (time_elapsed % 3600) % 60))
    print("Best Loss: %.4e at %d th Epoch" % (lowest_loss, lowest_epoch))

    # load best model weights
    model.load_state_dict(torch.load('./model_f1.bin'))

    result = dict()
    result["Train Loss"] = train_history
    result["Valid Loss"] = valid_history
    
    result["Train Recall"] = train_recalls
    result["Valid Recall"] = valid_recalls
    
    result["Train Precision"] = train_pres
    result["Valid Precision"] = valid_pres
    
    result["Train Accuracy"] = train_accs
    result["Valid Accuracy"] = valid_accs
    
    result["Train F1 Score"] = train_f1s
    result["Valid F1 Score"] = valid_f1s
    
    result["Train per Image IOU"] = train_per_ious
    result["Valid per Image IOU"] = valid_per_ious
    
    result["Train Dataset IOU"] = train_dataset_ious
    result["Valid Dataset IOU"] = valid_dataset_ious
    
    return model, result

##############################################################

In [None]:
# Run Training
model, result = run_training(model = model, 
                             loss_fn = loss_fn, 
                             optimizer = optimizer, 
                             device = device, 
                             scheduler = scheduler,
                             n_epochs = 10)

torch.save(model.state_dict(), "/home1/cliog/MLProject/Test1/OLD_FILE/MyModel.pth")
FILE_NAME = "/home1/cliog/MLProject/Test1/OLD_FILE/Result.json"
with open(FILE_NAME, 'w') as convert_file:
     convert_file.write(json.dumps(result))

  0%|                                                                                            | 0/3 [00:00<?, ?it/s]

In [None]:
## Train/Valid Loss History
plot_from = 0
plt.figure(figsize=(20, 10))
plt.title("Train/Valid Loss History", fontsize = 20)
plt.plot(
    range(0, len(result['Train Loss'][plot_from:])), 
    result['Train Loss'][plot_from:], 
    label = 'Train Loss'
    )

plt.plot(
    range(0, len(result['Valid Loss'][plot_from:])), 
    result['Valid Loss'][plot_from:], 
    label = 'Valid Loss'
    )

plt.legend()
# plt.yscale('log')
plt.grid(True)
plt.show()



In [None]:
## Train/Valid Accuracy History
plot_from = 0
plt.figure(figsize=(20, 10))
plt.title("Train/Valid Accuracy History", fontsize = 20)
plt.plot(
    range(0, len(result['Train Accuracy'][plot_from:])), 
    result['Train Accuracy'][plot_from:], 
    label = 'Train Accuracy'
    )

plt.plot(
    range(0, len(result['Valid Accuracy'][plot_from:])), 
    result['Valid Accuracy'][plot_from:], 
    label = 'Valid Accuracy'
    )

plt.legend()
# plt.yscale('log')
plt.grid(True)