In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../..')

In [3]:
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

In [4]:
from stack_segmentation.stack import Stack

In [5]:
from stack_segmentation.aug_pipelines import medium_aug
from stack_segmentation.io import make_dataloader, collate_fn_basic
from stack_segmentation.training import handle_stacks_data, make_model, train_loop
from stack_segmentation.unet import UNet
from stack_segmentation.pipeline_config import dataloaders_conf, model_conf, train_conf, loss_config

In [6]:
from exp_config import data_conf

## Parameters to tune

In [7]:
train_conf['device'] = 'cpu'
model_conf['device'] = 'cpu'

In [8]:
data_conf['conf_name'] = 'basic_lr1e-3_epoch300'
data_conf

{'conf_name': 'basic_lr1e-3_epoch300',
 'stacks': [{'path': '../../data/carb96558',
   'slice_train': (slice(None, None, None),
    slice(None, None, None),
    slice(None, 230, None)),
   'slice_val': (slice(None, None, None),
    slice(None, None, None),
    slice(250, 470, None))},
  {'path': '../../data/SoilB-2',
   'slice_train': (slice(None, None, None),
    slice(None, None, None),
    slice(None, 230, None)),
   'slice_val': (slice(None, None, None),
    slice(None, None, None),
    slice(240, 460, None))},
  {'path': '../../data/Urna_22',
   'slice_train': (slice(None, None, None),
    slice(None, None, None),
    slice(None, 220, None)),
   'slice_val': (slice(None, None, None),
    slice(None, None, None),
    slice(245, 455, None))},
  {'path': '../../data/carb96558',
   'slice_test': (slice(None, None, None),
    slice(None, None, None),
    slice(490, None, None))},
  {'path': '../../data/carb71',
   'slice_test': (slice(None, None, None),
    slice(None, None, None),
   

In [9]:
dataloaders_conf['train']['augmentation_pipeline'] = None
dataloaders_conf

{'train': {'batch_size': 32,
  'num_workers': 8,
  'shuffle': True,
  'augmentation_pipeline': None},
 'val': {'batch_size': 32,
  'num_workers': 8,
  'shuffle': False,
  'augmentation_pipeline': None},
 'test': {'batch_size': 32,
  'num_workers': 8,
  'shuffle': True,
  'augmentation_pipeline': None}}

In [10]:
train_conf['num_epochs'] = 300
train_conf

{'num_epochs': 300, 'device': 'cpu'}

In [11]:
model_conf['opt_type'] = 'SGD'
model_conf['lr'] = 1e-3
model_conf['weight'] = None
model_conf['nesterov'] = True
model_conf

{'device': 'cpu',
 'opt_type': 'SGD',
 'lr': 0.001,
 'weight_decay': 0.0001,
 'amsgrad': False,
 'nesterov': True,
 'momentum': 0.9,
 'centered': False,
 'min_lr': 1e-06,
 'factor': 0.5,
 'patience': 5,
 'weight': None}

In [12]:
data_conf['patches']

{'train': (128, 128, 1), 'val': (128, 128, 1), 'test': (128, 128, 1)}

## Prepare train, validation and test data

In [13]:
data_train, data_val, data_test = handle_stacks_data(**data_conf)

720it [00:01, 694.81it/s]
100%|██████████| 720/720 [00:07<00:00, 93.11it/s] 
8280it [00:00, 238881.54it/s]
7920it [00:00, 249424.75it/s]
700it [00:00, 863.35it/s]
100%|██████████| 700/700 [00:07<00:00, 98.70it/s] 
8280it [00:00, 250881.96it/s]
7920it [00:00, 81817.91it/s]
710it [00:00, 829.65it/s]
100%|██████████| 710/710 [00:07<00:00, 94.78it/s] 
7920it [00:00, 246969.56it/s]
7560it [00:00, 236762.85it/s]
720it [00:00, 799.16it/s]
100%|██████████| 720/720 [00:07<00:00, 90.82it/s] 
8280it [00:00, 256378.96it/s]
720it [00:00, 820.24it/s]
100%|██████████| 720/720 [00:07<00:00, 93.26it/s] 
25920it [00:00, 118542.46it/s]
700it [00:00, 850.66it/s]
100%|██████████| 700/700 [00:07<00:00, 99.06it/s] 
25200it [00:00, 133520.54it/s]
509it [00:00, 1163.95it/s]
100%|██████████| 509/509 [00:02<00:00, 195.62it/s]
8144it [00:00, 189609.89it/s]
700it [00:00, 833.01it/s]
100%|██████████| 700/700 [00:07<00:00, 98.91it/s] 
25200it [00:00, 231519.36it/s]
700it [00:00, 814.58it/s]
100%|██████████| 700/700 

In [14]:
len(data_train), len(data_val), len(data_test)

(24480, 23400, 11)

In [15]:
dataloader_train = make_dataloader(
    samples=data_train, 
    collate_fn=collate_fn_basic,
    **dataloaders_conf['train']
)

dataloader_val = make_dataloader(
    samples=data_val, 
    collate_fn=collate_fn_basic,
    **dataloaders_conf['val']
)

dataloaders_test = {
    name: make_dataloader(
        samples=data, 
        collate_fn=collate_fn_basic,
        **dataloaders_conf['test']
    ) for name, data in data_test.items()}

## Create model and metrics

In [16]:
from stack_segmentation.metrics import accuracy, precision, recall, f1, pr_auc, iou

In [17]:
metrics = {
    'accuracy': accuracy, 
    'precision': precision, 
    'recall': recall, 
    'f1': f1,
    'pr_auc': pr_auc, 
    'iou': iou,
}

In [18]:
device = 'cpu'

In [19]:
model, criterion, optimizer, scheduler = make_model(loss_config=loss_config, **model_conf)

## Run experiment

In [None]:
results = train_loop(
    model=model,
    dataloader_train=dataloader_train, 
    dataloader_val=dataloader_val,
    dataloaders_test=dataloaders_test,
    criterion=criterion, 
    optimizer=optimizer, 
    scheduler=scheduler,
    metrics=metrics,
    exp_name=data_conf['conf_name'],
    **train_conf)

## Dump experiment results

In [None]:
import pickle
import json

In [None]:
p = './{}_exp_results.pkl'.format(data_conf['conf_name'])
with open(p, 'wb') as f:
    pickle.dump(results, f)

In [17]:
# p = './{}_exp_results.pkl'.format(data_conf['conf_name'])
# with open(p, 'rb') as f:
#     results = pickle.load(f)

In [11]:
import torch
model.load_state_dict(torch.load('./{}.pt'.format(data_conf['conf_name'])))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

## Train and validation losses

In [None]:
from itertools import chain

In [None]:
train_losses = list(chain(*[item for item in results['train_losses']]))
val_losses = list(chain(*[item for item in results['val_losses']]))

In [None]:
def moving_average(a, n=5) :
    ret = np.cumsum([a[0]] * (n - 1) + a, dtype=float)
    ret[n:] = ret[n:] - ret[:-n]
    return ret[n - 1:] / n

In [None]:
plt.figure(figsize=(10, 10))
plt.title('Moving-averaged batch losses')
plt.plot(np.arange(len(train_losses)), moving_average(train_losses), label='train')
plt.plot(np.arange(len(val_losses)), moving_average(val_losses), label='validation')

plt.legend(loc='best')
plt.yscale('log')

# plt.ylim([1e-2, 1])
plt.show()

In [None]:
mean_train_loss = [np.mean(item) for item in results['train_losses']]
mean_val_loss = [np.mean(item) for item in results['val_losses']]

In [None]:
plt.figure(figsize=(10, 10))
plt.title('Epoch losses')
plt.plot(np.arange(len(mean_train_loss)) + 1, mean_train_loss, label='train')
plt.plot(np.arange(len(mean_val_loss)) + 1, mean_val_loss, label='val')

plt.yscale('log')
plt.legend(loc='best')

plt.xlim([1, len(mean_train_loss) + 1])
plt.show()

## Results

In [None]:
import pandas as pd

In [None]:
from visualization_utils import make_df

In [None]:
df = make_df(results, model_name='basic')
df

In [None]:
print('Mean   IOU: {:.5}'.format(df['iou'].mean()))
print('Std    IOU: {:.5}'.format(df['iou'].std()))
print('Min    IOU: {:.5}'.format(df['iou'].min()))
print('Median IOU: {:.5}'.format(df['iou'].median()))

In [29]:
for x, y in dataloader_train:
    break

In [30]:
x.shape, y.shape

((32, 1, 128, 128), (32, 128, 128))

In [31]:
import torch
from torch import nn

In [32]:
x = torch.from_numpy(x)
y = torch.from_numpy(y)

In [33]:
pred = model(x)

In [34]:
criterion(pred, y)

tensor(0.8451, grad_fn=<AddBackward0>)

In [34]:
ce = nn.CrossEntropyLoss()
dice = DiceLoss(mode='multiclass', log_loss=True, smooth=1)

In [None]:
from pytorch_toolbelt.losses import DiceLoss, JointLoss

In [42]:
ce(pred, y)

tensor(0.7135, grad_fn=<NllLoss2DBackward>)

In [43]:
dice(pred, y)

tensor(1.1725, grad_fn=<MeanBackward0>)

In [40]:
y.size(), pred.size()

(torch.Size([32, 128, 128]), torch.Size([32, 2, 128, 128]))

In [48]:
loss = JointLoss(ce, dice, 1, 1)

In [49]:
loss(pred, y)

tensor(1.8860, grad_fn=<AddBackward0>)

In [100]:
# 'weight': [1, 10],


In [103]:
loss.losses

(WeightedLoss(
   (loss): CrossEntropyLoss()
 ), WeightedLoss(
   (loss): DiceLoss()
 ))

In [104]:
device = 'cpu'

In [106]:
loss(pred, y)

tensor(0.9430, grad_fn=<AddBackward0>)