# Task 3 Unet

In [1]:
from task3.utils.config import *
from tqdm import tqdm
from task3.utils.data_utils import evaluate, save_zipped_pickle, load_zipped_pickle
from task3.utils.img_utils import show_img_batch, plot_pred_on_frame
from task3.utils.utils import upscale, get_img_dims
import pickle
import gzip
import importlib
import sys
import time
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from datetime import datetime
import segmentation_models_pytorch as smp
from torchmetrics import IoU
import torchvision
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from loguru import logger
from torch.utils.tensorboard import SummaryWriter

In [2]:
#config = init(config='configs/train_tg.yaml')
config = init(config='configs/raphaela.yaml')

In [3]:
device = torch.device(config['device'])
sys_device = 'cuda' if torch.cuda.is_available() else 'cpu'

logger.info(f'You are using {device}.')

if device != sys_device:
    logger.warning(f'You are using {device} but system device was found to be {sys_device}. Check your device choice in config.py.')

2022-01-03T14:43:52.219467+0100 INFO You are using cpu.


In [4]:
model = get_model(config)

2022-01-03T14:43:52.487357+0100 INFO model params set to: {'encoder_name': 'resnet34', 'encoder_weights': None, 'in_channels': 1, 'classes': 1, 'encoder_depth': 5, 'decoder_use_batchnorm': True, 'decoder_channels': [256, 128, 64, 32, 16], 'decoder_attention_type': 'scse', 'activation': 'sigmoid', 'aux_params': None}


In [5]:
training_loader, validation_loader, test_loader = get_data_loader(config, mode='train', get_subset=False)

2022-01-03T14:43:52.493401+0100 DEBUG selected device: cpu
2022-01-03T14:43:52.493903+0100 DEBUG Dataset creation: train
2022-01-03T14:43:58.618561+0100 DEBUG dict_keys(['train', 'val', 'test', 'submission'])


 13%|█▎        | 201/1560 [00:00<00:00, 68939.01it/s]

2022-01-03T14:43:58.655144+0100 DEBUG Dataset creation: validation





2022-01-03T14:44:04.196510+0100 DEBUG dict_keys(['train', 'val', 'test', 'submission'])


100%|██████████| 390/390 [00:00<00:00, 122502.70it/s]


In [6]:
optimizer = get_optimizer(model, config)
criterion = get_loss(model, config)

# learning rate scheduler TODO add to config
# decays lr after 10 epochs by factor 0.1, e.g. from 0.005 to 0.0005 every 10 epochs

lr_scheduler = get_lrscheduler(optimizer, config)
#lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

#torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10,
#                                           threshold=0.0001, threshold_mode='rel', cooldown=0,
#                                           min_lr=0, eps=1e-08, verbose=False)
num_epochs = config['training'].get('epochs', 1)
save_path = config['training'].get('save_path', 'outputs')

2022-01-03T14:44:04.234437+0100 INFO Using BCELoss() as loss function.


In [7]:
# initialize metric
metric = IoU(num_classes=2) # num classes in Unet=1 for binary segmentation, corresponds to 2 in IoU score

## Set up tensorboard and logging

In [8]:
writer = SummaryWriter(save_path)

In [9]:
f = open(save_path+'/config.txt', 'w+')
f.write(config['run_notes'])
f.write('\n\n')
f.write(json.dumps(config))

1413

## Training loop

In [10]:
def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    running_score = 0.
    last_loss = 0.

    for i, batch in enumerate(tqdm(training_loader)):
        inputs, labels = batch['frame_cropped'], batch['label_cropped']

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)
        labels_fl = labels.float()

        # Compute the loss and its gradients
        loss = criterion(outputs, labels_fl) # if BCE we need floats (0.0 or 1.0) for labels
        # we need a threshold when calcualting IoU as we have a sigmoid output -> [0,1] but we need (0,1)
        outputs_thr = (outputs > 0.5)
        score = metric(outputs_thr, labels) # here we need bool for labels not float

        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        running_score += score.item()
        
        # report average per-batch loss of last for last ... batches
        if i % 5 == 4:
            last_loss = running_loss / 5 # loss per batch
            last_score = running_score / 5 # IoU per batch
            #print('  batch {} loss: {}'.format(i + 1, last_loss))
            #print('  batch {} IoU: {}'.format(i + 1, last_score))
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            tb_writer.add_scalar('IoU/train', last_score, tb_x)
            running_loss = 0.
            running_score = 0.

    return last_loss, last_score

