In [None]:
import os
import cv2
import numpy as np

import argus
from argus import Model
from argus.callbacks import MonitorCheckpoint, EarlyStopping, LoggingToFile

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

from src.dataset import SaltDataset
from src.transforms import SimpleDepthTransform, DepthTransform, SaltTransform
from src.argus_models import SaltProbModel, SaltMetaModel
from src import config

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
image_size = (128, 128)
val_folds = [0]
train_folds = [1, 2, 3, 4]
train_batch_size = 64
val_batch_size = 64

In [None]:
depth_trns = SimpleDepthTransform()
train_trns = SaltTransform(image_size, True, 'pad')
val_trns = SaltTransform(image_size, False, 'pad')
train_dataset = SaltDataset(config.TRAIN_FOLDS_PATH, train_folds, train_trns, depth_trns)
val_dataset = SaltDataset(config.TRAIN_FOLDS_PATH, val_folds, val_trns, depth_trns)
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=val_batch_size, shuffle=False)

In [None]:
# Draw a list of images in a row# Draw  
def draw(imgs):
    n = len(imgs)  # Number of images in a row
    plt.figure(figsize=(7,n*7))
    for i in range(n):
        plt.subplot(1, n, i+1)
        plt.axis('off')
        plt.imshow(imgs[i])
    plt.show()

In [None]:
n_images_to_draw = 3

for img, trg in train_loader:
    for i in range(n_images_to_draw):
        img_i = img[i, 0, :, :].numpy()
        cumsum_i = img[i, 1, :, :].numpy()
        trg_i = trg[i, 0, :, :].numpy()
        draw([img_i, cumsum_i, trg_i])
    break

# Train prob

In [None]:
params = {
    'nn_module': ('DPNProbUnet', {
        'num_classes': 1,
        'num_channels': 3,
        'encoder_name': 'dpn92',
        'dropout': 0
    }),
    'loss': ('FbBceProbLoss', {
        'fb_weight': 0.95,
        'fb_beta': 2,
        'bce_weight': 0.9,
        'prob_weight': 0.85
    }),
    'prediction_transform': ('ProbOutputTransform', {
        'segm_thresh': 0.5,
        'prob_thresh': 0.5
    }),
    'optimizer': ('Adam', {'lr': 0.0001}),
    'device': 'cuda'
}

model = SaltMetaModel(params)

In [None]:
callbacks = [
    MonitorCheckpoint('/workdir/data/experiments/test_022', monitor='val_crop_iout', max_saves=3),
    EarlyStopping(monitor='val_crop_iout', patience=50),
    LoggingToFile('/workdir/data/experiments/test_022/log.txt')
]

model.fit(train_loader, 
          val_loader=val_loader,
          max_epochs=1000,
          callbacks=callbacks,
          metrics=['crop_iout'])

In [None]:
from argus import load_model

experiment_name = 'test_025'
lr_steps = [
    (300, 0.0001),
    (300, 0.00003),
    (300, 0.00001),
    (300, 0.000003),
    (1000, 0.0000001)
]

for i, (epochs, lr) in enumerate(lr_steps):
    print(i, epochs, lr)
    if not i:
        model = SaltMetaModel(params)
    else:
        model = load_model(f'/workdir/data/experiments/{experiment_name}/model-last.pth')
        
    callbacks = [
        MonitorCheckpoint(f'/workdir/data/experiments/{experiment_name}', monitor='val_crop_iout', max_saves=2),
        EarlyStopping(monitor='val_crop_iout', patience=50),
        LoggingToFile(f'/workdir/data/experiments/{experiment_name}/log.txt')
    ]    
    
    model.set_lr(lr)
    model.fit(train_loader, 
          val_loader=val_loader,
          max_epochs=epochs,
          callbacks=callbacks,
          metrics=['crop_iout'])