In [11]:
# Create Dataset.py so that we can load the CaravanaDataset and other things from it.
# /home/onkar/MyLearnings/UNet
# <-- comment after files are created -->

# !jupyter nbconvert --to script 02_Dataset.ipynb --output DataSet
# !jupyter nbconvert --to script 03_Model.ipynb --output SegModel

In [12]:


import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
import torch.optim as optim
from torch.cuda.amp import GradScaler

import random
import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

from SegModel import UNet

from DataSet import PROCESSED_IMAGES_DIR

from utils import (
    save_checkpoint,
    load_checkpoint,
    check_accuracy,
    save_preds_as_imgs,
    create_dataset,
    get_dataloader
)


In [13]:
# -- Test Model if working properly --

def test():
    x = torch.randn((1, 3, 448, 240))
    model = UNet(in_channels=3, out_channels=1)
    pred = model(x)
    print(x.shape)
    print(pred.shape)
    
test()

torch.Size([1, 3, 448, 240])
torch.Size([1, 1, 448, 240])


In [14]:
# Hyperparameters
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
NUM_EPOCHS = 3
NUM_WORKERS = 1
PIN_MEMORY = True
LOAD_MODEL = True

# Other information
TRAINING_DATASET_PATH=PROCESSED_IMAGES_DIR/'train'
VALIDATION_DATASET_PATH=PROCESSED_IMAGES_DIR/'val'


In [15]:
# Training function will run over one batch and will return the training loss
def train_func(loader, model, optimizer, loss_fn, scaler):
    # Sometimes, the GPU memory is not properly cleared. You can free the cached memory
    # by calling torch.cuda.empty_cache().
    torch.cuda.empty_cache()
    
    loop = tqdm(loader)
    for batch_idx, (images, targets) in enumerate(loop):
        images = images.permute([0, 3, 1, 2]).to(device=DEVICE)
        targets  = targets.unsqueeze(1).to(device=DEVICE)
        # print(f'images.size : {images.size()}, targets.size : {targets.size()}')
        
        # Forward
        with torch.cuda.amp.autocast():
            preds = model(images)
            loss = loss_fn(preds, targets)
            
        # Backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        # update tqdm loop
        loop.set_postfix(loss=loss.item())
        
def main():
    # get dataset
    training_ds = create_dataset(TRAINING_DATASET_PATH)
    validation_ds = create_dataset(VALIDATION_DATASET_PATH)
    
    # Create a dataloader for loading data in batches
    training_loader = get_dataloader(
                                dataset=training_ds,
                                batch_size=BATCH_SIZE,
                                num_workers=NUM_WORKERS,
                                shuffle=True
                                )
    validation_loader = get_dataloader(
                            dataset=validation_ds,
                            batch_size=BATCH_SIZE,
                            num_workers=NUM_WORKERS,
                            shuffle=False
                            )
    
    # Set a model
    model = UNet(in_channels=3, out_channels=1).to(device=DEVICE)
    
    loss_fn = nn.BCEWithLogitsLoss() # 2 classes 'Car' or 'Not car', if more than one class use cross entropy loss
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    if LOAD_MODEL:
        load_checkpoint(checkpoint=torch.load('checkpoints/my_checkpoint.pth'), model=model)
    
    check_accuracy(loader=validation_loader, model=model, device=DEVICE)
    
    # Start the training     
    scaler = GradScaler()
    for epoch in range(NUM_EPOCHS):
        train_func(training_loader, model, optimizer, loss_fn, scaler)
        
        # -- Save model --
        checkpoint = {
            "state_dict" : model.state_dict(), 
            "optimizer"  : optimizer.state_dict(),
        }
        
        save_checkpoint(checkpoint)
            
        with torch.no_grad():    
            # -- Check accuracy --
            check_accuracy(loader=validation_loader, model=model, device=DEVICE)
        
            # print some examples to folder
            save_preds_as_imgs(validation_loader, model=model,save_dir='pred_images', device=DEVICE)
      

if __name__ == '__main__':
    main()

=> Loading Checkpoint


  0%|          | 0/96 [00:00<?, ?it/s]

Got 163102813/164183040 with accuracy of 99.342%, Dice score : 95.075


  0%|          | 0/223 [00:00<?, ?it/s]

=> Saving Checkpoint


  0%|          | 0/96 [00:00<?, ?it/s]

Got 163286253/164183040 with accuracy of 99.454%, Dice score : 95.233


  0%|          | 0/223 [00:00<?, ?it/s]

=> Saving Checkpoint


  0%|          | 0/96 [00:00<?, ?it/s]

Got 163340751/164183040 with accuracy of 99.487%, Dice score : 95.277


  0%|          | 0/223 [00:00<?, ?it/s]

=> Saving Checkpoint


  0%|          | 0/96 [00:00<?, ?it/s]

Got 163425317/164183040 with accuracy of 99.538%, Dice score : 95.351
