In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch import optim, nn
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
from torchvision import models
from torch.nn.functional import relu


## Loss Function


In [1]:

def dice_loss(y_pred, y_true, smooth = 1e-6):
    y_pred = y_pred.flatten()
    y_true = y_true.flatten()

    intersection = (y_pred * y_true).sum()
    dice = (2.0 * intersection + smooth) / (y_pred.sum() + y_true.sum() + smooth)

    return 1-dice

## Dice Score

In [2]:
def dice_score(y_pred, y_true):
    y_pred = torch.sigmoid(y_pred)
    y_pred = (y_pred > 0.5).float()
    intersection = (y_pred * y_true).sum()
    return (2.0 * intersection + 1e-6) / (y_pred.sum() + y_true.sum() + 1e-6) 

## Dataset Loader


In [4]:

class hrgldd_dataset(Dataset): 
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)
      
    def __getitem__(self, index):
        img = self.x[index]
        label = self.y[index]

        img = torch.from_numpy(img).float()
        label = torch.from_numpy(label).float()

        return img, label

## Model

In [7]:


class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 1),
            nn.ReLU(inplace = True),
            nn.Conv2d(out_channels, out_channels,kernel_size = 3, padding = 1),
            nn.ReLU(inplace = True)
        )

    def forward(self, x):
        return self.conv(x)
    

class DownSample(nn.Module):
    
    def __init__(self, in_channels, out_channels):
        super(DownSample, self).__init__()
        self.conv = DoubleConv(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2)

    def forward(self,x):
        down = self.conv(x)
        p = self.pool(down)
        return down, p


class UpSample(nn.Module):
    
    def __init__(self,in_channels,out_channels):
        super(UpSample,self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels//2,
                                     kernel_size = 2,
                                     stride = 2)
        self.conv = DoubleConv(in_channels,out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x1,x2],1)
        x2 = self.conv(x)
        return x2
    
    
class UNet(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()

        self.down_conv_1 = DownSample(4, 64)
        self.down_conv_2 = DownSample(64, 128)
        self.down_conv_3 = DownSample(128, 256) 
        self.down_conv_4 = DownSample(256, 512)

        self.bottleneck = DoubleConv(512, 1024)

        self.up_conv_1 = UpSample(1024,512)
        self.up_conv_2 = UpSample(512,256)
        self.up_conv_3 = UpSample(256,128)
        self.up_conv_4 = UpSample(128,64)

        self.out = nn.Conv2d(in_channels=64, out_channels = 1,
                             kernel_size = 3
                             )

    def forward(self,x):
        down1, p1 = self.down_conv_1(x)
        down2, p2 = self.down_conv_2(p1)
        down3, p3 = self.down_conv_3(p2)
        down4, p4 = self.down_conv_4(p3)

        b = self.bottleneck(p4)

        up_1 = self.up_conv_1(b, down4)
        up_2 = self.up_conv_2(up_1,down3)
        up_3 = self.up_conv_3(up_2, down2)
        up_4 = self.up_conv_4(up_3, down1)

        op = self.out(up_4)

        return op

    


## Train

In [None]:

if __name__ == "__main__":
    LEARNING_RATE = 5e-4
    BATCH_SIZE = 16
    EPOCHS = 1
    DATA_PATH = ""
    MODEL_SAVE_PATH = ""


    path_to_testX = ''
    path_to_testY = ''
    path_to_trainX = ''
    path_to_trainY = ''
    path_to_valX = ''
    path_to_valY = ''

    data_testY = np.load(path_to_testY)
    data_testX = np.load(path_to_testX)
    data_trainX = np.load(path_to_trainX)
    data_trainY = np.load(path_to_trainY)
    data_valX = np.load(path_to_valX)
    data_valY = np.load(path_to_valY)

    train_dataset = hrgldd_dataset(data_trainX, data_trainY)
    val_dataset = hrgldd_dataset(data_valX, data_valY)
    test_dataset = hrgldd_dataset(data_testX, data_testY)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    train_dataloader = DataLoader(dataset = train_dataset,
                                  batch_size = BATCH_SIZE,
                                  shuffle = True)
    val_dataloader = DataLoader(dataset = val_dataset,
                                batch_size = BATCH_SIZE,
                                shuffle = True)

    model = UNet().to(device)
    optimizer = optim.AdamW(model.parameters(), lr = LEARNING_RATE)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode = 'min', factor = 0.1, patience = 3 , verbose = True)

    for epoch in tqdm(range(EPOCHS)):
        train_running_loss = 0
        model.train()
        for index, x_y in enumerate(tqdm(train_dataloader)):
            img = x_y[0].float().permute(0,3,1,2).to(device)
            mask = x_y[1].float().permute(0,3,1,2).to(device)

            optimizer.zero_grad()
            outputs = model(img) 
            loss = dice_loss(outputs, mask)
            loss.backward() #Find Gradients by Backward Pass
            optimizer.step() #Update the weights

            train_running_loss += loss.item()

        train_loss = train_running_loss / len(train_dataloader)

 # --------------- Validation Part ------------------------ #

        model.eval()
        val_running_loss = 0
        val_dice = 0
        with torch.no_grad():
            for index, x_y in enumerate(tqdm(val_dataloader)):
                img = x_y[0].float().permute(0,3,1,2).to(device)
                mask = x_y[1].float().permute(0,3,1,2).to(device)

                y_pred = model(img)
                loss = dice_loss(y_pred, mask)
                val_running_loss += loss.item()

                val_dice += dice_score(y_pred, mask).item()
            
            val_loss = val_running_loss / len(val_dataloader)
            val_dice = val_dice / len(val_dataloader)
        
        print("--" * 30)
        print(f"Train Loss EPOCH {epoch +1}: {train_loss:.4f}")
        print(f"Val Loss EPOCH {epoch +1} : {val_loss:.4f}")
        print(f"Val Dice Score EPOCH {epoch + 1} : {val_dice:.4f}")
        print(f"Current Learning Rate : {optimizer.param_groups[0]['lr']:.6f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
            print(f"Saved best model with val loss : {val_loss : .4f}")

        scheduler.step(val_loss)