In [None]:
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter(save_path+'/mitrial_valve_{}'.format(timestamp))



print(f'training for {num_epochs} epochs')
print(f'batch size: ', config['data'].get('batch_size', None))
print(f'saving results and models to {save_path}')
print('training model...')

start = time.time()
epoch_number = 0
best_vloss = 1_000_000.
best_vscore = 1_000_000.

for epoch in range(num_epochs):
    print('\nEPOCH {}:'.format(epoch_number + 1))
    start_epoch = time.time()
    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss, avg_score = train_one_epoch(epoch_number, writer)
    
    # We don't need gradients on to do reporting
    model.train(False)

    running_vloss = 0.0
    running_vscore = 0.0

    # validate model for every epoch
    for i, vbatch in enumerate(validation_loader):
        vinputs, vlabels = vbatch['frame_cropped'], vbatch['label_cropped']
        voutputs = model(vinputs)
        vlabels_fl = vlabels.float()
        vloss = criterion(voutputs, vlabels_fl)
        voutputs_thr = voutputs > 0.5
        vscore = metric(voutputs_thr, vlabels)
        
        running_vloss += vloss
        running_vscore += vscore

    avg_vloss = running_vloss / (i + 1)
    avg_vscore = running_vscore / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
    print('IOU train {} valid {}'.format(avg_score, avg_vscore))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.add_scalars('Training vs. Validation IoU',
                    { 'Training' : avg_score, 'Validation' : avg_vscore },
                    epoch_number + 1)

    # write last figure of batch to tensorboard
    fimg_grid = torchvision.utils.make_grid(vinputs)
    limg_grid = torchvision.utils.make_grid(vlabels)
    pimg_grid = torchvision.utils.make_grid(voutputs_thr.detach())
    
    writer.add_image(f'{save_path}_epoch_{epoch_number}_frame_valiou_{vscore}', fimg_grid)
    writer.add_image(f'{save_path}_epoch_{epoch_number}_label_valiou_{vscore}', limg_grid)
    writer.add_image(f'{save_path}_epoch_{epoch_number}_pred_valiou_{vscore}', pimg_grid)

    writer.flush()

    # Track best performance, and save the model's state, could also use IoU instead of loss 
    # Or use Jaccard loss as it is a direct proxy to IoU
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = save_path + '/model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)
    
    epoch_duration = (time.time()-start_epoch)/60
    print(f'Epoch {epoch_number} finished in {epoch_duration} min')

    epoch_number += 1

duration = (time.time()-start)/60
print(f'\nTraining finished in {duration} min')

training for 400 epochs
batch size:  8
saving results and models to runs/mv_training_15
training model...

EPOCH 1:


 64%|██████▍   | 16/25 [01:01<00:34,  3.88s/it]

## Restore model and do inference

In [None]:
load_dir = 'runs/mv_training_7/model_20211230_133722_20'
submission_loader = get_data_loader(config, mode='submission', get_subset=False)
#model = smp.Unet(**config['model'].get('smp-unet'))
model = get_model(config) 
model.to(device) # if not called wrong input shape (unclear why)
model.load_state_dict(torch.load(load_dir))
model.eval()

In [None]:
lookup_test = pd.read_csv('data/lookup_test.csv')

predictions_jodok = []
predictions = []

