In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../..')

In [3]:
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

In [4]:
from stack_segmentation.stack import Stack

In [5]:
from stack_segmentation.io import make_dataloader, collate_fn_basic

from stack_segmentation.training import (
    handle_stacks_data, 
    make_optimization_task, 
    train_loop
)

from stack_segmentation.pipeline_config import (
    dataloaders_conf,
    train_conf,
    model_config, 
    aug_config,
    optimizer_config,
    loss_config,
    scheduler_config,
)

In [6]:
from exp_config import data_conf

## Parameters to tune

In [7]:
train_conf['device'] = 'cuda:1'
# train_conf['device'] = 'cpu'
train_conf['num_epochs'] = 500
train_conf

{'num_epochs': 500, 'device': 'cuda:1'}

In [8]:
data_conf['conf_name'] = 'exp_basic_adamw_lr1e-3_epoch_300_resnet50_encoder_soft_aug_k_1_weight10_v2'
data_conf

{'conf_name': 'exp_basic_adamw_lr1e-3_epoch_300_resnet50_encoder_soft_aug_k_1_weight10_v2',
 'stacks': [{'path': '../../data/carb96558',
   'slice_train': (slice(None, None, None),
    slice(None, None, None),
    slice(None, 230, None)),
   'slice_val': (slice(None, None, None),
    slice(None, None, None),
    slice(250, 470, None))},
  {'path': '../../data/SoilB-2',
   'slice_train': (slice(None, None, None),
    slice(None, None, None),
    slice(None, 230, None)),
   'slice_val': (slice(None, None, None),
    slice(None, None, None),
    slice(240, 460, None))},
  {'path': '../../data/Urna_22',
   'slice_train': (slice(None, None, None),
    slice(None, None, None),
    slice(None, 220, None)),
   'slice_val': (slice(None, None, None),
    slice(None, None, None),
    slice(245, 455, None))},
  {'path': '../../data/carb96558',
   'slice_test': (slice(None, None, None),
    slice(None, None, None),
    slice(490, None, None))},
  {'path': '../../data/carb71',
   'slice_test': (slic

In [9]:
model_config

{'source': 'qubvel',
 'model_type': 'Unet',
 'encoder_name': 'resnet50',
 'encoder_weights': 'imagenet'}

In [10]:
optimizer_config['opt_type'] = 'AdamW'
optimizer_config['lr'] = 5e-4
optimizer_config['weight_decay'] = 1e-4
optimizer_config['amsgrad'] = True
# optimizer_config['nesterov'] = True
optimizer_config

{'opt_type': 'AdamW',
 'lr': 0.0005,
 'weight_decay': 0.0001,
 'amsgrad': True,
 'nesterov': False,
 'momentum': 0.9,
 'centered': False}

In [11]:
aug_config['aug_type'] = 'soft'
aug_config['k'] = 1
aug_config

{'aug_type': 'soft', 'original_height': 128, 'original_width': 128, 'k': 1}

In [12]:
loss_config =[
    {
        'loss': 'BCE', 
        'weight': 0.6, 
        'params': {'weight': [1, 10]}},
    {
        'loss': 'Dice',
        'weight': 0.4, 
        'params': {
            'mode': 'multiclass',
            'classes': [1], # может быть, этот параметр не нужен
            'log_loss': True,
            'from_logits': True,
            'smooth': 1,
            'eps': 1e-7
        }
    }
]
loss_config

[{'loss': 'BCE', 'weight': 0.6, 'params': {'weight': [1, 10]}},
 {'loss': 'Dice',
  'weight': 0.4,
  'params': {'mode': 'multiclass',
   'classes': [1],
   'log_loss': True,
   'from_logits': True,
   'smooth': 1,
   'eps': 1e-07}}]

In [13]:
data_conf['patches']

{'train': (128, 128, 1), 'val': (128, 128, 1), 'test': (128, 128, 1)}

## Prepare train, validation and test data

In [14]:
data_train, data_val, data_test = handle_stacks_data(**data_conf)

720it [00:01, 660.35it/s]
100%|██████████| 720/720 [00:07<00:00, 92.57it/s] 
8280it [00:00, 242497.80it/s]
7920it [00:00, 246315.80it/s]
700it [00:00, 852.73it/s]
100%|██████████| 700/700 [00:07<00:00, 99.56it/s] 
8280it [00:00, 260400.53it/s]
7920it [00:00, 250048.08it/s]
710it [00:00, 850.45it/s]
100%|██████████| 710/710 [00:07<00:00, 96.56it/s] 
7920it [00:00, 84723.47it/s]
7560it [00:00, 236117.58it/s]
720it [00:00, 815.91it/s]
100%|██████████| 720/720 [00:07<00:00, 94.58it/s] 
8280it [00:00, 258896.07it/s]
720it [00:00, 838.62it/s]
100%|██████████| 720/720 [00:07<00:00, 94.83it/s] 
25920it [00:00, 135762.23it/s]
700it [00:01, 657.24it/s]
100%|██████████| 700/700 [00:06<00:00, 100.89it/s]
25200it [00:00, 106392.27it/s]
509it [00:00, 1112.55it/s]
100%|██████████| 509/509 [00:02<00:00, 195.80it/s]
8144it [00:00, 230745.51it/s]
700it [00:00, 808.00it/s]
100%|██████████| 700/700 [00:07<00:00, 98.91it/s] 
25200it [00:00, 227957.20it/s]
700it [00:00, 807.19it/s]
100%|██████████| 700/700 

In [15]:
len(data_train), len(data_val), len(data_test)

(24480, 23400, 11)

In [16]:
dataloader_train = make_dataloader(
    samples=data_train, 
    collate_fn=collate_fn_basic,
    model_config=model_config,
    aug_config=aug_config,
    **dataloaders_conf['train']
)

dataloader_val = make_dataloader(
    samples=data_val, 
    collate_fn=collate_fn_basic,
    model_config=model_config,
    **dataloaders_conf['val']
)

dataloaders_test = {
    name: make_dataloader(
        samples=data, 
        collate_fn=collate_fn_basic,
        model_config=model_config,
        **dataloaders_conf['test']
    ) for name, data in data_test.items()}

## Create model and metrics

In [17]:
device = 'cuda:1'
# device = 'cpu'

In [18]:
model, criterion, optimizer, scheduler = make_optimization_task(
    device,
    model_config=model_config,
    loss_config=loss_config, 
    optimizer_config=optimizer_config,
    scheduler_config=scheduler_config)

## Run experiment

In [19]:
from stack_segmentation.metrics import accuracy, precision, recall, f1, pr_auc, iou

In [20]:
metrics = {
    'accuracy': accuracy, 
    'precision': precision, 
    'recall': recall, 
    'f1': f1,
    'pr_auc': pr_auc, 
    'iou': iou,
}

In [None]:
results = train_loop(
    model=model,
    dataloader_train=dataloader_train, 
    dataloader_val=dataloader_val,
    dataloaders_test=dataloaders_test,
    criterion=criterion, 
    optimizer=optimizer, 
    scheduler=scheduler,
    metrics=metrics,
    exp_name=data_conf['conf_name'],
    **train_conf)

  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 0...


100%|██████████| 765/765 [01:35<00:00,  8.04it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.047239


100%|██████████| 732/732 [00:31<00:00, 23.26it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.020733
Epoch 1...


100%|██████████| 765/765 [01:36<00:00,  7.96it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.017561


100%|██████████| 732/732 [00:31<00:00, 23.20it/s]


Mean val loss: 0.016408


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 2...


100%|██████████| 765/765 [01:36<00:00,  7.95it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.015112


100%|██████████| 732/732 [00:31<00:00, 23.16it/s]


Mean val loss: 0.015011


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 3...


100%|██████████| 765/765 [01:36<00:00,  7.93it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.014127


100%|██████████| 732/732 [00:31<00:00, 23.23it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.017855
EarlyStopping counter: 1 out of 15
Epoch 4...


100%|██████████| 765/765 [01:36<00:00,  7.91it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.013572


100%|██████████| 732/732 [00:31<00:00, 23.18it/s]


Mean val loss: 0.013282


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 5...


100%|██████████| 765/765 [01:37<00:00,  7.88it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.013235


100%|██████████| 732/732 [00:31<00:00, 23.34it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.014513
EarlyStopping counter: 1 out of 15
Epoch 6...


100%|██████████| 765/765 [01:37<00:00,  7.88it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.012771


100%|██████████| 732/732 [00:31<00:00, 23.21it/s]


Mean val loss: 0.012462


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 7...


100%|██████████| 765/765 [01:37<00:00,  7.88it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.012529


100%|██████████| 732/732 [00:31<00:00, 23.32it/s]


Mean val loss: 0.012228


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 8...


100%|██████████| 765/765 [01:37<00:00,  7.87it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.012488


100%|██████████| 732/732 [00:32<00:00, 22.81it/s]


Mean val loss: 0.01204


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 9...


100%|██████████| 765/765 [01:37<00:00,  7.86it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.012288


100%|██████████| 732/732 [00:31<00:00, 23.05it/s]


Mean val loss: 0.011913


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 10...


100%|██████████| 765/765 [01:39<00:00,  7.72it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.01217


100%|██████████| 732/732 [00:31<00:00, 23.21it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.012011
EarlyStopping counter: 1 out of 15
Epoch 11...


100%|██████████| 765/765 [01:39<00:00,  7.68it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.012117


100%|██████████| 732/732 [00:33<00:00, 22.16it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.012507
EarlyStopping counter: 2 out of 15
Epoch 12...


100%|██████████| 765/765 [01:37<00:00,  7.85it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.011964


100%|██████████| 732/732 [00:31<00:00, 23.28it/s]


Mean val loss: 0.011905


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 13...


100%|██████████| 765/765 [01:38<00:00,  7.81it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.011908


100%|██████████| 732/732 [00:31<00:00, 23.13it/s]


Mean val loss: 0.011643


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 14...


100%|██████████| 765/765 [01:37<00:00,  7.84it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.011783


100%|██████████| 732/732 [00:33<00:00, 21.86it/s]


Mean val loss: 0.011612


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 15...


100%|██████████| 765/765 [01:38<00:00,  7.80it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.011754


100%|██████████| 732/732 [00:31<00:00, 23.11it/s]


Mean val loss: 0.01152


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 16...


100%|██████████| 765/765 [01:37<00:00,  7.86it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.011696


100%|██████████| 732/732 [00:31<00:00, 22.94it/s]


Mean val loss: 0.011519


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 17...


100%|██████████| 765/765 [01:38<00:00,  7.76it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.011635


100%|██████████| 732/732 [00:31<00:00, 23.29it/s]


Mean val loss: 0.01145


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 18...


100%|██████████| 765/765 [01:38<00:00,  7.80it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.011569


100%|██████████| 732/732 [00:31<00:00, 23.07it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.011938
EarlyStopping counter: 1 out of 15
Epoch 19...


100%|██████████| 765/765 [01:36<00:00,  7.91it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.011537


100%|██████████| 732/732 [00:31<00:00, 22.97it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.011465
EarlyStopping counter: 2 out of 15
Epoch 20...


100%|██████████| 765/765 [01:37<00:00,  7.85it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.011525


100%|██████████| 732/732 [00:31<00:00, 23.01it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.011485
EarlyStopping counter: 3 out of 15
Epoch 21...


100%|██████████| 765/765 [01:37<00:00,  7.83it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.011438


100%|██████████| 732/732 [00:31<00:00, 23.00it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.011603
EarlyStopping counter: 4 out of 15
Epoch 22...


100%|██████████| 765/765 [01:37<00:00,  7.84it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.01134


100%|██████████| 732/732 [00:31<00:00, 23.06it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.011879
EarlyStopping counter: 5 out of 15
Epoch 23...


100%|██████████| 765/765 [01:37<00:00,  7.85it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.011393


100%|██████████| 732/732 [00:32<00:00, 22.77it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.011662
EarlyStopping counter: 6 out of 15
Epoch 24...


100%|██████████| 765/765 [01:37<00:00,  7.83it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.011307


100%|██████████| 732/732 [00:31<00:00, 22.91it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.012287
EarlyStopping counter: 7 out of 15
Epoch 25...


100%|██████████| 765/765 [01:37<00:00,  7.85it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.01126


100%|██████████| 732/732 [00:31<00:00, 22.95it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.011295
Epoch 26...


100%|██████████| 765/765 [01:37<00:00,  7.87it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.011227


100%|██████████| 732/732 [00:31<00:00, 23.02it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.011194
Epoch 27...


100%|██████████| 765/765 [01:37<00:00,  7.88it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.011185


100%|██████████| 732/732 [00:31<00:00, 22.93it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.011358
EarlyStopping counter: 1 out of 15
Epoch 28...


100%|██████████| 765/765 [01:36<00:00,  7.89it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.011182


100%|██████████| 732/732 [00:31<00:00, 23.22it/s]


Mean val loss: 0.011133


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 29...


100%|██████████| 765/765 [01:37<00:00,  7.88it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.011151


100%|██████████| 732/732 [00:31<00:00, 23.01it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.011157
EarlyStopping counter: 1 out of 15
Epoch 30...


100%|██████████| 765/765 [01:37<00:00,  7.89it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.011114


100%|██████████| 732/732 [00:31<00:00, 23.01it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.011123
Epoch 31...


100%|██████████| 765/765 [01:36<00:00,  7.89it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.011097


100%|██████████| 732/732 [00:31<00:00, 23.19it/s]


Mean val loss: 0.011


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 32...


100%|██████████| 765/765 [01:37<00:00,  7.88it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.01104


100%|██████████| 732/732 [00:31<00:00, 23.17it/s]


Mean val loss: 0.010983


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 33...


100%|██████████| 765/765 [01:37<00:00,  7.87it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.011017


100%|██████████| 732/732 [00:31<00:00, 23.21it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.011191
EarlyStopping counter: 1 out of 15
Epoch 34...


100%|██████████| 765/765 [01:37<00:00,  7.88it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.01096


100%|██████████| 732/732 [00:31<00:00, 23.00it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.011024
EarlyStopping counter: 2 out of 15
Epoch 35...


100%|██████████| 765/765 [01:37<00:00,  7.87it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010976


100%|██████████| 732/732 [00:31<00:00, 23.05it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.011015
EarlyStopping counter: 3 out of 15
Epoch 36...


100%|██████████| 765/765 [01:38<00:00,  7.81it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010946


100%|██████████| 732/732 [00:31<00:00, 23.17it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.011189
EarlyStopping counter: 4 out of 15
Epoch 37...


100%|██████████| 765/765 [01:37<00:00,  7.87it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010927


100%|██████████| 732/732 [00:32<00:00, 22.70it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.01104
EarlyStopping counter: 5 out of 15
Epoch 38...


100%|██████████| 765/765 [01:39<00:00,  7.69it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010875


100%|██████████| 732/732 [00:31<00:00, 23.10it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.010947
Epoch 39...


100%|██████████| 765/765 [01:36<00:00,  7.89it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010845


100%|██████████| 732/732 [00:31<00:00, 23.13it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.0111
EarlyStopping counter: 1 out of 15
Epoch 40...


100%|██████████| 765/765 [01:37<00:00,  7.82it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010838


100%|██████████| 732/732 [00:31<00:00, 23.26it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.010852
Epoch 41...


100%|██████████| 765/765 [01:36<00:00,  7.90it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010822


100%|██████████| 732/732 [00:31<00:00, 23.30it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.011834
EarlyStopping counter: 1 out of 15
Epoch 42...


100%|██████████| 765/765 [01:37<00:00,  7.88it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010779


100%|██████████| 732/732 [00:31<00:00, 23.14it/s]


Mean val loss: 0.010774


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 43...


100%|██████████| 765/765 [01:37<00:00,  7.81it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010768


100%|██████████| 732/732 [00:32<00:00, 22.53it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.011
EarlyStopping counter: 1 out of 15
Epoch 44...


100%|██████████| 765/765 [01:37<00:00,  7.84it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010724


100%|██████████| 732/732 [00:31<00:00, 23.23it/s]


Mean val loss: 0.010772


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 45...


100%|██████████| 765/765 [01:37<00:00,  7.86it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.01073


100%|██████████| 732/732 [00:32<00:00, 22.63it/s]


Mean val loss: 0.010773


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 46...


100%|██████████| 765/765 [01:36<00:00,  7.90it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010684


100%|██████████| 732/732 [00:31<00:00, 23.02it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.010874
EarlyStopping counter: 1 out of 15
Epoch 47...


100%|██████████| 765/765 [01:37<00:00,  7.88it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010696


100%|██████████| 732/732 [00:31<00:00, 23.07it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.011026
EarlyStopping counter: 2 out of 15
Epoch 48...


100%|██████████| 765/765 [01:36<00:00,  7.89it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010711


100%|██████████| 732/732 [00:32<00:00, 22.55it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.010947
EarlyStopping counter: 3 out of 15
Epoch 49...


100%|██████████| 765/765 [01:37<00:00,  7.85it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010634


100%|██████████| 732/732 [00:31<00:00, 23.17it/s]


Mean val loss: 0.010669


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 50...


100%|██████████| 765/765 [01:38<00:00,  7.73it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010645


100%|██████████| 732/732 [00:32<00:00, 22.18it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.010899
EarlyStopping counter: 1 out of 15
Epoch 51...


100%|██████████| 765/765 [01:36<00:00,  7.89it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010602


100%|██████████| 732/732 [00:31<00:00, 23.06it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.0108
EarlyStopping counter: 2 out of 15
Epoch 52...


100%|██████████| 765/765 [01:37<00:00,  7.86it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010581


100%|██████████| 732/732 [00:31<00:00, 23.16it/s]


Mean val loss: 0.010639


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 53...


100%|██████████| 765/765 [01:37<00:00,  7.88it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010552


100%|██████████| 732/732 [00:31<00:00, 22.89it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.010661
EarlyStopping counter: 1 out of 15
Epoch 54...


100%|██████████| 765/765 [01:37<00:00,  7.86it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010553


100%|██████████| 732/732 [00:31<00:00, 23.23it/s]


Mean val loss: 0.010631


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 55...


100%|██████████| 765/765 [01:37<00:00,  7.87it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010522


100%|██████████| 732/732 [00:31<00:00, 23.14it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.010885
EarlyStopping counter: 1 out of 15
Epoch 56...


100%|██████████| 765/765 [01:37<00:00,  7.85it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010516


100%|██████████| 732/732 [00:31<00:00, 23.12it/s]


Mean val loss: 0.010567


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 57...


100%|██████████| 765/765 [01:36<00:00,  7.90it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010513


100%|██████████| 732/732 [00:32<00:00, 22.44it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.011044
EarlyStopping counter: 1 out of 15
Epoch 58...


100%|██████████| 765/765 [01:37<00:00,  7.87it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.01046


100%|██████████| 732/732 [00:32<00:00, 22.73it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.010693
EarlyStopping counter: 2 out of 15
Epoch 59...


100%|██████████| 765/765 [01:38<00:00,  7.75it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010446


100%|██████████| 732/732 [00:32<00:00, 22.39it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.010708
EarlyStopping counter: 3 out of 15
Epoch 60...


100%|██████████| 765/765 [01:38<00:00,  7.79it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010406


100%|██████████| 732/732 [00:31<00:00, 22.94it/s]


Mean val loss: 0.010535


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 61...


100%|██████████| 765/765 [01:37<00:00,  7.85it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010419


100%|██████████| 732/732 [00:32<00:00, 22.84it/s]


Mean val loss: 0.01052


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 62...


100%|██████████| 765/765 [01:37<00:00,  7.88it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010388


100%|██████████| 732/732 [00:32<00:00, 22.71it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.010603
EarlyStopping counter: 1 out of 15
Epoch 63...


100%|██████████| 765/765 [01:38<00:00,  7.76it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010356


100%|██████████| 732/732 [00:31<00:00, 23.07it/s]


Mean val loss: 0.010521


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 64...


100%|██████████| 765/765 [01:36<00:00,  7.90it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010381


100%|██████████| 732/732 [00:31<00:00, 23.11it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.01067
EarlyStopping counter: 1 out of 15
Epoch 65...


100%|██████████| 765/765 [01:37<00:00,  7.87it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010321


100%|██████████| 732/732 [00:31<00:00, 23.18it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.01137
EarlyStopping counter: 2 out of 15
Epoch 66...


100%|██████████| 765/765 [01:37<00:00,  7.83it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010317


100%|██████████| 732/732 [00:32<00:00, 22.86it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.010629
EarlyStopping counter: 3 out of 15
Epoch 67...


100%|██████████| 765/765 [01:37<00:00,  7.88it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010308


100%|██████████| 732/732 [00:31<00:00, 23.09it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.011208
EarlyStopping counter: 4 out of 15
Epoch 68...


100%|██████████| 765/765 [01:38<00:00,  7.77it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010291


100%|██████████| 732/732 [00:31<00:00, 23.26it/s]


Mean val loss: 0.010482


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 69...


100%|██████████| 765/765 [01:37<00:00,  7.85it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010255


100%|██████████| 732/732 [00:31<00:00, 23.14it/s]


Mean val loss: 0.010423


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 70...


100%|██████████| 765/765 [01:37<00:00,  7.85it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010263


100%|██████████| 732/732 [00:32<00:00, 22.48it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.010442
EarlyStopping counter: 1 out of 15
Epoch 71...


100%|██████████| 765/765 [01:36<00:00,  7.89it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010217


100%|██████████| 732/732 [00:31<00:00, 23.13it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.010542
EarlyStopping counter: 2 out of 15
Epoch 72...


100%|██████████| 765/765 [01:37<00:00,  7.86it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010209


100%|██████████| 732/732 [00:31<00:00, 23.11it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.010562
EarlyStopping counter: 3 out of 15
Epoch 73...


100%|██████████| 765/765 [01:37<00:00,  7.85it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010181


100%|██████████| 732/732 [00:31<00:00, 23.17it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.010533
EarlyStopping counter: 4 out of 15
Epoch 74...


100%|██████████| 765/765 [01:36<00:00,  7.89it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010188


100%|██████████| 732/732 [00:31<00:00, 23.14it/s]


Mean val loss: 0.010389


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 75...


100%|██████████| 765/765 [01:37<00:00,  7.88it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010139


100%|██████████| 732/732 [00:31<00:00, 23.04it/s]


Mean val loss: 0.010378


  0%|          | 0/765 [00:00<?, ?it/s]

Epoch 76...


100%|██████████| 765/765 [01:36<00:00,  7.89it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010155


100%|██████████| 732/732 [00:31<00:00, 23.26it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.011104
EarlyStopping counter: 1 out of 15
Epoch 77...


100%|██████████| 765/765 [01:37<00:00,  7.83it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010149


100%|██████████| 732/732 [00:32<00:00, 22.65it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.0104
EarlyStopping counter: 2 out of 15
Epoch 78...


100%|██████████| 765/765 [01:38<00:00,  7.77it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010101


100%|██████████| 732/732 [00:33<00:00, 22.03it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.010679
EarlyStopping counter: 3 out of 15
Epoch 79...


100%|██████████| 765/765 [01:37<00:00,  7.89it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.01011


100%|██████████| 732/732 [00:31<00:00, 23.16it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.010492
EarlyStopping counter: 4 out of 15
Epoch 80...


100%|██████████| 765/765 [01:37<00:00,  7.86it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010068


100%|██████████| 732/732 [00:31<00:00, 23.12it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.010398
EarlyStopping counter: 5 out of 15
Epoch 81...


100%|██████████| 765/765 [01:37<00:00,  7.85it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010085


100%|██████████| 732/732 [00:31<00:00, 23.24it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.010586
EarlyStopping counter: 6 out of 15
Epoch 82...


100%|██████████| 765/765 [01:37<00:00,  7.84it/s]
  0%|          | 0/732 [00:00<?, ?it/s]

Mean train loss: 0.010075


100%|██████████| 732/732 [00:32<00:00, 22.20it/s]
  0%|          | 0/765 [00:00<?, ?it/s]

Mean val loss: 0.010538
EarlyStopping counter: 7 out of 15
Epoch 83...


 74%|███████▍  | 567/765 [01:12<00:24,  8.03it/s]

## Dump experiment results

In [None]:
import pickle
import json

In [None]:
p = './{}_exp_results.pkl'.format(data_conf['conf_name'])
with open(p, 'wb') as f:
    pickle.dump(results, f)

In [28]:
# p = './{}_exp_results.pkl'.format(data_conf['conf_name'])
# with open(p, 'rb') as f:
#     results = pickle.load(f)

In [11]:
import torch
model.load_state_dict(torch.load('./{}.pt'.format(data_conf['conf_name'])))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

## Train and validation losses

In [None]:
from itertools import chain

In [None]:
train_losses = list(chain(*[item for item in results['train_losses']]))
val_losses = list(chain(*[item for item in results['val_losses']]))

In [None]:
def moving_average(a, n=5) :
    ret = np.cumsum([a[0]] * (n - 1) + a, dtype=float)
    ret[n:] = ret[n:] - ret[:-n]
    return ret[n - 1:] / n

In [None]:
plt.figure(figsize=(10, 10))
plt.title('Moving-averaged batch losses')
plt.plot(np.arange(len(train_losses)), moving_average(train_losses), label='train')
plt.plot(np.arange(len(val_losses)), moving_average(val_losses), label='validation')

plt.legend(loc='best')
plt.yscale('log')

# plt.ylim([1e-2, 1])
plt.show()

In [None]:
mean_train_loss = [np.mean(item) for item in results['train_losses']]
mean_val_loss = [np.mean(item) for item in results['val_losses']]

In [None]:
plt.figure(figsize=(10, 10))
plt.title('Epoch losses')
plt.plot(np.arange(len(mean_train_loss)) + 1, mean_train_loss, label='train')
plt.plot(np.arange(len(mean_val_loss)) + 1, mean_val_loss, label='val')

plt.yscale('log')
plt.legend(loc='best')

plt.xlim([1, len(mean_train_loss) + 1])
plt.show()

## Results

In [None]:
import pandas as pd

In [None]:
from visualization_utils import make_df

In [None]:
df = make_df(results, model_name='basic')
df

In [None]:
print('Mean   IOU: {:.5}'.format(df['iou'].mean()))
print('Std    IOU: {:.5}'.format(df['iou'].std()))
print('Min    IOU: {:.5}'.format(df['iou'].min()))
print('Median IOU: {:.5}'.format(df['iou'].median()))

## Check loss to loss ratio

In [40]:
criterion.losses

(WeightedLoss(
   (loss): CrossEntropyLoss()
 ), WeightedLoss(
   (loss): DiceLoss()
 ))

In [3]:
from stack_segmentation.training import make_joint_loss

In [46]:
import torch

In [70]:
# loss_config[0] = {'loss': 'BCE',
#   'weight': 0.5,
#   'params': {}}
# loss_config
# crit = make_joint_loss(device=device, loss_config=loss_config)

[{'loss': 'BCE', 'weight': 0.5, 'params': {}},
 {'loss': 'Dice',
  'weight': 0.5,
  'params': {'mode': 'multiclass',
   'log_loss': True,
   'from_logits': True,
   'smooth': 1,
   'eps': 1e-07}}]

In [67]:
a = []
b = []
for i, (x, y) in enumerate(dataloader_val):
    if i > 100:
        break
    pred = model(torch.from_numpy(x).to(device))
    y = torch.from_numpy(y).to(device)
    a.append(crit.losses[0](pred, y).cpu().data.numpy())
    b.append(crit.losses[1](pred, y).cpu().data.numpy())

In [71]:
print('Mean 0 loss: {:.4f}\nMEan 1 loss: {:.4f}\nMean of ratios: {:.4f}\nRatio of means: {:.4f}'
      .format(np.mean(a), np.mean(b), np.mean(np.array(a) / np.array(b)), np.mean(a) / np.mean(b)))

Mean 0 loss: 0.0050
MEan 1 loss: 0.0088
Mean of ratios: 0.6099
Ratio of means: 0.5694
