# UNET

In [41]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [42]:
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from model import UNET
from utils import (
    check_env,
    load_checkpoint,
    save_checkpoint,
    get_loaders,
    evaluate,
    check_dice_score,
    save_predictions_to_tensorboard,
    save_results
)

In [43]:
#os.environ["CUDA_VISIBLE_DEVICES"] = "0"
check_env()

CUDA: True - PyTorch: 1.8.1+cu111


## Hyperparameters etc

In [48]:
# Hyperparameters etc.
LEARNING_RATE = 1e-5
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 4
NUM_EPOCHS = 100
NUM_WORKERS = 2
IMAGE_HEIGHT = 512  # 1280 originally
IMAGE_WIDTH = 512  # 1918 originally
PIN_MEMORY = True
POS_WEIGHTS = torch.tensor([100])
DATASET = "mobility"
TRAIN_IMG_DIR = f"data/{DATASET}/train_images/"
TRAIN_MASK_DIR = f"data/{DATASET}/train_masks/"
VAL_IMG_DIR = f"data/{DATASET}/val_images/"
VAL_MASK_DIR = f"data/{DATASET}/val_masks/"

## Augmentation

In [49]:
train_transform = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

val_transforms = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

## Entraînement

In [50]:
# Chargement d'un modèle pré-entraîné
LOAD_MODEL = True
CHECKPOINT_FILE = "MODEL_400_BEST.pth"

In [51]:
# Setup Tensorboard
writer = SummaryWriter(comment=f'LR_{LEARNING_RATE}_BS_{BATCH_SIZE}_SIZE_{IMAGE_WIDTH}_{IMAGE_HEIGHT}')

In [52]:
# Instanciation du modèle
model = UNET(in_channels=3, out_channels=1).to(DEVICE)
torch.backends.cudnn.benchmark = True # faster convolutions, but more memory

In [53]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loss_acc = 0
    loop = tqdm(loader)
    
    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    
        loss_acc = loss_acc + loss.item()
    
        # update tqdm loop
        loop.set_postfix(loss=loss.item())
        
    return loss_acc / len(loader)

In [54]:
loss_fn = nn.BCEWithLogitsLoss(pos_weight=POS_WEIGHTS.to(DEVICE))
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

train_loader, val_loader = get_loaders(
    TRAIN_IMG_DIR,
    TRAIN_MASK_DIR,
    VAL_IMG_DIR,
    VAL_MASK_DIR,
    BATCH_SIZE,
    train_transform,
    val_transforms,
    NUM_WORKERS,
    PIN_MEMORY,
)

if LOAD_MODEL:
    load_checkpoint(CHECKPOINT_FILE, model)
    dice_score = check_dice_score(val_loader, model, device=DEVICE)
    print(f'Pre-trained model dice score: {dice_score}')

scaler = torch.cuda.amp.GradScaler()

best_loss = 10000

for epoch in range(NUM_EPOCHS):
    
    # train model
    loss_train = train_fn(train_loader, model, optimizer, loss_fn, scaler)
    writer.add_scalar('loss/train', loss_train, epoch)
    
    # evaluate
    loss_val = evaluate(train_loader, model, loss_fn)
    writer.add_scalar('loss/val', loss_val, epoch)
    
    # save results
    save_results(val_loader, "val", model, epoch, writer, device=DEVICE)
    save_results(train_loader, "train", model, epoch, writer, device=DEVICE)
   
    # save model
    if loss_val < best_loss:
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer":optimizer.state_dict(),
        }
        save_checkpoint(checkpoint, epoch, verbose=1)
        best_loss = loss_val

Pre-trained model dice score: 0.8509801030158997


  0%|          | 0/28 [00:00<?, ?it/s]

Best checkpoint 1 saved !


  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

Best checkpoint 3 saved !


  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

Best checkpoint 7 saved !


  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

Best checkpoint 12 saved !


  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

Best checkpoint 29 saved !


  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

Best checkpoint 34 saved !


  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/28 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# Fermeture du writer pour TensorBoard
writer.close()

--
## TensorBoard

Depuis un terminal, démarrez Tensorboard avec la commande:<br/><br/>
`tensorboard --host 0.0.0.0 --logdir=runs`<br/><br/>
**Pensez à ouvrir le port 6006 (Firewall Rules)**