# Train U-Net on multiclass problems

This is very similar code to `11_unet_carvana.ipynb` but we have 2 classes instead of 1.

With approximately 400 images in the training set, I trained for 100 epochs and get very good results for face tracking.


In [1]:
import torch
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
from torch import optim
import torch.nn as nn
from datetime import datetime
import albumentations as A
import cv2

from unetTracker.trackingProject import TrackingProject
from unetTracker.dataset import UNetDataset
from unetTracker.unet import Unet

In [2]:
project = TrackingProject(name="faceTrack",root_folder = "/home/kevin/Documents/trackingProjects/")

Project directory: /home/kevin/Documents/trackingProjects/faceTrack
Loading /home/kevin/Documents/trackingProjects/faceTrack/config.yalm
{'augmentation_HorizontalFlipProb': 0.0, 'augmentation_RandomBrightnessContrastProb': 0.2, 'augmentation_RandomSizedCropProb': 1.0, 'augmentation_RotateProb': 0.3, 'image_size': [480, 640], 'labeling_ImageEnlargeFactor': 2.0, 'name': 'faceTrack', 'normalization_values': {'means': [0.5110162496566772, 0.4608974754810333, 0.4772901237010956], 'stds': [0.2727729380130768, 0.2578601539134979, 0.256255567073822]}, 'object_colors': [(0.0, 0.0, 255.0), (255.0, 0.0, 0.0), (255.0, 255.0, 0.0), (128.0, 0.0, 128.0)], 'objects': ['nose', 'chin', 'rEye', 'lEye'], 'target_radius': 10}


## Hyperparameters

In [3]:
LEARNING_RATE=1e-4
DEVICE = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) 
BATCH_SIZE=4
NUM_EPOCHS = 10
NUM_WORKERS = 4
OUTPUT_CHANNELS = len(project.object_list)
IMAGE_HEIGHT = project.image_size[0]
IMAGE_WIDTH = project.image_size[1]
PIN_MEMORY = True
LOAD_MODEL = True
TRAIN_IMG_DIR ="/home/kevin/Documents/trackingProjects/faceTrack/dataset/train_images"
TRAIN_MASK_DIR ="/home/kevin/Documents/trackingProjects/faceTrack/dataset/train_masks" 
TRAIN_COORDINATE_DIR ="/home/kevin/Documents/trackingProjects/faceTrack/dataset/train_coordinates" 
VAL_IMG_DIR = "/home/kevin/Documents/trackingProjects/faceTrack/dataset/val_images"
VAL_MASK_DIR ="/home/kevin/Documents/trackingProjects/faceTrack/dataset/val_masks"
VAL_COORDINATE_DIR ="/home/kevin/Documents/trackingProjects/faceTrack/dataset/val_coordinates"

## Model, loss, and optimizer

In [4]:
model = Unet(in_channels=3, out_channels=OUTPUT_CHANNELS).to(DEVICE)
if LOAD_MODEL:
    project.load_model(model)

loss_fn = nn.BCEWithLogitsLoss() # not doing sigmoid on the output of the model, so use this, if we had more classes (objects) we would use change out_chan and cross_entropy_loss as loss_fn
optimizer= optim.Adam(model.parameters(),lr=LEARNING_RATE)
scaler = torch.cuda.amp.GradScaler()

## Data augmentation and normalization

