In [1]:
import cv2
import torch
import torch.nn as nn
import numpy as np
import segmentation_models_pytorch as smp
import os
import glob
import matplotlib.pyplot as plt
import random
import time
from torch.utils.data import DataLoader
from sklearn import metrics
from sklearn.metrics import confusion_matrix
import datetime

os.environ['KMP_DUPLICATE_LIB_OK']='True'

In [2]:
train_folder_imgs = np.array(glob.glob(os.path.join("p2psmb128", "images", "*")))
train_folder_labs = np.array(glob.glob(os.path.join("p2psmb128", "labels", "*")))

print(len(train_folder_imgs))
np.random.seed(1)
val_idx = np.random.choice(len(train_folder_imgs), int(2*len(train_folder_imgs)/10), replace=False).astype(int)
val_paths = []
train_paths = []
for i in range(len(train_folder_imgs)):
    if i in val_idx:
        val_paths.append([train_folder_imgs[i], train_folder_labs[i]])
    else:
        train_paths.append([train_folder_imgs[i], train_folder_labs[i]])

1536


In [3]:
len(val_paths), len(train_paths)

(307, 1229)

In [4]:
def WoundTransformation(im, mask, p):

    # Horizontal flip
    if np.random.rand() < p:
        im = cv2.flip(im,1)
        mask = cv2.flip(mask,1)
    
    # Vertical flip
    if np.random.rand() < p:
        im = cv2.flip(im,0)
        mask = cv2.flip(mask,0)
        
    # Gaussian noise
    if np.random.rand() < p: # Add Gaussian noise
        stdv = np.random.uniform(3, 12)
        noise = np.random.normal(0, stdv, im.shape)
        im = np.uint8(np.round(np.clip(im + noise,0,255)))

    return im, mask

class WoundData(torch.utils.data.Dataset):
    
    def __init__(self, data, transform):
        self.data = data
        self.transform = transform
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        # Read image and mask
        im = cv2.imread(self.data[idx][0],-1)
        mask = cv2.imread(self.data[idx][1],0)
        
        if self.transform:
            im, mask = WoundTransformation(im, mask, 0.5)
        
        # From np.array (HxWxC) to torch.tensor (CxHxW). From [0,255] to [0,1]
        im = torch.from_numpy(np.float32(im/255).transpose(2,0,1))
        mask = torch.from_numpy(np.float32(mask/255)).unsqueeze(0)
        
        return im, mask

In [5]:
train_ds = WoundData(train_paths, True)
val_ds = WoundData(val_paths, False)

train_dl = torch.utils.data.DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=0)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=0)

In [6]:
len(train_ds), len(val_ds)

(1229, 307)

In [7]:
# create segmentation model with pretrained encoder

ENCODER = 'timm-regnetx_006'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['vein']
DEVICE = 'cuda'

model = smp.Unet(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=len(CLASSES), 
    in_channels=3,
)

model.to(DEVICE);

In [8]:
#PyTorch
ALPHA = 0.4
BETA = 0.6
GAMMA = 2

class FocalTverskyLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(FocalTverskyLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1, alpha=ALPHA, beta=BETA, gamma=GAMMA):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = torch.sigmoid(inputs)       
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        #True Positives, False Positives & False Negatives
        TP = (inputs * targets).sum()    
        FP = ((1-targets) * inputs).sum()
        FN = (targets * (1-inputs)).sum()
        
        Tversky = (TP + smooth) / (TP + alpha*FP + beta*FN + smooth)  
        FocalTversky = (1 - Tversky)**gamma
                       
        return FocalTversky

In [9]:
##Loss
loss_func = FocalTverskyLoss()

# Optimizer
opt = torch.optim.Adam(model.parameters(), lr=4e-4)

# Learning rate schedule
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt,mode='min',factor=0.1,patience=10,verbose=1)

# freeze encoder
#for param in model.encoder.parameters():
#    param.requires_grad = False

In [10]:
# Training loop
n_epochs = 100
early_stop = 10
#unfreeze_epochs = 3

train_losses = []
val_losses = []

val_loss_min = np.Inf
stagnant = 0

folder_name = "trainrun_" + datetime.datetime.now().strftime("%d%m%Y-%H%M%S")
os.mkdir(folder_name)

