# Problem 4

In [28]:
import numpy as np
import pandas as pd
from torch import nn 
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms.functional as TF
from sklearn import metrics


import os
import cv2
from torch.nn import ConvTranspose2d
from torch.nn import Conv2d
from torch.nn import MaxPool2d
from torch.nn import Module
from torch.nn import ModuleList
from torch.nn import ReLU
from torchvision.transforms import CenterCrop
from torch.nn import functional as F
import torch

from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from sklearn.model_selection import train_test_split
from torchvision import transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import time
import os

In [2]:
BATCH_SIZE = 32

### Data input and train_test_split

In [3]:
data_path_4 = '/Users/charliejiang/Documents/Stanford/Machine-Learning-for-Neuroimaging/HW2/data_assignment_2/lgg-mri-segmentation/kaggle_3m'

metadata = pd.read_csv(data_path_4 + '/data.csv')
patient_id = metadata['Patient']

In [4]:
folder_lst = []
for i in os.listdir(data_path_4):
    if i[0:12] in patient_id.values:
        folder_lst.append(os.path.join(data_path_4,i))
print(len(folder_lst))

110


In [5]:
train_folders, test_folders = train_test_split(folder_lst,test_size=0.2, random_state=42)
train_folders,val_folders = train_test_split(train_folders,test_size=0.125, random_state=42)

In [6]:
print(len(train_folders),len(val_folders),len(test_folders))

77 11 22


In [41]:
train_imgs = []
val_imgs = []
test_imgs = []
train_masks = []
val_masks = []
test_masks = []


for i in train_folders:
    for j in os.listdir(i):
        if j[-8:] != 'mask.tif':
            train_imgs.append(os.path.join(data_path_4,i,j))
            train_masks.append(os.path.join(data_path_4,i,j[:-4])+'_mask.tif')
            
for i in val_folders:
    for j in os.listdir(i):
        if j[-8:] != 'mask.tif':
            val_imgs.append(os.path.join(data_path_4,i,j))
            val_masks.append(os.path.join(data_path_4,i,j[:-4])+'_mask.tif')
            
            
for i in test_folders:
    for j in os.listdir(i):
        if j[-8:] != 'mask.tif':
            test_imgs.append(os.path.join(data_path_4,i,j))
            test_masks.append(os.path.join(data_path_4,i,j[:-4])+'_mask.tif')

In [42]:
class SegmentationDataset(Dataset):
    def __init__(self, imagePaths, maskPaths,transforms):
        self.imagePaths = imagePaths
        self.maskPaths = maskPaths
        self.transforms = transforms
    def __len__(self):
        return len(self.imagePaths)
    def __getitem__(self, idx):
        imagePath = self.imagePaths[idx]
        image = cv2.imread(imagePath)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        mask = cv2.imread(self.maskPaths[idx], 0)
        
        if self.transforms is not None:
            image = self.transforms(image)
            mask = self.transforms(mask)
        
        return (image, mask)

In [43]:
transform = transforms.Compose([transforms.ToPILImage(),
    transforms.ToTensor()])


trainDS = SegmentationDataset(imagePaths=train_imgs, maskPaths=train_masks,transforms = transform)
valDS = SegmentationDataset(imagePaths=val_imgs, maskPaths=val_masks, transforms = transform)
testDS = SegmentationDataset(imagePaths=test_imgs, maskPaths=test_masks, transforms = transform)
print(f"[INFO] found {len(trainDS)} examples in the training set...")
print(f"[INFO] found {len(valDS)} examples in the val set...")
print(f"[INFO] found {len(testDS)} examples in the test set...")

trainLoader = DataLoader(trainDS, shuffle=True,batch_size = BATCH_SIZE, pin_memory=True)

valLoader = DataLoader(valDS, shuffle=False,batch_size = BATCH_SIZE,pin_memory=True)

testLoader = DataLoader(testDS, shuffle=False,batch_size = BATCH_SIZE, pin_memory=True)


[INFO] found 2833 examples in the training set...
[INFO] found 378 examples in the val set...
[INFO] found 718 examples in the test set...


In [44]:
initial_shape = (next(iter(trainDS))[1].shape[1],next(iter(trainDS))[1].shape[2])

In [45]:
class DConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3),
        )

    def forward(self, x):
        return self.double_conv(x)

