In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../..')

In [3]:
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

In [4]:
from stack_segmentation.stack import Stack

In [5]:
from stack_segmentation.io import make_dataloader, collate_fn_basic

from stack_segmentation.training import (
    handle_stacks_data, 
    make_optimization_task, 
    train_loop
)

from stack_segmentation.pipeline_config import (
    dataloaders_conf,
    train_conf,
    model_config, 
    aug_config,
    optimizer_config,
    loss_config,
    scheduler_config,
)

In [6]:
from exp_config import data_conf

## Parameters to tune

In [7]:
train_conf['device'] = 'cuda:1'
# train_conf['device'] = 'cpu'
train_conf['num_epochs'] = 500
train_conf

{'num_epochs': 500, 'device': 'cuda:1'}

In [8]:
data_conf['conf_name'] = 'exp_basic_adamw_lr5e-3_epoch_300_se_resnet101_encoder_soft_aug_k_1_weight10_patch64'
data_conf['patches'] = {
    'train': (64, 64, 1),
    'val': (64, 64, 1),
    'test': (64, 64, 1)
}
data_conf

{'conf_name': 'exp_basic_adamw_lr5e-3_epoch_300_se_resnet101_encoder_soft_aug_k_1_weight10_patch64',
 'stacks': [{'path': '../../data/carb96558',
   'slice_train': (slice(None, None, None),
    slice(None, None, None),
    slice(None, 230, None)),
   'slice_val': (slice(None, None, None),
    slice(None, None, None),
    slice(250, 470, None))},
  {'path': '../../data/SoilB-2',
   'slice_train': (slice(None, None, None),
    slice(None, None, None),
    slice(None, 230, None)),
   'slice_val': (slice(None, None, None),
    slice(None, None, None),
    slice(240, 460, None))},
  {'path': '../../data/Urna_22',
   'slice_train': (slice(None, None, None),
    slice(None, None, None),
    slice(None, 220, None)),
   'slice_val': (slice(None, None, None),
    slice(None, None, None),
    slice(245, 455, None))},
  {'path': '../../data/carb96558',
   'slice_test': (slice(None, None, None),
    slice(None, None, None),
    slice(490, None, None))},
  {'path': '../../data/carb71',
   'slice_tes

In [9]:
model_config['encoder_name'] = 'se_resnet101'
model_config

{'source': 'qubvel',
 'model_type': 'Unet',
 'encoder_name': 'se_resnet101',
 'encoder_weights': 'imagenet'}

In [10]:
optimizer_config['opt_type'] = 'AdamW'
optimizer_config['lr'] = 5e-3
optimizer_config['weight_decay'] = 5e-5
optimizer_config['amsgrad'] = True
# optimizer_config['nesterov'] = True
optimizer_config

{'opt_type': 'AdamW',
 'lr': 0.005,
 'weight_decay': 5e-05,
 'amsgrad': True,
 'nesterov': False,
 'momentum': 0.9,
 'centered': False}

In [11]:
aug_config['aug_type'] = 'soft'
aug_config['k'] = 1
aug_config['original_height'] = 64
aug_config['original_width'] = 64
aug_config

{'aug_type': 'soft', 'original_height': 64, 'original_width': 64, 'k': 1}

In [12]:
loss_config =[
    {
        'loss': 'BCE', 
        'weight': 1, 
        'params': {'weight': [1, 10]}},
    {
        'loss': 'Dice',
        'weight': 1, 
        'params': {
            'mode': 'multiclass',
            'classes': [1], # может быть, этот параметр не нужен
            'log_loss': True,
            'from_logits': True,
            'smooth': 1,
            'eps': 1e-7
        }
    }
]
loss_config

[{'loss': 'BCE', 'weight': 1, 'params': {'weight': [1, 10]}},
 {'loss': 'Dice',
  'weight': 1,
  'params': {'mode': 'multiclass',
   'classes': [1],
   'log_loss': True,
   'from_logits': True,
   'smooth': 1,
   'eps': 1e-07}}]

In [13]:
dataloaders_conf['train']['batch_size'] = 96
dataloaders_conf['val']['batch_size'] = 96
dataloaders_conf['test']['batch_size'] = 96
dataloaders_conf

{'train': {'batch_size': 96, 'num_workers': 16, 'shuffle': True},
 'val': {'batch_size': 96, 'num_workers': 16, 'shuffle': False},
 'test': {'batch_size': 96, 'num_workers': 16, 'shuffle': True}}

## Prepare train, validation and test data

In [14]:
data_train, data_val, data_test = handle_stacks_data(**data_conf)

720it [00:03, 213.63it/s]
100%|██████████| 720/720 [00:07<00:00, 95.14it/s] 
33120it [00:00, 242060.03it/s]
31680it [00:00, 159577.49it/s]
700it [00:02, 235.04it/s]
100%|██████████| 700/700 [00:06<00:00, 100.96it/s]
27830it [00:00, 147192.87it/s]
26620it [00:00, 252299.51it/s]
710it [00:03, 227.39it/s]
100%|██████████| 710/710 [00:06<00:00, 104.77it/s]
31680it [00:00, 141873.45it/s]
30240it [00:00, 129331.33it/s]
720it [00:00, 824.24it/s]
100%|██████████| 720/720 [00:07<00:00, 96.04it/s] 
33120it [00:00, 250892.38it/s]
720it [00:01, 481.73it/s]
100%|██████████| 720/720 [00:07<00:00, 96.36it/s] 
103680it [00:00, 180188.56it/s]
700it [00:01, 552.44it/s]
100%|██████████| 700/700 [00:06<00:00, 101.16it/s]
84700it [00:00, 229942.39it/s]
509it [00:00, 779.09it/s] 
100%|██████████| 509/509 [00:02<00:00, 197.75it/s]
32576it [00:00, 93369.68it/s]
700it [00:01, 355.83it/s]
100%|██████████| 700/700 [00:07<00:00, 99.28it/s] 
84700it [00:00, 232001.29it/s]
700it [00:00, 806.43it/s]
100%|██████████|

In [15]:
len(data_train), len(data_val), len(data_test)

(92630, 88540, 11)

In [16]:
dataloader_train = make_dataloader(
    samples=data_train, 
    collate_fn=collate_fn_basic,
    model_config=model_config,
    aug_config=aug_config,
    **dataloaders_conf['train']
)

dataloader_val = make_dataloader(
    samples=data_val, 
    collate_fn=collate_fn_basic,
    model_config=model_config,
    **dataloaders_conf['val']
)

dataloaders_test = {
    name: make_dataloader(
        samples=data, 
        collate_fn=collate_fn_basic,
        model_config=model_config,
        **dataloaders_conf['test']
    ) for name, data in data_test.items()}

## Create model and metrics

In [17]:
device = 'cuda:1'
# device = 'cpu'

In [18]:
model, criterion, optimizer, scheduler = make_optimization_task(
    device,
    model_config=model_config,
    loss_config=loss_config, 
    optimizer_config=optimizer_config,
    scheduler_config=scheduler_config)

Downloading: "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth" to /home/evlavrukhin/.cache/torch/checkpoints/se_resnet101-7e38fcc6.pth
100%|██████████| 189M/189M [09:16<00:00, 356kB/s]    


## Run experiment

In [19]:
from stack_segmentation.metrics import accuracy, precision, recall, f1, pr_auc, iou

In [20]:
metrics = {
    'accuracy': accuracy, 
    'precision': precision, 
    'recall': recall, 
    'f1': f1,
    'pr_auc': pr_auc, 
    'iou': iou,
}

In [None]:
results = train_loop(
    model=model,
    dataloader_train=dataloader_train, 
    dataloader_val=dataloader_val,
    dataloaders_test=dataloaders_test,
    criterion=criterion, 
    optimizer=optimizer, 
    scheduler=scheduler,
    metrics=metrics,
    exp_name=data_conf['conf_name'],
    **train_conf)

  0%|          | 0/965 [00:00<?, ?it/s]

Epoch 0...


100%|██████████| 965/965 [03:29<00:00,  4.61it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.041601


100%|██████████| 923/923 [00:53<00:00, 17.35it/s]
  0%|          | 0/965 [00:00<?, ?it/s]

Mean val loss: 0.02439
Epoch 1...


100%|██████████| 965/965 [03:30<00:00,  4.59it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.025287


100%|██████████| 923/923 [00:53<00:00, 17.21it/s]
  0%|          | 0/965 [00:00<?, ?it/s]

Mean val loss: 0.024477
EarlyStopping counter: 1 out of 15
Epoch 2...


100%|██████████| 965/965 [03:30<00:00,  4.58it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.023257


100%|██████████| 923/923 [00:53<00:00, 17.27it/s]


Mean val loss: 0.021591


  0%|          | 0/965 [00:00<?, ?it/s]

Epoch 3...


100%|██████████| 965/965 [03:31<00:00,  4.56it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.02207


100%|██████████| 923/923 [00:53<00:00, 17.24it/s]
  0%|          | 0/965 [00:00<?, ?it/s]

Mean val loss: 0.021636
EarlyStopping counter: 1 out of 15
Epoch 4...


100%|██████████| 965/965 [03:30<00:00,  4.58it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.021462


100%|██████████| 923/923 [00:53<00:00, 17.27it/s]
  0%|          | 0/965 [00:00<?, ?it/s]

Mean val loss: 0.023987
EarlyStopping counter: 2 out of 15
Epoch 5...


100%|██████████| 965/965 [03:30<00:00,  4.58it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.020821


100%|██████████| 923/923 [00:53<00:00, 17.25it/s]


Mean val loss: 0.021149


  0%|          | 0/965 [00:00<?, ?it/s]

Epoch 6...


100%|██████████| 965/965 [03:31<00:00,  4.57it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.020634


100%|██████████| 923/923 [00:53<00:00, 17.23it/s]


Mean val loss: 0.020306


  0%|          | 0/965 [00:00<?, ?it/s]

Epoch 7...


100%|██████████| 965/965 [03:31<00:00,  4.57it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.020119


100%|██████████| 923/923 [00:53<00:00, 17.30it/s]


Mean val loss: 0.019889


  0%|          | 0/965 [00:00<?, ?it/s]

Epoch 8...


100%|██████████| 965/965 [03:31<00:00,  4.57it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.020029


100%|██████████| 923/923 [00:53<00:00, 17.21it/s]


Mean val loss: 0.019853


  0%|          | 0/965 [00:00<?, ?it/s]

Epoch 9...


100%|██████████| 965/965 [03:30<00:00,  4.58it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.019839


100%|██████████| 923/923 [00:53<00:00, 17.23it/s]


Mean val loss: 0.01986


  0%|          | 0/965 [00:00<?, ?it/s]

Epoch 10...


100%|██████████| 965/965 [03:31<00:00,  4.57it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.01967


100%|██████████| 923/923 [00:53<00:00, 17.17it/s]
  0%|          | 0/965 [00:00<?, ?it/s]

Mean val loss: 0.02022
EarlyStopping counter: 1 out of 15
Epoch 11...


100%|██████████| 965/965 [03:31<00:00,  4.56it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.019571


100%|██████████| 923/923 [00:53<00:00, 17.22it/s]
  0%|          | 0/965 [00:00<?, ?it/s]

Mean val loss: 0.022035
EarlyStopping counter: 2 out of 15
Epoch 12...


100%|██████████| 965/965 [03:30<00:00,  4.57it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.019307


100%|██████████| 923/923 [00:53<00:00, 17.19it/s]


Mean val loss: 0.019502


  0%|          | 0/965 [00:00<?, ?it/s]

Epoch 13...


100%|██████████| 965/965 [03:31<00:00,  4.56it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.019206


100%|██████████| 923/923 [00:53<00:00, 17.24it/s]


Mean val loss: 0.019188


  0%|          | 0/965 [00:00<?, ?it/s]

Epoch 14...


100%|██████████| 965/965 [03:31<00:00,  4.56it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.019149


100%|██████████| 923/923 [00:53<00:00, 17.27it/s]
  0%|          | 0/965 [00:00<?, ?it/s]

Mean val loss: 0.02025
EarlyStopping counter: 1 out of 15
Epoch 15...


100%|██████████| 965/965 [03:31<00:00,  4.56it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.01897


100%|██████████| 923/923 [00:54<00:00, 17.08it/s]


Mean val loss: 0.019018


  0%|          | 0/965 [00:00<?, ?it/s]

Epoch 16...


100%|██████████| 965/965 [03:31<00:00,  4.57it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.018936


100%|██████████| 923/923 [00:53<00:00, 17.26it/s]
  0%|          | 0/965 [00:00<?, ?it/s]

Mean val loss: 0.02154
EarlyStopping counter: 1 out of 15
Epoch 17...


100%|██████████| 965/965 [03:31<00:00,  4.56it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.01888


 25%|██▍       | 227/923 [00:13<00:39, 17.57it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 923/923 [00:53<00:00, 17.22it/s]
  0%|          | 0/965 [00:00<?, ?it/s]

Mean val loss: 0.019651
EarlyStopping counter: 2 out of 15
Epoch 18...


100%|██████████| 965/965 [03:31<00:00,  4.56it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.018818


100%|██████████| 923/923 [00:53<00:00, 17.19it/s]
  0%|          | 0/965 [00:00<?, ?it/s]

Mean val loss: 0.019424
EarlyStopping counter: 3 out of 15
Epoch 19...


100%|██████████| 965/965 [03:31<00:00,  4.57it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.01872


100%|██████████| 923/923 [00:53<00:00, 17.23it/s]


Mean val loss: 0.019027


  0%|          | 0/965 [00:00<?, ?it/s]

Epoch 20...


100%|██████████| 965/965 [03:31<00:00,  4.56it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.018615


100%|██████████| 965/965 [03:31<00:00,  4.56it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.018517


100%|██████████| 923/923 [00:53<00:00, 17.23it/s]
  0%|          | 0/965 [00:00<?, ?it/s]

Mean val loss: 0.018896
EarlyStopping counter: 1 out of 15
Epoch 22...


100%|██████████| 965/965 [03:32<00:00,  4.55it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.01838


100%|██████████| 923/923 [00:53<00:00, 17.18it/s]
  0%|          | 0/965 [00:00<?, ?it/s]

Mean val loss: 0.019895
EarlyStopping counter: 2 out of 15
Epoch 23...


 52%|█████▏    | 498/965 [01:50<01:45,  4.42it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 923/923 [00:53<00:00, 17.17it/s]


Mean val loss: 0.018336


  0%|          | 0/965 [00:00<?, ?it/s]

Epoch 29...


100%|██████████| 965/965 [03:32<00:00,  4.54it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.018077


100%|██████████| 923/923 [00:53<00:00, 17.18it/s]
  0%|          | 0/965 [00:00<?, ?it/s]

Mean val loss: 0.020398
EarlyStopping counter: 1 out of 15
Epoch 30...


100%|██████████| 965/965 [03:32<00:00,  4.54it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.018017


100%|██████████| 923/923 [00:53<00:00, 17.24it/s]
  0%|          | 0/965 [00:00<?, ?it/s]

Mean val loss: 0.018495
EarlyStopping counter: 2 out of 15
Epoch 31...


  6%|▌         | 55/965 [00:13<03:17,  4.61it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 965/965 [03:32<00:00,  4.54it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.017687


100%|██████████| 923/923 [00:54<00:00, 17.05it/s]
  0%|          | 0/965 [00:00<?, ?it/s]

Mean val loss: 0.019154
EarlyStopping counter: 3 out of 15
Epoch 37...


100%|██████████| 965/965 [03:32<00:00,  4.55it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.017754


100%|██████████| 923/923 [00:53<00:00, 17.17it/s]


Mean val loss: 0.018134


  0%|          | 0/965 [00:00<?, ?it/s]

Epoch 38...


 68%|██████▊   | 655/965 [02:24<01:07,  4.58it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 923/923 [00:53<00:00, 17.12it/s]


Mean val loss: 0.017879


  0%|          | 0/965 [00:00<?, ?it/s]

Epoch 44...


100%|██████████| 965/965 [03:32<00:00,  4.54it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.017549


100%|██████████| 923/923 [00:54<00:00, 16.97it/s]


Mean val loss: 0.017872


  0%|          | 0/965 [00:00<?, ?it/s]

Epoch 45...


100%|██████████| 965/965 [03:31<00:00,  4.56it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.017527


 71%|███████▏  | 659/923 [00:38<00:14, 17.69it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 965/965 [03:32<00:00,  4.55it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.01733


100%|██████████| 923/923 [00:53<00:00, 17.12it/s]
  0%|          | 0/965 [00:00<?, ?it/s]

Mean val loss: 0.018142
EarlyStopping counter: 3 out of 15
Epoch 52...


100%|██████████| 965/965 [03:32<00:00,  4.54it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.01733


100%|██████████| 923/923 [00:53<00:00, 17.19it/s]
  0%|          | 0/965 [00:00<?, ?it/s]

Mean val loss: 0.01784
EarlyStopping counter: 4 out of 15
Epoch 53...


 40%|████      | 388/965 [01:25<02:07,  4.51it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 923/923 [00:54<00:00, 16.91it/s]
  0%|          | 0/965 [00:00<?, ?it/s]

Mean val loss: 0.017851
EarlyStopping counter: 5 out of 15
Epoch 59...


100%|██████████| 965/965 [03:31<00:00,  4.55it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.017184


100%|██████████| 923/923 [00:53<00:00, 17.18it/s]
  0%|          | 0/965 [00:00<?, ?it/s]

Mean val loss: 0.018484
EarlyStopping counter: 6 out of 15
Epoch 60...


100%|██████████| 965/965 [03:31<00:00,  4.56it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.017167


100%|██████████| 923/923 [00:53<00:00, 17.10it/s]
  0%|          | 0/965 [00:00<?, ?it/s]

Mean val loss: 0.019737
EarlyStopping counter: 7 out of 15
Epoch 61...


  1%|▏         | 13/965 [00:04<03:40,  4.32it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 965/965 [03:32<00:00,  4.55it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.017084


100%|██████████| 923/923 [00:53<00:00, 17.23it/s]
  0%|          | 0/965 [00:00<?, ?it/s]

Mean val loss: 0.017904
EarlyStopping counter: 2 out of 15
Epoch 67...


100%|██████████| 965/965 [03:32<00:00,  4.54it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.016993


100%|██████████| 923/923 [00:53<00:00, 17.18it/s]


Mean val loss: 0.017537


  0%|          | 0/965 [00:00<?, ?it/s]

Epoch 68...


 43%|████▎     | 413/965 [01:31<01:59,  4.63it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 923/923 [00:53<00:00, 17.19it/s]
  0%|          | 0/965 [00:00<?, ?it/s]

Mean val loss: 0.017629
EarlyStopping counter: 3 out of 15
Epoch 74...


100%|██████████| 965/965 [03:32<00:00,  4.54it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.016946


100%|██████████| 923/923 [00:53<00:00, 17.12it/s]
  0%|          | 0/965 [00:00<?, ?it/s]

Mean val loss: 0.01771
EarlyStopping counter: 4 out of 15
Epoch 75...


100%|██████████| 965/965 [03:32<00:00,  4.54it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.016949


 86%|████████▌ | 795/923 [00:46<00:07, 18.02it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 965/965 [03:32<00:00,  4.55it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.016884


100%|██████████| 923/923 [00:53<00:00, 17.11it/s]


Mean val loss: 0.017379


  0%|          | 0/965 [00:00<?, ?it/s]

Epoch 82...


100%|██████████| 965/965 [03:32<00:00,  4.55it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.016831


100%|██████████| 923/923 [00:53<00:00, 17.21it/s]
  0%|          | 0/965 [00:00<?, ?it/s]

Mean val loss: 0.01937
EarlyStopping counter: 1 out of 15
Epoch 83...


 96%|█████████▌| 925/965 [03:23<00:08,  4.61it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 923/923 [00:53<00:00, 17.12it/s]
  0%|          | 0/965 [00:00<?, ?it/s]

Mean val loss: 0.017759
EarlyStopping counter: 7 out of 15
Epoch 89...


100%|██████████| 965/965 [03:32<00:00,  4.55it/s]
  0%|          | 0/923 [00:00<?, ?it/s]

Mean train loss: 0.016672


 82%|████████▏ | 755/923 [00:44<00:09, 17.27it/s]

## Dump experiment results

In [None]:
import pickle
import json

In [None]:
p = './{}_exp_results.pkl'.format(data_conf['conf_name'])
with open(p, 'wb') as f:
    pickle.dump(results, f)

In [28]:
# p = './{}_exp_results.pkl'.format(data_conf['conf_name'])
# with open(p, 'rb') as f:
#     results = pickle.load(f)

In [11]:
import torch
model.load_state_dict(torch.load('./{}.pt'.format(data_conf['conf_name'])))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

## Train and validation losses

In [None]:
from itertools import chain

In [None]:
train_losses = list(chain(*[item for item in results['train_losses']]))
val_losses = list(chain(*[item for item in results['val_losses']]))

In [None]:
def moving_average(a, n=5) :
    ret = np.cumsum([a[0]] * (n - 1) + a, dtype=float)
    ret[n:] = ret[n:] - ret[:-n]
    return ret[n - 1:] / n

In [None]:
plt.figure(figsize=(10, 10))
plt.title('Moving-averaged batch losses')
plt.plot(np.arange(len(train_losses)), moving_average(train_losses), label='train')
plt.plot(np.arange(len(val_losses)), moving_average(val_losses), label='validation')

plt.legend(loc='best')
plt.yscale('log')

# plt.ylim([1e-2, 1])
plt.show()

In [None]:
mean_train_loss = [np.mean(item) for item in results['train_losses']]
mean_val_loss = [np.mean(item) for item in results['val_losses']]

In [None]:
plt.figure(figsize=(10, 10))
plt.title('Epoch losses')
plt.plot(np.arange(len(mean_train_loss)) + 1, mean_train_loss, label='train')
plt.plot(np.arange(len(mean_val_loss)) + 1, mean_val_loss, label='val')

plt.yscale('log')
plt.legend(loc='best')

plt.xlim([1, len(mean_train_loss) + 1])
plt.show()

## Results

In [None]:
import pandas as pd

In [None]:
from visualization_utils import make_df

In [None]:
df = make_df(results, model_name='basic')
df

In [None]:
print('Mean   IOU: {:.5}'.format(df['iou'].mean()))
print('Std    IOU: {:.5}'.format(df['iou'].std()))
print('Min    IOU: {:.5}'.format(df['iou'].min()))
print('Median IOU: {:.5}'.format(df['iou'].median()))

## Check loss to loss ratio

In [None]:
criterion.losses

In [None]:
from stack_segmentation.training import make_joint_loss

In [None]:
import torch

In [None]:
# loss_config[0] = {'loss': 'BCE',
#   'weight': 0.5,
#   'params': {}}
# loss_config
# crit = make_joint_loss(device=device, loss_config=loss_config)

In [None]:
crit = criterion
a = []
b = []
for i, (x, y) in enumerate(dataloader_val):
    if i > 100:
        break
    pred = model(torch.from_numpy(x).to(device))
    y = torch.from_numpy(y).to(device)
    a.append(crit.losses[0](pred, y).cpu().data.numpy())
    b.append(crit.losses[1](pred, y).cpu().data.numpy())

In [None]:
print('Mean 0 loss: {:.4f}\nMean 1 loss: {:.4f}\nMean of ratios: {:.4f}\nRatio of means: {:.4f}'
      .format(np.mean(a), np.mean(b), np.mean(np.array(a) / np.array(b)), np.mean(a) / np.mean(b)))