I am using the [albumentations](https://albumentations.ai/) package to do data augmentation.

We might want to do some data augmentation when training so that the images are modified slightly between epochs. This improves generalization of the model and prevent overfitting.

We also want to perform data normalization so that the mean of each channel is 0 and the std is 1. This facilitate learning. See the notebook on data normalization.

Here I am using 4 transformations. We can set the probability that this transformation is applied using the `p` argument. You can set it in the project configuration file. Alternatively, you can edit the code below.

Tips

* If you are tracking left/right body parts, you probably don't want to flip your images.


In [5]:
original_height = project.image_size[0]
original_width = project.image_size[1]
means = project.normalization_values["means"]
stds = project.normalization_values["stds"]


trainTransform = A.Compose([   
                    A.RandomSizedCrop(min_max_height=(original_height-50, original_height),w2h_ratio=original_width/original_height,height=original_height, width=original_width, p=project.augmentation_RandomSizedCropProb),
                    A.HorizontalFlip(p=project.augmentation_HorizontalFlipProb),
                    A.Rotate (limit=30,border_mode=cv2.BORDER_CONSTANT,p=project.augmentation_RotateProb),
                    A.RandomBrightnessContrast(p=project.augmentation_RandomBrightnessContrastProb),
                    A.Normalize(mean=means, std=stds)
])

valTransform = A.Compose([   
                    A.Normalize(mean=means, std=stds)
])


print(trainTransform)
print(valTransform)

Compose([
  RandomSizedCrop(always_apply=False, p=1.0, min_max_height=(430, 480), height=480, width=640, w2h_ratio=1.3333333333333333, interpolation=1),
  HorizontalFlip(always_apply=False, p=0.0),
  Rotate(always_apply=False, p=0.3, limit=(-30, 30), interpolation=1, border_mode=0, value=None, mask_value=None, rotate_method='largest_box', crop_border=False),
  RandomBrightnessContrast(always_apply=False, p=0.2, brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2), brightness_by_max=True),
  Normalize(always_apply=False, p=1.0, mean=[0.5110162496566772, 0.4608974754810333, 0.4772901237010956], std=[0.2727729380130768, 0.2578601539134979, 0.256255567073822], max_pixel_value=255.0),
], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={})
Compose([
  Normalize(always_apply=False, p=1.0, mean=[0.5110162496566772, 0.4608974754810333, 0.4772901237010956], std=[0.2727729380130768, 0.2578601539134979, 0.256255567073822], max_pixel_value=255.0),
], p=1.0, bbox_params=None, k

## Datasets and DataLoaders

In [6]:
trainDataset = UNetDataset(TRAIN_IMG_DIR, TRAIN_MASK_DIR,TRAIN_COORDINATE_DIR, transform=trainTransform)
valDataset = UNetDataset(VAL_IMG_DIR, VAL_MASK_DIR,VAL_COORDINATE_DIR, transform=valTransform)
trainLoader = DataLoader(trainDataset,shuffle=True,batch_size=BATCH_SIZE, num_workers=NUM_WORKERS,pin_memory=PIN_MEMORY)
valLoader = DataLoader(valDataset,shuffle=False,batch_size=BATCH_SIZE, num_workers=NUM_WORKERS,pin_memory = PIN_MEMORY)

In [7]:
BATCH_SIZE=2
trainLoader = DataLoader(trainDataset,
                          shuffle=True,
                          batch_size=BATCH_SIZE,
                          num_workers=4)
valLoader = DataLoader(valDataset,
                          shuffle=False,
                          batch_size=BATCH_SIZE,
                          num_workers=4)

In [8]:
imgs, masks, _ = next(iter(trainLoader))
imgs.shape, masks.shape

(torch.Size([2, 3, 480, 640]), torch.Size([2, 4, 480, 640]))

There is a lot of black because half of our pixels are below 0, on average.


# Save and load checkpoint

In [9]:
def save_checkpoint(state, filename = "my_checkpoint.pth.tar"):
    #print("Saving checkpoint")
    torch.save(state,filename)

## Check accuracy

In [10]:
def check_accuracy(loader,model,device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()
    with torch.no_grad():
        for x,y, _ in loader:
            x = x.to(device)
            y = y.to(device)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2*(preds * y).sum() / ((preds+y).sum() + 1e-8)) # work only for binary
            
            # proportion of the mask detected
            num_mask = y.sum()
            num_mask_detected = preds[y==1.0].sum()
            num_detected = preds.sum()
                        
    print(f"Accuracy: {num_correct/num_pixels*100:.2f}")
    print(f"Dice score: {dice_score/len(loader):.2f}")
    print(f"Mask pixels detected: {num_mask_detected/num_mask*100:.2f}%")
    print(f"False positives: {(num_detected-num_mask_detected)/num_detected*100:.2f}%")
    model.train()     

## Training loop

In [11]:
def train_fn(loader,model,optimizer,loss_fn,scaler,epoch,total_epochs):
    """
    One epoch of training
    """
    loop = tqdm(loader)
    for batch_idx, (data,targets,_) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.to(device=DEVICE)
        
        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions,targets)
            
        
        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        # update tqdm loop
        loop.set_postfix_str("loss: {:.7f}, epoch: {:d}/{:d}".format(loss.item(),epoch,total_epochs))


In [None]:
startTime = datetime.now()
print("Starting time:",startTime)
for epoch in range(NUM_EPOCHS):
    
    train_fn(trainLoader,model,optimizer,loss_fn,scaler,epoch,NUM_EPOCHS)
    
    if epoch % 5 == 0 :
        # save model
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict()}
        save_checkpoint(checkpoint,filename=os.path.join(project.models_dir,"my_checkpoint.pth.tar"))

        # check accuracy
        check_accuracy(valLoader,model,device=DEVICE)

endTime=datetime.now()
print("End time:",endTime)
print("{} epochs, duration:".format(NUM_EPOCHS), endTime-startTime)

Starting time: 2022-11-26 14:19:31.146527


100%|███████████| 232/232 [00:30<00:00,  7.60it/s, loss: 0.0005777, epoch: 0/10]


Accuracy: 99.98
Dice score: 0.89
Mask pixels detected: 92.17%
False positives: 11.04%


100%|███████████| 232/232 [00:29<00:00,  7.78it/s, loss: 0.0007374, epoch: 1/10]
100%|███████████| 232/232 [00:29<00:00,  7.80it/s, loss: 0.0006644, epoch: 2/10]
 34%|████        | 78/232 [00:09<00:20,  7.66it/s, loss: 0.0004306, epoch: 3/10]

In [None]:
project.save_model(model)