In [48]:
class Encoder(Module):
    def __init__(self, channels=(3, 16, 32, 64)):
        super().__init__()
        self.encBlocks = ModuleList(
            [DConvBlock(channels[i], channels[i + 1])
                 for i in range(len(channels) - 1)])
        self.pool = MaxPool2d(2)
    def forward(self, x):
        blockOutputs = []
        for block in self.encBlocks:
            x = block(x)
            blockOutputs.append(x)
            x = self.pool(x)
        return blockOutputs

In [49]:
class Decoder(Module):
    def __init__(self, channels=(64, 32, 16)):
        super().__init__()
        self.channels = channels
        self.upconvs = ModuleList(
            [ConvTranspose2d(channels[i], channels[i + 1], 2, 2)
                 for i in range(len(channels) - 1)])
        self.dec_blocks = ModuleList(
            [DConvBlock(channels[i], channels[i + 1])
                for i in range(len(channels) - 1)])
    def forward(self, x, encFeatures):
        for i in range(len(self.channels) - 1):
            x = self.upconvs[i](x)
            encFeat = self.crop(encFeatures[i], x)
            x = torch.cat([x, encFeat], dim=1)
            x = self.dec_blocks[i](x)
        return x
    def crop(self, encFeatures, x):
        (_, _, H, W) = x.shape
        encFeatures = CenterCrop([H, W])(encFeatures)
        return encFeatures


In [50]:
class UNet(Module):
    def __init__(self, encChannels=(3, 16, 32, 64),
                 decChannels=(64, 32, 16),
                 nbClasses=1, retainDim=True,
                 outSize=initial_shape):
        super().__init__()

        self.encoder = Encoder(encChannels)
        self.decoder = Decoder(decChannels)

        self.head = Conv2d(decChannels[-1], nbClasses, 1)
        self.retainDim = retainDim
        self.outSize = outSize
    def forward(self, x):
        encFeatures = self.encoder(x)
        decFeatures = self.decoder(encFeatures[::-1][0],encFeatures[::-1][1:])
        map = self.head(decFeatures)
        if self.retainDim:
            map = F.interpolate(map, self.outSize)
        return map

In [51]:
u_net=UNet()
print(u_net)

