# Train U-Net

This notebook will guide you through the steps to train your unet-tracker model on Google colab.

With approximately 200-400 images in the training set, I trained for 200 epochs and got reasonable results for face tracking.


## Copy your unetTracker project to Google Drive

The first step is to copy your unetTracker project to your Google Drive and keep track of where you saved it.

You also need to copy the Jupyter Notebooks to Google Drive as well. You could save them in the unetTracker project directory, in a subdirectory called `notebooks`.

## GPU access

You will need to get a Runtime with GPU access. Click on Runtime/Change runtime type. Select GPU as hardware accelerator.

## Install unet-tracker

We need to install the unet-tracker python package into your Colab workspace.

In [1]:
!pip install albumentations==1.3.0
!git clone https://github.com/kevin-allen/unetTracker
!pip install -r unetTracker/requirements.txt
!pip install -e unetTracker

Collecting albumentations==1.3.0
  Downloading albumentations-1.3.0-py3-none-any.whl (123 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m123.5/123.5 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: albumentations
  Attempting uninstall: albumentations
    Found existing installation: albumentations 1.3.1
    Uninstalling albumentations-1.3.1:
      Successfully uninstalled albumentations-1.3.1
Successfully installed albumentations-1.3.0
Cloning into 'unetTracker'...
remote: Enumerating objects: 825, done.[K
remote: Counting objects: 100% (204/204), done.[K
remote: Compressing objects: 100% (162/162), done.[K
remote: Total 825 (delta 110), reused 100 (delta 41), pack-reused 621[K
Receiving objects: 100% (825/825), 122.30 MiB | 29.95 MiB/s, done.
Resolving deltas: 100% (500/500), done.
Collecting jupyterlab (from -r unetTracker/requirements.txt (line 2))
  Downloading jupyterlab-4.0.9-py3-none-any.whl (9.2 MB)
[2K     [90m━━━━━━━━━━━

Obtaining file:///content/unetTracker
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: unetTracker
  Building editable for unetTracker (pyproject.toml) ... [?25l[?25hdone
  Created wheel for unetTracker: filename=unetTracker-0.0.1-0.editable-py3-none-any.whl size=15897 sha256=098e2b1dcf7532e4afee301e1598922da931d91fff37da3f4dc059cfedb87d7d
  Stored in directory: /tmp/pip-ephem-wheel-cache-iqtsvgv6/wheels/62/9b/5a/0cb547490a9187d698861d98e1e803c5e64f31a9d899a8e84c
Successfully built unetTracker
Installing collected packages: unetTracker
Successfully installed unetTracker-0.0.1


You now need to restart your runtime to be able to use unet-tracker. You can press the button above.

We are now going to mount your google drive so that we can access the unet-tracker project directory. Running the code below should open a window in your browser and you will need to give permission to Google Colab to access Google Drive.

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import os
fn = "/content/drive/My Drive/teaching_thesis_taq/Data_science_neuroscience/master_neuroscience_2023/unetTracker/finger_tracker/config.yalm"
if os.path.exists(fn):
  print("We can access the unet-tracker project directory.")
else:
  raise IOError("Problem accessing the unet-tracker project directory.")

We can access the unet-tracker project directory.


In [3]:
import torch
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm
from torch import optim
import torch.nn as nn
from datetime import datetime
import albumentations as A
import cv2
import os
import pickle

from unetTracker.trackingProject import TrackingProject
from unetTracker.dataset import UNetDataset
from unetTracker.unet import Unet
from unetTracker.coordinatesFromSegmentationMask import CoordinatesFromSegmentationMask
from unetTracker.utils import check_accuracy

In [4]:
project = TrackingProject(name="finger_tracker",
                          root_folder = "/content/drive/My Drive/teaching_thesis_taq/Data_science_neuroscience/master_neuroscience_2023/unetTracker/")

Project directory: /content/drive/My Drive/teaching_thesis_taq/Data_science_neuroscience/master_neuroscience_2023/unetTracker/finger_tracker
Getting configuration from config file. Values from config file will be used.
Loading /content/drive/My Drive/teaching_thesis_taq/Data_science_neuroscience/master_neuroscience_2023/unetTracker/finger_tracker/config.yalm
{'augmentation_HorizontalFlipProb': 0.0, 'augmentation_RandomBrightnessContrastProb': 0.2, 'augmentation_RandomSizedCropProb': 1.0, 'augmentation_RotateProb': 0.3, 'image_extension': '.png', 'image_size': [270, 480], 'labeling_ImageEnlargeFactor': 2.0, 'name': 'finger_tracker', 'normalization_values': {'means': [0.40835028886795044, 0.4549056589603424, 0.51627117395401], 'stds': [0.23996737599372864, 0.251758873462677, 0.26929107308387756]}, 'object_colors': [(0.0, 0.0, 255.0), (255.0, 0.0, 0.0), (255.0, 255.0, 0.0), (240.0, 255.0, 255.0)], 'objects': ['f1', 'f2', 'f3', 'f4'], 'target_radius': 6, 'unet_features': [64, 128, 256, 512

We can check if torch has access to a GPU.

In [5]:
torch.cuda.is_available(),torch.cuda.get_device_name(0)

(True, 'Tesla T4')

## Hyperparameters

You can probably leave most of these parameters as they are.

If it is the first time you train your model, you can set `LOAD_MODEL` to False. If you only want to refine your model quickly, you can set it to True.

As a starting point, you can use ~100 epochs if you have between 400 to 1000 images.

In [6]:
LEARNING_RATE=1e-4
DEVICE = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
BATCH_SIZE=4
NUM_EPOCHS = 150
NUM_WORKERS = 2
OUTPUT_CHANNELS = len(project.object_list)
IMAGE_HEIGHT = project.image_size[0]
IMAGE_WIDTH = project.image_size[1]
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_IMAGE_DIR = os.path.join(project.dataset_dir,"train_images")
TRAIN_MASK_DIR =  os.path.join(project.dataset_dir,"train_masks")
TRAIN_COORDINATE_DIR = os.path.join(project.dataset_dir,"train_coordinates")
VAL_IMAGE_DIR = os.path.join(project.dataset_dir,"val_images")
VAL_MASK_DIR =  os.path.join(project.dataset_dir,"val_masks")
VAL_COORDINATE_DIR = os.path.join(project.dataset_dir,"val_coordinates")

## Model, loss, and optimizer

In [7]:
model = Unet(in_channels=3, out_channels=OUTPUT_CHANNELS).to(DEVICE)
if LOAD_MODEL:
    project.load_model(model)

# set the model in train mode
model.train()

loss_fn = nn.BCEWithLogitsLoss() # not doing sigmoid on the output of the model, so use this, if we had more classes (objects) we would use change out_chan and cross_entropy_loss as loss_fn
optimizer= optim.Adam(model.parameters(),lr=LEARNING_RATE)
scaler = torch.cuda.amp.GradScaler()

## Data augmentation and normalization



In [8]:
fileName = os.path.join(project.augmentation_dir,"trainTransform")
print("Loading trainTransform from", fileName)
trainTransform=pickle.load(open(fileName,"rb" ))

fileName = os.path.join(project.augmentation_dir,"valTransform")
print("Loading valTransform from", fileName)
valTransform=pickle.load(open(fileName, "rb" ))

Loading trainTransform from /content/drive/My Drive/teaching_thesis_taq/Data_science_neuroscience/master_neuroscience_2023/unetTracker/finger_tracker/augmentation/trainTransform
Loading valTransform from /content/drive/My Drive/teaching_thesis_taq/Data_science_neuroscience/master_neuroscience_2023/unetTracker/finger_tracker/augmentation/valTransform


## Datasets and DataLoaders

In [9]:
trainDataset = UNetDataset(TRAIN_IMAGE_DIR, TRAIN_MASK_DIR,TRAIN_COORDINATE_DIR, transform=trainTransform)
valDataset = UNetDataset(VAL_IMAGE_DIR, VAL_MASK_DIR,VAL_COORDINATE_DIR, transform=valTransform)
trainLoader = DataLoader(trainDataset,shuffle=True,batch_size=BATCH_SIZE, num_workers=NUM_WORKERS,pin_memory=PIN_MEMORY)
valLoader = DataLoader(valDataset,shuffle=False,batch_size=BATCH_SIZE, num_workers=NUM_WORKERS,pin_memory = PIN_MEMORY)

In [10]:
imgs, masks, _ = next(iter(trainLoader))
imgs.shape, masks.shape

(torch.Size([4, 3, 270, 480]), torch.Size([4, 4, 270, 480]))

There is a lot of black because half of our pixels are below 0, on average.


# Save and load checkpoint

In [11]:
def save_checkpoint(state, filename = "my_checkpoint.pth.tar"):
    #print("Saving checkpoint")
    torch.save(state,filename)

## Training loop

This is where we train our model. Every few epochs, the performance of the model will be evaluated on the validation set.

The task of the model is to learn to predict your masks. For each image in the dataset, you created a set of masks when you labelled the image. For one image, there is one mask per body part. The mask has the same size as your image.  Most pixels in the mask are set to 0, but the circle centered on the body part is set to 1.

The task of the model is to output values for each mask that ressemble the mask you created while labeling the images. The loss function just measure the similarity between your masks and the output of the model.

It ususally takes a 40-60 epochs before the model starts to mark the labeled positive pixels of your mask as positive pixels.

During training, you should have a look at the printed output. If your model is learning, the loss will decrease over time.

In [12]:
def train_fn(loader,model,optimizer,loss_fn,scaler,epoch,total_epochs):
    """
    One epoch of training
    """
    loop = tqdm(loader)
    for batch_idx, (data,targets,_) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.to(device=DEVICE)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions,targets)


        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix_str("loss: {:.7f}, epoch: {:d}/{:d}".format(loss.item(),epoch,total_epochs))


In [13]:
startTime = datetime.now()
print("Starting time:",startTime)
for epoch in range(NUM_EPOCHS):

    train_fn(trainLoader,model,optimizer,loss_fn,scaler,epoch,NUM_EPOCHS)

    if epoch % 5 == 0 :
        # save model
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict()}
        save_checkpoint(checkpoint,filename=os.path.join(project.models_dir,"my_checkpoint.pth.tar"))

        # check accuracy
        print("Performance on validation set")
        check_accuracy(model,valLoader,DEVICE)

endTime=datetime.now()
print("End time:",endTime)
print("{} epochs, duration:".format(NUM_EPOCHS), endTime-startTime)

Starting time: 2023-11-22 11:42:55.710790


100%|██████████| 42/42 [00:11<00:00,  3.61it/s, loss: 0.5265929, epoch: 0/150]


Performance on validation set
Number of positive pixels predicted: 181051.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 2.183
Percentage of positive pixels in masks: 0.078
Accuracy: 97.740
Dice score: 0.000
Mask pixels detected (True positives): 0.078%
False positives: 99.997%
Mean distance: 132.01212023162395


100%|██████████| 42/42 [00:09<00:00,  4.32it/s, loss: 0.4740380, epoch: 1/150]
100%|██████████| 42/42 [00:10<00:00,  4.14it/s, loss: 0.4363039, epoch: 2/150]
100%|██████████| 42/42 [00:09<00:00,  4.25it/s, loss: 0.4041386, epoch: 3/150]
100%|██████████| 42/42 [00:09<00:00,  4.34it/s, loss: 0.3769300, epoch: 4/150]
100%|██████████| 42/42 [00:09<00:00,  4.37it/s, loss: 0.3473141, epoch: 5/150]


Performance on validation set
Number of positive pixels predicted: 0.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.000
Percentage of positive pixels in masks: 0.078
Accuracy: 99.922
Dice score: 0.000
Mask pixels detected (True positives): 0.000%
False positives: nan%
Mean distance: 110.27126238564456


100%|██████████| 42/42 [00:09<00:00,  4.42it/s, loss: 0.3220396, epoch: 6/150]
100%|██████████| 42/42 [00:09<00:00,  4.33it/s, loss: 0.3014618, epoch: 7/150]
100%|██████████| 42/42 [00:09<00:00,  4.46it/s, loss: 0.2813196, epoch: 8/150]
100%|██████████| 42/42 [00:09<00:00,  4.36it/s, loss: 0.2580007, epoch: 9/150]
100%|██████████| 42/42 [00:09<00:00,  4.37it/s, loss: 0.2459028, epoch: 10/150]


Performance on validation set
Number of positive pixels predicted: 21.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.000
Percentage of positive pixels in masks: 0.078
Accuracy: 99.922
Dice score: 0.000
Mask pixels detected (True positives): 0.000%
False positives: 100.000%
Mean distance: 157.27107486691747


100%|██████████| 42/42 [00:09<00:00,  4.40it/s, loss: 0.2240371, epoch: 11/150]
100%|██████████| 42/42 [00:10<00:00,  4.16it/s, loss: 0.2099213, epoch: 12/150]
100%|██████████| 42/42 [00:09<00:00,  4.41it/s, loss: 0.1940752, epoch: 13/150]
100%|██████████| 42/42 [00:09<00:00,  4.35it/s, loss: 0.1793021, epoch: 14/150]
100%|██████████| 42/42 [00:09<00:00,  4.45it/s, loss: 0.1659796, epoch: 15/150]


Performance on validation set
Number of positive pixels predicted: 0.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.000
Percentage of positive pixels in masks: 0.078
Accuracy: 99.922
Dice score: 0.000
Mask pixels detected (True positives): 0.000%
False positives: nan%
Mean distance: 116.91932582002268


100%|██████████| 42/42 [00:09<00:00,  4.50it/s, loss: 0.1725814, epoch: 16/150]
100%|██████████| 42/42 [00:09<00:00,  4.30it/s, loss: 0.1426741, epoch: 17/150]
100%|██████████| 42/42 [00:09<00:00,  4.45it/s, loss: 0.1348113, epoch: 18/150]
100%|██████████| 42/42 [00:09<00:00,  4.35it/s, loss: 0.1239873, epoch: 19/150]
100%|██████████| 42/42 [00:09<00:00,  4.42it/s, loss: 0.1171092, epoch: 20/150]


Performance on validation set


  sum_distance+= np.nanmean(np.sqrt(((pred_coords[:,:,0:2] - c.numpy())**2).sum(axis=2))) # calculate the distance between predicted coordinates and the coordinates from the dataset.


Number of positive pixels predicted: 0.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.000
Percentage of positive pixels in masks: 0.078
Accuracy: 99.922
Dice score: 0.000
Mask pixels detected (True positives): 0.000%
False positives: nan%
Mean distance: nan


100%|██████████| 42/42 [00:09<00:00,  4.44it/s, loss: 0.1086498, epoch: 21/150]
100%|██████████| 42/42 [00:09<00:00,  4.35it/s, loss: 0.1021502, epoch: 22/150]
100%|██████████| 42/42 [00:09<00:00,  4.46it/s, loss: 0.0961197, epoch: 23/150]
100%|██████████| 42/42 [00:09<00:00,  4.39it/s, loss: 0.0898558, epoch: 24/150]
100%|██████████| 42/42 [00:09<00:00,  4.44it/s, loss: 0.0843235, epoch: 25/150]


Performance on validation set
Number of positive pixels predicted: 0.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.000
Percentage of positive pixels in masks: 0.078
Accuracy: 99.922
Dice score: 0.000
Mask pixels detected (True positives): 0.000%
False positives: nan%
Mean distance: nan


100%|██████████| 42/42 [00:09<00:00,  4.42it/s, loss: 0.0797953, epoch: 26/150]
100%|██████████| 42/42 [00:09<00:00,  4.36it/s, loss: 0.0754770, epoch: 27/150]
100%|██████████| 42/42 [00:09<00:00,  4.39it/s, loss: 0.0708839, epoch: 28/150]
100%|██████████| 42/42 [00:09<00:00,  4.42it/s, loss: 0.0677123, epoch: 29/150]
100%|██████████| 42/42 [00:09<00:00,  4.44it/s, loss: 0.0637476, epoch: 30/150]


Performance on validation set
Number of positive pixels predicted: 0.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.000
Percentage of positive pixels in masks: 0.078
Accuracy: 99.922
Dice score: 0.000
Mask pixels detected (True positives): 0.000%
False positives: nan%
Mean distance: nan


100%|██████████| 42/42 [00:09<00:00,  4.39it/s, loss: 0.0607551, epoch: 31/150]
100%|██████████| 42/42 [00:09<00:00,  4.41it/s, loss: 0.0567937, epoch: 32/150]
100%|██████████| 42/42 [00:09<00:00,  4.39it/s, loss: 0.0545856, epoch: 33/150]
100%|██████████| 42/42 [00:09<00:00,  4.43it/s, loss: 0.0517872, epoch: 34/150]
100%|██████████| 42/42 [00:10<00:00,  4.17it/s, loss: 0.0494270, epoch: 35/150]


Performance on validation set
Number of positive pixels predicted: 0.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.000
Percentage of positive pixels in masks: 0.078
Accuracy: 99.922
Dice score: 0.000
Mask pixels detected (True positives): 0.000%
False positives: nan%
Mean distance: nan


100%|██████████| 42/42 [00:09<00:00,  4.35it/s, loss: 0.0469796, epoch: 36/150]
100%|██████████| 42/42 [00:09<00:00,  4.36it/s, loss: 0.0446418, epoch: 37/150]
100%|██████████| 42/42 [00:09<00:00,  4.38it/s, loss: 0.0428893, epoch: 38/150]
100%|██████████| 42/42 [00:09<00:00,  4.40it/s, loss: 0.0412544, epoch: 39/150]
100%|██████████| 42/42 [00:09<00:00,  4.41it/s, loss: 0.0390896, epoch: 40/150]


Performance on validation set
Number of positive pixels predicted: 0.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.000
Percentage of positive pixels in masks: 0.078
Accuracy: 99.922
Dice score: 0.000
Mask pixels detected (True positives): 0.000%
False positives: nan%
Mean distance: 58.36565597158375


100%|██████████| 42/42 [00:09<00:00,  4.36it/s, loss: 0.0379482, epoch: 41/150]
100%|██████████| 42/42 [00:09<00:00,  4.45it/s, loss: 0.0361732, epoch: 42/150]
100%|██████████| 42/42 [00:09<00:00,  4.40it/s, loss: 0.0342847, epoch: 43/150]
100%|██████████| 42/42 [00:09<00:00,  4.44it/s, loss: 0.0331604, epoch: 44/150]
100%|██████████| 42/42 [00:09<00:00,  4.40it/s, loss: 0.0319200, epoch: 45/150]


Performance on validation set
Number of positive pixels predicted: 0.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.000
Percentage of positive pixels in masks: 0.078
Accuracy: 99.922
Dice score: 0.000
Mask pixels detected (True positives): 0.000%
False positives: nan%
Mean distance: 44.50960960356789


100%|██████████| 42/42 [00:09<00:00,  4.43it/s, loss: 0.0298649, epoch: 46/150]
100%|██████████| 42/42 [00:09<00:00,  4.46it/s, loss: 0.0292846, epoch: 47/150]
100%|██████████| 42/42 [00:09<00:00,  4.32it/s, loss: 0.0282292, epoch: 48/150]
100%|██████████| 42/42 [00:09<00:00,  4.43it/s, loss: 0.0270904, epoch: 49/150]
100%|██████████| 42/42 [00:09<00:00,  4.36it/s, loss: 0.0262407, epoch: 50/150]


Performance on validation set
Number of positive pixels predicted: 0.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.000
Percentage of positive pixels in masks: 0.078
Accuracy: 99.922
Dice score: 0.000
Mask pixels detected (True positives): 0.000%
False positives: nan%
Mean distance: 44.42736100952057


100%|██████████| 42/42 [00:09<00:00,  4.35it/s, loss: 0.0247730, epoch: 51/150]
100%|██████████| 42/42 [00:09<00:00,  4.35it/s, loss: 0.0245716, epoch: 52/150]
100%|██████████| 42/42 [00:09<00:00,  4.40it/s, loss: 0.0234710, epoch: 53/150]
100%|██████████| 42/42 [00:09<00:00,  4.46it/s, loss: 0.0226474, epoch: 54/150]
100%|██████████| 42/42 [00:09<00:00,  4.34it/s, loss: 0.0212981, epoch: 55/150]


Performance on validation set
Number of positive pixels predicted: 0.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.000
Percentage of positive pixels in masks: 0.078
Accuracy: 99.922
Dice score: 0.000
Mask pixels detected (True positives): 0.000%
False positives: nan%
Mean distance: 37.784359111086545


100%|██████████| 42/42 [00:09<00:00,  4.44it/s, loss: 0.0215081, epoch: 56/150]
100%|██████████| 42/42 [00:09<00:00,  4.30it/s, loss: 0.0208197, epoch: 57/150]
100%|██████████| 42/42 [00:09<00:00,  4.41it/s, loss: 0.0201795, epoch: 58/150]
100%|██████████| 42/42 [00:09<00:00,  4.37it/s, loss: 0.0194588, epoch: 59/150]
100%|██████████| 42/42 [00:09<00:00,  4.37it/s, loss: 0.0181127, epoch: 60/150]


Performance on validation set
Number of positive pixels predicted: 0.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.000
Percentage of positive pixels in masks: 0.078
Accuracy: 99.922
Dice score: 0.000
Mask pixels detected (True positives): 0.000%
False positives: nan%
Mean distance: 25.518588907410376


100%|██████████| 42/42 [00:09<00:00,  4.40it/s, loss: 0.0188483, epoch: 61/150]
100%|██████████| 42/42 [00:09<00:00,  4.23it/s, loss: 0.0172683, epoch: 62/150]
100%|██████████| 42/42 [00:09<00:00,  4.41it/s, loss: 0.0170629, epoch: 63/150]
100%|██████████| 42/42 [00:09<00:00,  4.37it/s, loss: 0.0157686, epoch: 64/150]
100%|██████████| 42/42 [00:09<00:00,  4.40it/s, loss: 0.0156963, epoch: 65/150]


Performance on validation set
Number of positive pixels predicted: 0.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.000
Percentage of positive pixels in masks: 0.078
Accuracy: 99.922
Dice score: 0.000
Mask pixels detected (True positives): 0.000%
False positives: nan%
Mean distance: 23.01128029656136


100%|██████████| 42/42 [00:09<00:00,  4.47it/s, loss: 0.0148371, epoch: 66/150]
100%|██████████| 42/42 [00:10<00:00,  4.20it/s, loss: 0.0146150, epoch: 67/150]
100%|██████████| 42/42 [00:09<00:00,  4.47it/s, loss: 0.0145792, epoch: 68/150]
100%|██████████| 42/42 [00:09<00:00,  4.36it/s, loss: 0.0143634, epoch: 69/150]
100%|██████████| 42/42 [00:09<00:00,  4.42it/s, loss: 0.0129827, epoch: 70/150]


Performance on validation set
Number of positive pixels predicted: 0.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.000
Percentage of positive pixels in masks: 0.078
Accuracy: 99.922
Dice score: 0.000
Mask pixels detected (True positives): 0.000%
False positives: nan%
Mean distance: 21.379839928619198


100%|██████████| 42/42 [00:09<00:00,  4.39it/s, loss: 0.0130161, epoch: 71/150]
100%|██████████| 42/42 [00:09<00:00,  4.40it/s, loss: 0.0134169, epoch: 72/150]
100%|██████████| 42/42 [00:09<00:00,  4.47it/s, loss: 0.0126736, epoch: 73/150]
100%|██████████| 42/42 [00:09<00:00,  4.38it/s, loss: 0.0122297, epoch: 74/150]
100%|██████████| 42/42 [00:09<00:00,  4.46it/s, loss: 0.0115494, epoch: 75/150]


Performance on validation set
Number of positive pixels predicted: 1073.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.013
Percentage of positive pixels in masks: 0.078
Accuracy: 99.931
Dice score: 0.234
Mask pixels detected (True positives): 14.072%
False positives: 15.657%
Mean distance: 14.296537693600879


100%|██████████| 42/42 [00:09<00:00,  4.38it/s, loss: 0.0116865, epoch: 76/150]
100%|██████████| 42/42 [00:09<00:00,  4.33it/s, loss: 0.0124094, epoch: 77/150]
100%|██████████| 42/42 [00:09<00:00,  4.46it/s, loss: 0.0105907, epoch: 78/150]
100%|██████████| 42/42 [00:09<00:00,  4.39it/s, loss: 0.0100952, epoch: 79/150]
100%|██████████| 42/42 [00:09<00:00,  4.44it/s, loss: 0.0097110, epoch: 80/150]


Performance on validation set
Number of positive pixels predicted: 2802.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.034
Percentage of positive pixels in masks: 0.078
Accuracy: 99.939
Dice score: 0.450
Mask pixels detected (True positives): 32.406%
False positives: 25.625%
Mean distance: 11.666134345097937


100%|██████████| 42/42 [00:09<00:00,  4.36it/s, loss: 0.0099572, epoch: 81/150]
100%|██████████| 42/42 [00:09<00:00,  4.37it/s, loss: 0.0095464, epoch: 82/150]
100%|██████████| 42/42 [00:09<00:00,  4.40it/s, loss: 0.0100266, epoch: 83/150]
100%|██████████| 42/42 [00:09<00:00,  4.40it/s, loss: 0.0095338, epoch: 84/150]
100%|██████████| 42/42 [00:09<00:00,  4.43it/s, loss: 0.0081001, epoch: 85/150]


Performance on validation set
Number of positive pixels predicted: 3633.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.044
Percentage of positive pixels in masks: 0.078
Accuracy: 99.942
Dice score: 0.515
Mask pixels detected (True positives): 40.662%
False positives: 28.021%
Mean distance: 6.783632700882176


100%|██████████| 42/42 [00:09<00:00,  4.39it/s, loss: 0.0088720, epoch: 86/150]
100%|██████████| 42/42 [00:09<00:00,  4.40it/s, loss: 0.0089303, epoch: 87/150]
100%|██████████| 42/42 [00:09<00:00,  4.38it/s, loss: 0.0087263, epoch: 88/150]
100%|██████████| 42/42 [00:09<00:00,  4.42it/s, loss: 0.0082529, epoch: 89/150]
100%|██████████| 42/42 [00:09<00:00,  4.43it/s, loss: 0.0076453, epoch: 90/150]


Performance on validation set
Number of positive pixels predicted: 4040.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.049
Percentage of positive pixels in masks: 0.078
Accuracy: 99.946
Dice score: 0.564
Mask pixels detected (True positives): 46.556%
False positives: 25.891%
Mean distance: 6.496012755127067


100%|██████████| 42/42 [00:09<00:00,  4.34it/s, loss: 0.0088673, epoch: 91/150]
100%|██████████| 42/42 [00:09<00:00,  4.45it/s, loss: 0.0076746, epoch: 92/150]
100%|██████████| 42/42 [00:09<00:00,  4.35it/s, loss: 0.0067772, epoch: 93/150]
100%|██████████| 42/42 [00:09<00:00,  4.43it/s, loss: 0.0068030, epoch: 94/150]
100%|██████████| 42/42 [00:09<00:00,  4.38it/s, loss: 0.0074146, epoch: 95/150]


Performance on validation set
Number of positive pixels predicted: 5368.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.065
Percentage of positive pixels in masks: 0.078
Accuracy: 99.948
Dice score: 0.630
Mask pixels detected (True positives): 58.234%
False positives: 30.235%
Mean distance: 7.097593199287884


100%|██████████| 42/42 [00:09<00:00,  4.34it/s, loss: 0.0081646, epoch: 96/150]
100%|██████████| 42/42 [00:09<00:00,  4.41it/s, loss: 0.0079309, epoch: 97/150]
100%|██████████| 42/42 [00:09<00:00,  4.37it/s, loss: 0.0068000, epoch: 98/150]
100%|██████████| 42/42 [00:09<00:00,  4.43it/s, loss: 0.0081863, epoch: 99/150]
100%|██████████| 42/42 [00:09<00:00,  4.36it/s, loss: 0.0056331, epoch: 100/150]


Performance on validation set
Number of positive pixels predicted: 4887.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.059
Percentage of positive pixels in masks: 0.078
Accuracy: 99.952
Dice score: 0.641
Mask pixels detected (True positives): 56.943%
False positives: 25.067%
Mean distance: 4.953919032119479


100%|██████████| 42/42 [00:09<00:00,  4.39it/s, loss: 0.0061217, epoch: 101/150]
100%|██████████| 42/42 [00:09<00:00,  4.27it/s, loss: 0.0052526, epoch: 102/150]
100%|██████████| 42/42 [00:09<00:00,  4.40it/s, loss: 0.0062576, epoch: 103/150]
100%|██████████| 42/42 [00:09<00:00,  4.37it/s, loss: 0.0076817, epoch: 104/150]
100%|██████████| 42/42 [00:09<00:00,  4.37it/s, loss: 0.0051745, epoch: 105/150]


Performance on validation set
Number of positive pixels predicted: 3714.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.045
Percentage of positive pixels in masks: 0.078
Accuracy: 99.947
Dice score: 0.561
Mask pixels detected (True positives): 44.674%
False positives: 22.644%
Mean distance: 5.599696983243346


100%|██████████| 42/42 [00:09<00:00,  4.40it/s, loss: 0.0071389, epoch: 106/150]
100%|██████████| 42/42 [00:09<00:00,  4.29it/s, loss: 0.0058118, epoch: 107/150]
100%|██████████| 42/42 [00:09<00:00,  4.39it/s, loss: 0.0057276, epoch: 108/150]
100%|██████████| 42/42 [00:09<00:00,  4.40it/s, loss: 0.0051255, epoch: 109/150]
100%|██████████| 42/42 [00:09<00:00,  4.38it/s, loss: 0.0059947, epoch: 110/150]


Performance on validation set
Number of positive pixels predicted: 5018.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.060
Percentage of positive pixels in masks: 0.078
Accuracy: 99.949
Dice score: 0.625
Mask pixels detected (True positives): 56.414%
False positives: 27.700%
Mean distance: 6.712045800106003


100%|██████████| 42/42 [00:09<00:00,  4.41it/s, loss: 0.0071418, epoch: 111/150]
100%|██████████| 42/42 [00:09<00:00,  4.23it/s, loss: 0.0041973, epoch: 112/150]
100%|██████████| 42/42 [00:09<00:00,  4.43it/s, loss: 0.0054228, epoch: 113/150]
100%|██████████| 42/42 [00:09<00:00,  4.35it/s, loss: 0.0053845, epoch: 114/150]
100%|██████████| 42/42 [00:09<00:00,  4.42it/s, loss: 0.0051153, epoch: 115/150]


Performance on validation set
Number of positive pixels predicted: 6010.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.072
Percentage of positive pixels in masks: 0.078
Accuracy: 99.953
Dice score: 0.683
Mask pixels detected (True positives): 66.708%
False positives: 28.619%
Mean distance: 3.107967681234871


100%|██████████| 42/42 [00:09<00:00,  4.47it/s, loss: 0.0048278, epoch: 116/150]
100%|██████████| 42/42 [00:09<00:00,  4.24it/s, loss: 0.0039610, epoch: 117/150]
100%|██████████| 42/42 [00:09<00:00,  4.44it/s, loss: 0.0038859, epoch: 118/150]
100%|██████████| 42/42 [00:09<00:00,  4.31it/s, loss: 0.0041898, epoch: 119/150]
100%|██████████| 42/42 [00:09<00:00,  4.40it/s, loss: 0.0050018, epoch: 120/150]


Performance on validation set
Number of positive pixels predicted: 5621.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.068
Percentage of positive pixels in masks: 0.078
Accuracy: 99.948
Dice score: 0.636
Mask pixels detected (True positives): 60.177%
False positives: 31.151%
Mean distance: 4.016812526348529


100%|██████████| 42/42 [00:09<00:00,  4.38it/s, loss: 0.0043903, epoch: 121/150]
100%|██████████| 42/42 [00:09<00:00,  4.33it/s, loss: 0.0034968, epoch: 122/150]
100%|██████████| 42/42 [00:09<00:00,  4.48it/s, loss: 0.0045519, epoch: 123/150]
100%|██████████| 42/42 [00:09<00:00,  4.40it/s, loss: 0.0046303, epoch: 124/150]
100%|██████████| 42/42 [00:09<00:00,  4.44it/s, loss: 0.0053233, epoch: 125/150]


Performance on validation set
Number of positive pixels predicted: 5982.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.072
Percentage of positive pixels in masks: 0.078
Accuracy: 99.947
Dice score: 0.640
Mask pixels detected (True positives): 62.479%
False positives: 32.832%
Mean distance: 4.2592931261621345


100%|██████████| 42/42 [00:09<00:00,  4.35it/s, loss: 0.0034374, epoch: 126/150]
100%|██████████| 42/42 [00:09<00:00,  4.35it/s, loss: 0.0052954, epoch: 127/150]
100%|██████████| 42/42 [00:09<00:00,  4.41it/s, loss: 0.0044431, epoch: 128/150]
100%|██████████| 42/42 [00:09<00:00,  4.38it/s, loss: 0.0053893, epoch: 129/150]
100%|██████████| 42/42 [00:09<00:00,  4.43it/s, loss: 0.0041845, epoch: 130/150]


Performance on validation set
Number of positive pixels predicted: 5469.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.066
Percentage of positive pixels in masks: 0.078
Accuracy: 99.946
Dice score: 0.616
Mask pixels detected (True positives): 57.907%
False positives: 31.907%
Mean distance: 5.414879157962554


100%|██████████| 42/42 [00:09<00:00,  4.39it/s, loss: 0.0051739, epoch: 131/150]
100%|██████████| 42/42 [00:09<00:00,  4.38it/s, loss: 0.0030495, epoch: 132/150]
100%|██████████| 42/42 [00:09<00:00,  4.36it/s, loss: 0.0031706, epoch: 133/150]
100%|██████████| 42/42 [00:09<00:00,  4.40it/s, loss: 0.0039381, epoch: 134/150]
100%|██████████| 42/42 [00:09<00:00,  4.39it/s, loss: 0.0043495, epoch: 135/150]


Performance on validation set
Number of positive pixels predicted: 5236.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.063
Percentage of positive pixels in masks: 0.078
Accuracy: 99.950
Dice score: 0.641
Mask pixels detected (True positives): 58.451%
False positives: 28.209%
Mean distance: 4.669968085948394


100%|██████████| 42/42 [00:09<00:00,  4.34it/s, loss: 0.0052769, epoch: 136/150]
100%|██████████| 42/42 [00:09<00:00,  4.40it/s, loss: 0.0038736, epoch: 137/150]
100%|██████████| 42/42 [00:09<00:00,  4.32it/s, loss: 0.0051593, epoch: 138/150]
100%|██████████| 42/42 [00:09<00:00,  4.38it/s, loss: 0.0037289, epoch: 139/150]
100%|██████████| 42/42 [00:09<00:00,  4.34it/s, loss: 0.0025650, epoch: 140/150]


Performance on validation set
Number of positive pixels predicted: 6526.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.079
Percentage of positive pixels in masks: 0.078
Accuracy: 99.950
Dice score: 0.670
Mask pixels detected (True positives): 68.201%
False positives: 32.792%
Mean distance: 4.6673087297221105


100%|██████████| 42/42 [00:09<00:00,  4.37it/s, loss: 0.0041481, epoch: 141/150]
100%|██████████| 42/42 [00:09<00:00,  4.42it/s, loss: 0.0035599, epoch: 142/150]
100%|██████████| 42/42 [00:09<00:00,  4.24it/s, loss: 0.0036225, epoch: 143/150]
100%|██████████| 42/42 [00:09<00:00,  4.28it/s, loss: 0.0023492, epoch: 144/150]
100%|██████████| 42/42 [00:09<00:00,  4.32it/s, loss: 0.0035684, epoch: 145/150]


Performance on validation set
Number of positive pixels predicted: 5557.0
Number of positive pixels in masks: 6431.0
Percentage of positive pixels predicted: 0.067
Percentage of positive pixels in masks: 0.078
Accuracy: 99.949
Dice score: 0.644
Mask pixels detected (True positives): 60.582%
False positives: 29.890%
Mean distance: 4.050035555009737


100%|██████████| 42/42 [00:09<00:00,  4.34it/s, loss: 0.0023686, epoch: 146/150]
100%|██████████| 42/42 [00:09<00:00,  4.42it/s, loss: 0.0022719, epoch: 147/150]
100%|██████████| 42/42 [00:09<00:00,  4.35it/s, loss: 0.0024798, epoch: 148/150]
100%|██████████| 42/42 [00:09<00:00,  4.41it/s, loss: 0.0022060, epoch: 149/150]

End time: 2023-11-22 12:08:21.539659
150 epochs, duration: 0:25:25.828869





Once your model starts to predict positive pixels, the output will contain the a few values that helps you track how your model is doing.

* Dice score: a common measure to measure the performace of image segmentation models. [External link](https://medium.com/mlearning-ai/understanding-evaluation-metrics-in-medical-image-segmentation-d289a373a3f)
* Mask pixels detected: percentage of the positive pixels in the masks that are predicted as positive by the model.
* False positive: percentage of the pixels that are predicted as positive by the models that are negative pixels in your masks.
* Mean distance: Mean distance in pixels between the coordinates you labeled and the coordinated calculated from the model output.



In [14]:
project.save_model(model)

saving model state dict to /content/drive/My Drive/teaching_thesis_taq/Data_science_neuroscience/master_neuroscience_2023/unetTracker/finger_tracker/models/UNet.pt
2023-11-22 12:08:21.554515


## Shutdown kernels

If you want to continue with the next notebook, you might want to shutdown the kernel associated with this notebook to ensure that the GPU memory is free for other notebook.