In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../..')

In [3]:
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

In [4]:
from stack_segmentation.stack import Stack

In [5]:
from stack_segmentation.aug_pipelines import medium_aug
from stack_segmentation.io import make_dataloader, collate_fn_basic
from stack_segmentation.training import handle_stacks_data, make_model, train_loop
from stack_segmentation.unet import UNet
from stack_segmentation.pipeline_config import dataloaders_conf, model_conf, train_conf, loss_config

In [6]:
from exp_config import data_conf

## Parameters to tune

In [7]:
train_conf['device'] = 'cuda:1'
model_conf['device'] = 'cuda:1'

In [8]:
data_conf['conf_name'] = 'basic_lr1e-2_epoch300'
data_conf

{'conf_name': 'basic_lr1e-2_epoch300',
 'stacks': [{'path': '../../data/carb96558',
   'slice_train': (slice(None, None, None),
    slice(None, None, None),
    slice(None, 230, None)),
   'slice_val': (slice(None, None, None),
    slice(None, None, None),
    slice(250, 470, None))},
  {'path': '../../data/SoilB-2',
   'slice_train': (slice(None, None, None),
    slice(None, None, None),
    slice(None, 230, None)),
   'slice_val': (slice(None, None, None),
    slice(None, None, None),
    slice(240, 460, None))},
  {'path': '../../data/Urna_22',
   'slice_train': (slice(None, None, None),
    slice(None, None, None),
    slice(None, 220, None)),
   'slice_val': (slice(None, None, None),
    slice(None, None, None),
    slice(245, 455, None))},
  {'path': '../../data/carb96558',
   'slice_test': (slice(None, None, None),
    slice(None, None, None),
    slice(490, None, None))},
  {'path': '../../data/carb71',
   'slice_test': (slice(None, None, None),
    slice(None, None, None),
   

In [9]:
dataloaders_conf['train']['augmentation_pipeline'] = None
dataloaders_conf

{'train': {'batch_size': 32,
  'num_workers': 8,
  'shuffle': True,
  'augmentation_pipeline': None},
 'val': {'batch_size': 32,
  'num_workers': 8,
  'shuffle': False,
  'augmentation_pipeline': None},
 'test': {'batch_size': 32,
  'num_workers': 8,
  'shuffle': True,
  'augmentation_pipeline': None}}

In [10]:
train_conf['num_epochs'] = 300
train_conf

{'num_epochs': 300, 'device': 'cuda:1'}

In [11]:
model_conf['opt_type'] = 'SGD'
model_conf['lr'] = 1e-2
model_conf['weight'] = None
model_conf['nesterov'] = True
model_conf

{'device': 'cuda:1',
 'opt_type': 'SGD',
 'lr': 0.01,
 'weight_decay': 0.0001,
 'amsgrad': False,
 'nesterov': True,
 'momentum': 0.9,
 'centered': False,
 'min_lr': 1e-06,
 'factor': 0.5,
 'patience': 5,
 'weight': None}

In [12]:
loss_config =[
    {'loss': 'BCE', 'weight': 1, 'params': {}},
]
loss_config

[{'loss': 'BCE', 'weight': 1, 'params': {}}]

In [13]:
data_conf['patches']

{'train': (128, 128, 1), 'val': (128, 128, 1), 'test': (128, 128, 1)}

## Prepare train, validation and test data

In [14]:
data_train, data_val, data_test = handle_stacks_data(**data_conf)

720it [00:03, 197.58it/s]
100%|██████████| 720/720 [00:08<00:00, 81.56it/s]
8280it [00:00, 199268.07it/s]
7920it [00:00, 225934.26it/s]
700it [00:03, 225.96it/s]
100%|██████████| 700/700 [00:07<00:00, 88.36it/s] 
8280it [00:00, 229427.09it/s]
7920it [00:00, 74432.41it/s]
710it [00:03, 208.77it/s]
100%|██████████| 710/710 [00:08<00:00, 84.30it/s]
7920it [00:00, 245796.37it/s]
7560it [00:00, 227216.26it/s]
720it [00:01, 699.08it/s]
100%|██████████| 720/720 [00:08<00:00, 81.37it/s]
8280it [00:00, 190904.85it/s]
720it [00:03, 217.03it/s]
100%|██████████| 720/720 [00:08<00:00, 81.05it/s]
25920it [00:00, 108520.26it/s]
700it [00:03, 218.77it/s]
100%|██████████| 700/700 [00:08<00:00, 86.36it/s]
25200it [00:00, 99796.30it/s] 
509it [00:01, 361.38it/s]
100%|██████████| 509/509 [00:03<00:00, 166.82it/s]
8144it [00:00, 187964.50it/s]
700it [00:03, 210.11it/s]
100%|██████████| 700/700 [00:08<00:00, 86.95it/s]
25200it [00:00, 191659.30it/s]
700it [00:01, 694.39it/s]
100%|██████████| 700/700 [00:08<

In [15]:
len(data_train), len(data_val), len(data_test)

(24480, 23400, 11)

In [16]:
dataloader_train = make_dataloader(
    samples=data_train, 
    collate_fn=collate_fn_basic,
    **dataloaders_conf['train']
)

dataloader_val = make_dataloader(
    samples=data_val, 
    collate_fn=collate_fn_basic,
    **dataloaders_conf['val']
)

dataloaders_test = {
    name: make_dataloader(
        samples=data, 
        collate_fn=collate_fn_basic,
        **dataloaders_conf['test']
    ) for name, data in data_test.items()}

## Create model and metrics

In [17]:
from stack_segmentation.metrics import accuracy, precision, recall, f1, pr_auc, iou

In [18]:
metrics = {
    'accuracy': accuracy, 
    'precision': precision, 
    'recall': recall, 
    'f1': f1,
    'pr_auc': pr_auc, 
    'iou': iou,
}

In [19]:
device = 'cuda:1'

In [20]:
model, criterion, optimizer, scheduler = make_model(loss_config=loss_config, **model_conf)

## Run experiment

In [None]:
results = train_loop(
    model=model,
    dataloader_train=dataloader_train, 
    dataloader_val=dataloader_val,
    dataloaders_test=dataloaders_test,
    criterion=criterion, 
    optimizer=optimizer, 
    scheduler=scheduler,
    metrics=metrics,
    exp_name=data_conf['conf_name'],
    **train_conf)

  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 0...


100%|██████████| 765/765 [02:41<00:00,  4.88it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.048393


100%|██████████| 732/732 [00:59<00:00, 12.38it/s]


Mean val loss: 0.019424


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 1...


100%|██████████| 765/765 [02:42<00:00,  4.87it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.018001


100%|██████████| 732/732 [00:59<00:00, 12.28it/s]


Mean val loss: 0.016344


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 2...


100%|██████████| 765/765 [02:43<00:00,  4.87it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.015546


100%|██████████| 732/732 [00:59<00:00, 12.33it/s]


Mean val loss: 0.016338


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 3...


100%|██████████| 765/765 [02:43<00:00,  4.87it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.014396


100%|██████████| 732/732 [00:59<00:00, 12.32it/s]


Mean val loss: 0.014038


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 4...


100%|██████████| 765/765 [02:43<00:00,  4.85it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.013666


100%|██████████| 732/732 [00:59<00:00, 12.37it/s]


Mean val loss: 0.01366


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 5...


100%|██████████| 765/765 [02:43<00:00,  4.84it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.013538


100%|██████████| 732/732 [00:59<00:00, 12.33it/s]


Mean val loss: 0.013391


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 6...


100%|██████████| 765/765 [02:43<00:00,  4.85it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.013025


100%|██████████| 732/732 [00:59<00:00, 12.32it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.013583
EarlyStopping counter: 1 out of 10
Epoch 7...


100%|██████████| 765/765 [02:43<00:00,  4.82it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.012923


100%|██████████| 732/732 [00:59<00:00, 12.31it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.013606
EarlyStopping counter: 2 out of 10
Epoch 8...


100%|██████████| 765/765 [02:44<00:00,  4.88it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.012732


100%|██████████| 732/732 [00:59<00:00, 12.33it/s]


Mean val loss: 0.013092


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 9...


100%|██████████| 765/765 [02:43<00:00,  4.82it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.012621


100%|██████████| 732/732 [00:59<00:00, 12.31it/s]


Mean val loss: 0.012971


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 10...


100%|██████████| 765/765 [02:43<00:00,  4.82it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.012484


100%|██████████| 732/732 [00:59<00:00, 12.36it/s]


Mean val loss: 0.012875


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 11...


100%|██████████| 765/765 [02:43<00:00,  4.86it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.012453


100%|██████████| 732/732 [00:59<00:00, 12.28it/s]


Mean val loss: 0.012751


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 12...


100%|██████████| 765/765 [02:44<00:00,  4.83it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.012404


100%|██████████| 732/732 [00:59<00:00, 12.32it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.013316
EarlyStopping counter: 1 out of 10
Epoch 13...


100%|██████████| 765/765 [02:43<00:00,  4.82it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.012333


100%|██████████| 732/732 [00:59<00:00, 12.35it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.012905
EarlyStopping counter: 2 out of 10
Epoch 14...


100%|██████████| 765/765 [02:43<00:00,  4.86it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.012264


100%|██████████| 732/732 [00:59<00:00, 12.31it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.013103
EarlyStopping counter: 3 out of 10
Epoch 15...


100%|██████████| 765/765 [02:43<00:00,  4.82it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.012244


100%|██████████| 732/732 [00:59<00:00, 12.30it/s]


Mean val loss: 0.012575


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 16...


100%|██████████| 765/765 [02:43<00:00,  4.87it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.012182


100%|██████████| 732/732 [00:59<00:00, 12.35it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.012852
EarlyStopping counter: 1 out of 10
Epoch 17...


100%|██████████| 765/765 [02:43<00:00,  4.85it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.012177


100%|██████████| 732/732 [00:59<00:00, 12.32it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.012943
EarlyStopping counter: 2 out of 10
Epoch 18...


100%|██████████| 765/765 [02:43<00:00,  4.83it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.012117


100%|██████████| 732/732 [00:59<00:00, 12.29it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.012653
EarlyStopping counter: 3 out of 10
Epoch 19...


100%|██████████| 765/765 [02:43<00:00,  4.88it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.012094


100%|██████████| 732/732 [00:59<00:00, 12.30it/s]


Mean val loss: 0.01249


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 20...


100%|██████████| 765/765 [02:43<00:00,  4.83it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.012073


100%|██████████| 732/732 [00:59<00:00, 12.30it/s]


Mean val loss: 0.012481


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 21...


100%|██████████| 765/765 [02:43<00:00,  4.82it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.012068


100%|██████████| 732/732 [00:58<00:00, 12.45it/s]


Mean val loss: 0.012464


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 22...


100%|██████████| 765/765 [02:43<00:00,  4.81it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.012003


100%|██████████| 732/732 [00:59<00:00, 12.26it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.01297
EarlyStopping counter: 1 out of 10
Epoch 23...


100%|██████████| 765/765 [02:46<00:00,  4.77it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.012018


100%|██████████| 732/732 [01:02<00:00, 11.73it/s]


Mean val loss: 0.0124


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 24...


100%|██████████| 765/765 [02:48<00:00,  4.67it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.011991


## Dump experiment results

In [None]:
import pickle
import json

In [None]:
p = './{}_exp_results.pkl'.format(data_conf['conf_name'])
with open(p, 'wb') as f:
    pickle.dump(results, f)

In [17]:
# p = './{}_exp_results.pkl'.format(data_conf['conf_name'])
# with open(p, 'rb') as f:
#     results = pickle.load(f)

In [11]:
import torch
model.load_state_dict(torch.load('./{}.pt'.format(data_conf['conf_name'])))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

## Train and validation losses

In [None]:
from itertools import chain

In [None]:
train_losses = list(chain(*[item for item in results['train_losses']]))
val_losses = list(chain(*[item for item in results['val_losses']]))

In [None]:
def moving_average(a, n=5) :
    ret = np.cumsum([a[0]] * (n - 1) + a, dtype=float)
    ret[n:] = ret[n:] - ret[:-n]
    return ret[n - 1:] / n

In [None]:
plt.figure(figsize=(10, 10))
plt.title('Moving-averaged batch losses')
plt.plot(np.arange(len(train_losses)), moving_average(train_losses), label='train')
plt.plot(np.arange(len(val_losses)), moving_average(val_losses), label='validation')

plt.legend(loc='best')
plt.yscale('log')

# plt.ylim([1e-2, 1])
plt.show()

In [None]:
mean_train_loss = [np.mean(item) for item in results['train_losses']]
mean_val_loss = [np.mean(item) for item in results['val_losses']]

In [None]:
plt.figure(figsize=(10, 10))
plt.title('Epoch losses')
plt.plot(np.arange(len(mean_train_loss)) + 1, mean_train_loss, label='train')
plt.plot(np.arange(len(mean_val_loss)) + 1, mean_val_loss, label='val')

plt.yscale('log')
plt.legend(loc='best')

plt.xlim([1, len(mean_train_loss) + 1])
plt.show()

## Results

In [None]:
import pandas as pd

In [None]:
from visualization_utils import make_df

In [None]:
df = make_df(results, model_name='basic')
df

In [None]:
print('Mean   IOU: {:.5}'.format(df['iou'].mean()))
print('Std    IOU: {:.5}'.format(df['iou'].std()))
print('Min    IOU: {:.5}'.format(df['iou'].min()))
print('Median IOU: {:.5}'.format(df['iou'].median()))