for i, batch in enumerate(submission_loader):
    
    frame = batch['frame_cropped']
    name = batch['name']
    
    fid = batch['id']
    prediction = model(frame)
    prediction_thr = prediction > 0.5

    img_dims = get_img_dims(lookup_test, name[0])
    lookup_sample = lookup_test[lookup_test.name == name[0]]
    roi_coord = lookup_sample.roi_coord.values[0]
    roi_dims = lookup_sample.roi_dims.values[0]    

    roi_coord_stripped = roi_coord.replace('(', '')
    roi_coord_stripped = roi_coord_stripped.replace(')', '')

    roi_dims_stripped = roi_dims.replace('(', '')
    roi_dims_stripped = roi_dims_stripped.replace(')', '')    

    roi_coord_final = tuple(map(int, roi_coord_stripped.split(', ')))
    roi_dims_final = tuple(map(int, roi_dims_stripped.split(', ')))

    #print('roi coord final', roi_coord_final)
    #print('roi dims final', roi_dims_final)        
    #print(fid, '\n')
    #print('orig_frame_dim', batch['orig_frame_dims'])
    #print('frame shape', frame.shape)
    #print('pred shape', prediction_thr.shape)
    #print('img_dims', img_dims)
    pred_squeezed = prediction_thr.squeeze(0).squeeze(0)
    pred_scaled = upscale(pred_squeezed.numpy().astype(np.uint8), img_dims, roi_coord_final, roi_dims_final)

    #print(pred_scaled.shape)
    # name needs to be string
    # prediction needs to be 2D numpy bool array
    
    ## raws for jodok
    #predictions_jodok.append({
    #   'name': name[0],
    #    'prediction_upscaled': pred_scaled.astype(bool),
    #    'prediction_raw': pred_squeezed.to(bool)
    #    }
    #)

    predictions.append({
       'name': name[0],
        'prediction': pred_scaled.astype(bool),
        }
    )    

    #print('pred scaled shape', pred_scaled.shape) # shoud be 2Dim    

    #fig = plt.figure(figsize=(8,6))
    #ax = fig.subplots(1,2)
    # original frame
    #ax[0].imshow(frame_squeezed, interpolation='nearest')
    #ax[1].imshow(pred_squeezed, interpolation='nearest')
    #ax[0].set_title(f'name_{name}_video_{i}_frame_{fid}')
    #plt.savefig(f'plots/frame_and_predictions_name_{name}_idx_{i}_frame_{fid}.png')
    #show_img_batch(batch, pred=prediction_thr.detach())
    #print('\n----------------------------------------\n')
    
    print(i)

In [None]:
save_zipped_pickle(predictions, 'submissions/mv_run7_unet++_resnet34_attention_predictions.pkl')
save_zipped_pickle(predictions_jodok, 'submissions/mv_run7_unet++_resnet34_attention_predictions_jodok.pkl')

In [None]:
df = pd.DataFrame.from_dict(predictions)

def mySum(dataframe):
    return np.stack(dataframe, axis=2) # make sure dims are correct (image_x, image_y, num_frames)

res = df.groupby('name').agg(mySum).reset_index(drop=False)
submissions_corrected = pd.DataFrame.to_dict(res, orient='records')


sort_order = {'RZ9W7OK2EO': 0,
              '401JD35E1A': 1,
              'O7WUJ71C15': 2,
              '7UXIXUBK2G': 3,
              'JQX264DTZ0': 4,
              'NHC30J31YN': 5, 
              'CD4RIAOCHG': 6,
              'QJTAVYCG6M': 7,
              '3WOQKZBVRN': 8,
              'UB7LFQKZT5': 9,
              'SZKYOVQ4ZP': 10,
              'ESY800XYMN': 11,
              '1QSFD8ORNM': 12,
              '0MVRNDWR1G': 13,
              'VODEK84RH4': 14,
              '1EKDG3M9L1': 15,
              'QQW12K1U3U': 16,
              'D271IBSMUW': 17,
              'TYM0IJW004': 18,
              '8FKMSXTPSJ': 19,
             }
submission_sorted = sorted(submissions_corrected, key=lambda d: sort_order[d['name']])
# save in correct format, all frames aggregated into 20 videos

In [None]:
save_zipped_pickle(submission_sorted, 'submissions/mv_run7_unet++_resnet34_attention_predictions_sorted.pkl')

In [None]:
testset = load_zipped_pickle('data/test.pkl')

In [None]:
#submission = load_zipped_pickle('submissions/mv_run7_unet++_resnet34_attention_sorted_padding_corrected.pkl')
#submission_old = load_zipped_pickle('submissions/mv_run7_unet++_resnet34_attention_correctupscaling_sorted_togroup.pkl')

In [None]:
for subm, test in zip(submission_sorted, testset):

    sname = subm['name']
    spred = subm['prediction']

    tname = test['name']
    tframe = test['video']

    print(spred.shape)
    print(tframe.shape)

    assert sname == tname
    assert spred.shape == tframe.shape

    for frame in range(tframe.shape[-1]):
        
        fig = plt.figure(figsize=(12,8))
        plt.imshow(tframe[:,:,frame], alpha=1)
        plt.imshow(spred[:,:,frame], alpha=0.6)
        #fig2 = plt.figure(figsize=(12,8))
        #plt.imshow(tframe[:,:,frame], alpha=1)
        #plt.imshow(spredold[:,:,frame], alpha=0.6)
        
        #plot_pred_on_frame(tframe[:,:,frame], pred=spred[:,:,frame])