UNet(
  (encoder): Encoder(
    (encBlocks): ModuleList(
      (0): DConvBlock(
        (double_conv): Sequential(
          (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
          (1): ReLU(inplace=True)
          (2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
        )
      )
      (1): DConvBlock(
        (double_conv): Sequential(
          (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
          (1): ReLU(inplace=True)
          (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
        )
      )
      (2): DConvBlock(
        (double_conv): Sequential(
          (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
        )
      )
    )
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (decoder): Decoder(
    (upconvs): ModuleList(
      (0): ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2))
      (1

In [52]:
lr = 0.001
lossfun = BCEWithLogitsLoss()
optimizer = Adam(u_net.parameters(), lr)

In [53]:
from sklearn.metrics import jaccard_score
#from sklearn.metrics import dice_score

In [54]:
trainSteps = len(trainDS) // BATCH_SIZE
valSteps  = len(valDS) // BATCH_SIZE
testSteps = len(testDS) // BATCH_SIZE

In [55]:
# loop over epochs
def train_unet(unet,trainLoader,valLoader, opt,lossFunc,epoch_cnt = 100):
    # Training model
    startTime = time.time()
    train_loss = []
    val_loss = []
    traindice_lst = []
    #JI_lst = []
    
    for e in tqdm(range(epoch_cnt)):
        unet.train()
        total_train_loss = 0
        total_val_loss = 0
        train_JI = 0
        val_JI = 0
        
        for (i, (x, y)) in enumerate(trainLoader):
            pred = unet(x)
            loss = lossFunc(pred, y)
            
            opt.zero_grad()
            loss.backward()
            opt.step()
            
            #JI = jaccard_score(pred.detach().numpy(),y)
            total_train_loss += loss
            #train_JI+= JI
            
        
        with torch.no_grad():
            unet.eval()
            for (x, y) in valLoader:
                pred = unet(x)
                total_val_loss += lossFunc(pred,y)
                #val_JI += jaccard_score(pred.detach().numpy(),y)
            

        avgTrainLoss = total_train_loss/trainSteps
        avgValLoss = total_val_loss/valSteps
        
        #avgTrainJI = train_JI/trainSteps
        #avgValJI = val_JI/valSteps
        
        print(f"Epoch {e}: Training Loss: {avgTrainLoss} mIOU: ")
        print(f"           Validation Loss: {avgValLoss} mIOU: ")

    endTime = time.time()
    print("[INFO] total time taken to train the model: {:.2f}s".format(
        endTime - startTime))
    
    #Validation

In [None]:
train_unet(u_net,trainLoader,valLoader,optimizer,lossfun,epoch_cnt = 100)

  1%|▍                                      | 1/100 [06:08<10:08:17, 368.66s/it]

Epoch 0: Training Loss: 0.10560795664787292 mIOU: 
           Validation Loss: 0.030888782814145088 mIOU: 


  2%|▊                                       | 2/100 [11:50<9:36:01, 352.67s/it]

Epoch 1: Training Loss: 0.03485293686389923 mIOU: 
           Validation Loss: 0.027261190116405487 mIOU: 


  3%|█▏                                      | 3/100 [18:08<9:49:25, 364.60s/it]

Epoch 2: Training Loss: 0.03301425278186798 mIOU: 
           Validation Loss: 0.02686992473900318 mIOU: 


  4%|█▌                                      | 4/100 [24:16<9:45:10, 365.73s/it]

Epoch 3: Training Loss: 0.03152129054069519 mIOU: 
           Validation Loss: 0.030174775049090385 mIOU: 


  5%|██                                      | 5/100 [30:29<9:43:16, 368.39s/it]

Epoch 4: Training Loss: 0.02983386442065239 mIOU: 
           Validation Loss: 0.02277674525976181 mIOU: 


  6%|██▍                                     | 6/100 [36:44<9:40:53, 370.78s/it]

Epoch 5: Training Loss: 0.02806568704545498 mIOU: 
           Validation Loss: 0.02509339340031147 mIOU: 


  7%|██▊                                     | 7/100 [42:51<9:32:47, 369.54s/it]

Epoch 6: Training Loss: 0.02616105228662491 mIOU: 
           Validation Loss: 0.020944060757756233 mIOU: 


  8%|██▉                                  | 8/100 [1:14:21<21:48:28, 853.35s/it]

Epoch 7: Training Loss: 0.02616053819656372 mIOU: 
           Validation Loss: 0.0206635482609272 mIOU: 


  9%|███▎                                 | 9/100 [1:19:53<17:27:21, 690.57s/it]

Epoch 8: Training Loss: 0.024519948288798332 mIOU: 
           Validation Loss: 0.02189638838171959 mIOU: 


 10%|███▌                                | 10/100 [1:26:09<14:50:10, 593.45s/it]

Epoch 9: Training Loss: 0.024398623034358025 mIOU: 
           Validation Loss: 0.018476592376828194 mIOU: 


 11%|███▉                                | 11/100 [1:32:24<13:00:58, 526.50s/it]

Epoch 10: Training Loss: 0.02347376197576523 mIOU: 
           Validation Loss: 0.01847982592880726 mIOU: 


 12%|████▎                               | 12/100 [1:38:23<11:37:16, 475.42s/it]

Epoch 11: Training Loss: 0.023629145696759224 mIOU: 
           Validation Loss: 0.019028041511774063 mIOU: 


 13%|████▋                               | 13/100 [1:52:31<14:13:08, 588.37s/it]

Epoch 12: Training Loss: 0.022552764043211937 mIOU: 
           Validation Loss: 0.018341796472668648 mIOU: 


 14%|█████                               | 14/100 [1:59:05<12:39:21, 529.79s/it]

Epoch 13: Training Loss: 0.023123174905776978 mIOU: 
           Validation Loss: 0.01834752783179283 mIOU: 


 15%|█████▍                              | 15/100 [2:05:17<11:23:00, 482.12s/it]

Epoch 14: Training Loss: 0.022049862891435623 mIOU: 
           Validation Loss: 0.019373415037989616 mIOU: 


 16%|█████▊                              | 16/100 [2:12:55<11:04:37, 474.74s/it]

Epoch 15: Training Loss: 0.0221430491656065 mIOU: 
           Validation Loss: 0.017744479700922966 mIOU: 


 17%|██████                              | 17/100 [2:20:02<10:37:01, 460.50s/it]

Epoch 16: Training Loss: 0.02178223989903927 mIOU: 
           Validation Loss: 0.018432416021823883 mIOU: 
