In [1]:
import os
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchsummary import summary
from torch.utils.tensorboard import SummaryWriter

from unet_model import UNet
from seep_dataset import SeepImageDataset
from utils import ImageHelperFunctions, ImageTransformFunctions


KeyboardInterrupt



In [None]:
IMAGE_DIR = "./seep_detection/images_256/"
MASK_DIR = "./seep_detection/masks_256/"

NUM_CLASSES = 8
NUM_CHANNELS = 1
BATCH_SIZE = 8
N_EPOCHS = 400
LR = 1e-2
PATIENCE = 50
class_weights = [1, 20, 40, 40, 40, 20, 20, 80]

TRAIN = True
SUMMARY_WRITER_NAME = f"FocalLoss_{str(class_weights)}_lr={str(LR)}_batch_size={str(BATCH_SIZE)}_p={str(PATIENCE)}"

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
helper = ImageHelperFunctions()
list_img_paths = helper.list_files_in_dir(IMAGE_DIR, '.tif')
list_mask_paths = helper.list_files_in_dir(MASK_DIR, '.tif')

images = [np.array(helper.read_image(image)) for image in list_img_paths]
# mu, std = ImageHelperFunctions.find_mu_and_std(images)
# print(mu, std)

masks = [np.array(helper.read_image(mask)) for mask in list_mask_paths]
classes, counts = helper.count_classes_in_arr(masks, show=True)

# There shows a huge inbalance between the classes
# Class 0 is the background, so we remove it from the plot

# class_weights = [(sum(counts) / (c)) for c in counts]
# class_weights = torch.tensor(class_weights)
# print(class_weights)

class_weights = torch.FloatTensor(class_weights)
class_weights = class_weights.to(device)

# My solution was to just to manually put in some weights sort of like a hyper-parameter.
# I will be passing it as a parameter in the loss function.

In [None]:
transform = ImageTransformFunctions()

for i in range(5):
    image = helper.read_image(list_img_paths[i])
    mask = helper.read_mask(list_mask_paths[i])
    mask_to_pal = transform.mask_to_palette(mask)
    helper.preview_images((image, mask, mask_to_pal))
    
helper.image_properties(image)
helper.image_properties(mask)

In [None]:
dataset = SeepImageDataset(IMAGE_DIR, MASK_DIR)
n_test = int(np.floor(0.1 * len(dataset)))
n_train = len(dataset) - n_test

train_ds, test_ds = random_split(dataset, [n_train, n_test])

In [None]:
train_loader = DataLoader(
    dataset = train_ds,
    batch_size = BATCH_SIZE,
    shuffle = True
)

test_loader = DataLoader(
    dataset = test_ds,
    batch_size = BATCH_SIZE,
    shuffle = True
)

x, y = next(iter(train_loader))
print(x[0], x.shape, y.shape)

mu, std = helper.find_mu_and_std(train_loader)
print(mu, std)

In [None]:
image, mask = next(iter(train_ds))

print(type(image), image.shape)
image = image.numpy()
print(image.min(), image.max(), image.std(), image.mean())
image = np.squeeze(image)
print(image.shape)
print('-----')

print(type(mask), mask.shape)
mask = mask.numpy()
print(mask.shape)
mask = np.squeeze(mask)
print(mask.min(), mask.max(), mask.std(), mask.mean())

helper.preview_images((image, mask))

# TRAINING

In [None]:
model = UNet(NUM_CHANNELS, NUM_CLASSES).to(device)

In [None]:
class FocalLoss(nn.modules.loss._WeightedLoss):
    def __init__(self, weight=None, gamma=2,reduction='mean'):
        super(FocalLoss, self).__init__(weight,reduction=reduction)
        self.gamma = gamma
        self.weight = weight #weight parameter will act as the alpha parameter to balance class weights

    def forward(self, input, target):

        ce_loss = F.cross_entropy(input, target,reduction=self.reduction,weight=self.weight)
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
        return focal_loss

In [None]:
def train_overfit_single_batch(train_loader, model, loss_fn, opt, n_epochs=N_EPOCHS, lr=LR, patience=PATIENCE):
    x, y = next(iter(train_loader))
    for epoch in tqdm(range(n_epochs)):
        model.train()
        x = x.to(device)
        targets = y.to(device)
        
        preds = model(x)
        
        loss = loss_fn(preds, targets)
        print(f"epoch: {epoch}, loss: {loss.item()}")
        
        opt.zero_grad()
        loss.backward()
        
        opt.step()

In [None]:
def train(train_loader, val_loader, model, loss_fn, opt, n_epochs=N_EPOCHS, lr=LR, batch_size=BATCH_SIZE, patience=PATIENCE, fn=SUMMARY_WRITER_NAME, counter=0):
    writer = SummaryWriter(f"./runs/{fn}")

    best_val_loss = float("inf")
    
    total_steps = 0
    total_loss = 0
    total_val_loss = 0
    for epoch in tqdm(range(n_epochs)):
        model.train()
        for x, y in train_loader:
            x = x.to(device)
            targets = y.to(device)   
                   
            preds = model(x)
            
            loss = loss_fn(preds, targets)
            
            opt.zero_grad()
            loss.backward()
            
            opt.step()
            total_loss = total_loss + loss.item()
            
        total_steps += batch_size
        avg_train_loss = total_loss / len(train_loader)
        writer.add_scalar('train_loss', avg_train_loss, total_steps)
        print(f"epoch {epoch} loss {avg_train_loss}")
        
        if epoch % 10 == 0:
            model.eval()
            with torch.no_grad():
                for x, y in val_loader:
                    x = x.to(device)
                    targets = y.to(device)
                    
                    preds = model(x)
                    
                    val_loss = loss_fn(preds, targets)
                    
                    total_val_loss = total_val_loss + val_loss.item()
            
            avg_val_loss = total_val_loss / len(val_loader)
            writer.add_scalar('val_loss', avg_val_loss, total_steps)
            print(f"epoch {epoch}, val_loss: {avg_val_loss}")
            
            if avg_val_loss <= best_val_loss:
                best_val_loss = avg_val_loss
                torch.save(model.state_dict(), f"./models/{fn}")
            else:
                counter = counter + 1
                if counter >= patience:
                    break
                    
        total_loss = 0
        total_val_loss = 0
                
    writer.close()

In [None]:
if TRAIN:
   # loss_fn = nn.CrossEntropyLoss(weight=class_weights)
    loss_fn = FocalLoss(weight=class_weights)
    opt = optim.Adam(model.parameters(), lr=LR)
    
    train_overfit_single_batch(train_loader, model, loss_fn, opt, n_epochs=20, lr=LR, patience=PATIENCE)
    train(train_loader, test_loader, model, loss_fn, opt)