for epoch in range(1, n_epochs+1):
    
    #if epoch == unfreeze_epochs:
    #    for param in model.encoder.parameters():
    #        param.requires_grad = True        
    
    e_time = time.time()
    # Get value of the current learning rate
    current_lr = opt.param_groups[0]['lr']
    
    # keep track of training and validation loss
    train_loss = 0.0
    val_loss = 0.0

    # Train the model
    model.train()
    for bid, (xb, yb) in enumerate(train_dl):
        if bid % 10 == 0:
            print(epoch, bid)
        xb = xb.to(DEVICE)
        yb = yb.to(DEVICE)

        # forward pass: compute predicted outputs by passing input to the model
        output = model(xb)

        # calculate the batch losses
        loss = loss_func(output, yb)

        # clear the gradients of all optimized variables
        opt.zero_grad()
        # backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()
        # perform a single optimization step (parameter update)
        opt.step()

        # Update train loss
        train_loss += loss.item()

    
    # Validate the model
    model.eval() # Activate dropout and BatchNorm in eval mode
    with torch.no_grad(): # Save memory bc gradients are not calculated
        for xb, yb in val_dl:
            xb = xb.to(DEVICE) #(n,1,120,120)
            yb = yb.to(DEVICE) #(n,1,120,120)

            # forward pass: compute predicted outputs by passing input to the model
            output = model(xb) #(n,1,120,120)

            # calculate the batch losses
            loss = loss_func(output, yb)

            # Update validation loss
            val_loss += loss.item()
    
    # Calculate average losses of the epoch
    train_loss /= len(train_ds)
    val_loss /= len(val_ds)
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    # Store best model
    if val_loss < val_loss_min:
        print(f'Validation loss decreased ({val_loss_min:.6} --> {val_loss:.6}). Saving model ...')

        torch.save(model.state_dict(), os.path.join(folder_name, 'ganwoundmodel.pt'))
        val_loss_min = val_loss
        stagnant = 0
    else:
        stagnant += 1

    # learning rate schedule
    lr_scheduler.step(val_loss)
    
    print(f"Epoch {epoch}/{n_epochs}, lr = {current_lr:.2e}, "
    f"train loss: {train_loss:.6}, val loss: {val_loss:.6}, ")
    print("time taken for epoch: ", time.time() - e_time)
    print("trained parameters: ", sum(p.numel() for p in model.parameters() if p.requires_grad))
    print("-"*10)
    
    if stagnant >= early_stop:
        print('val loss stagnant for too long, stopping training')
        break

1 0
1 10
1 20
1 30
1 40
1 50
1 60
1 70
Validation loss decreased (inf --> 0.00516384). Saving model ...
Epoch 1/100, lr = 4.00e-04, train loss: 0.010675, val loss: 0.00516384, 
time taken for epoch:  13.369653940200806
trained parameters:  8763441
----------
2 0
2 10
2 20
2 30
2 40
2 50
2 60
2 70
Validation loss decreased (0.00516384 --> 0.0032792). Saving model ...
Epoch 2/100, lr = 4.00e-04, train loss: 0.00451164, val loss: 0.0032792, 
time taken for epoch:  11.019538879394531
trained parameters:  8763441
----------
3 0
3 10
3 20
3 30
3 40
3 50
3 60
3 70
Validation loss decreased (0.0032792 --> 0.00187524). Saving model ...
Epoch 3/100, lr = 4.00e-04, train loss: 0.00268874, val loss: 0.00187524, 
time taken for epoch:  9.684211015701294
trained parameters:  8763441
----------
4 0
4 10
4 20
4 30
4 40
4 50
4 60
4 70
Validation loss decreased (0.00187524 --> 0.00128423). Saving model ...
Epoch 4/100, lr = 4.00e-04, train loss: 0.00173319, val loss: 0.00128423, 
time taken for epoch:  

35 10
35 20
35 30
35 40
35 50
35 60
35 70
Epoch 35/100, lr = 4.00e-04, train loss: 0.000161352, val loss: 0.000425049, 
time taken for epoch:  9.615506172180176
trained parameters:  8763441
----------
36 0
36 10
36 20
36 30
36 40
36 50
36 60
36 70
Epoch 36/100, lr = 4.00e-04, train loss: 0.000162178, val loss: 0.000462768, 
time taken for epoch:  9.308945655822754
trained parameters:  8763441
----------
37 0
37 10
37 20
37 30
37 40
37 50
37 60
37 70
Epoch 37/100, lr = 4.00e-04, train loss: 0.000149256, val loss: 0.00043116, 
time taken for epoch:  9.279958486557007
trained parameters:  8763441
----------
38 0
38 10
38 20
38 30
38 40
38 50
38 60
38 70
Epoch 38/100, lr = 4.00e-04, train loss: 0.00017493, val loss: 0.000463633, 
time taken for epoch:  8.83243441581726
trained parameters:  8763441
----------
39 0
39 10
39 20
39 30
39 40
39 50
39 60
39 70
Epoch 39/100, lr = 4.00e-04, train loss: 0.000156399, val loss: 0.000435705, 
time taken for epoch:  9.123726606369019
trained parameters