In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

from siamese.data import SiameseDataSource
from siamese.model import (
    resnet_baseline, ModelRunner, LossCallback)

In [3]:
TRAIN_FOLDER = "/home/ivb/nvme/data/raw/train"
TRAIN_CSV = "/home/ivb/nvme/data/train.csv"

INFER_FOLDER = "/home/ivb/nvme/data/raw/test"

In [4]:
loaders = SiameseDataSource.prepare_loaders(
    mode="train",
    n_workers=8,
    batch_size=512,
    train_folder=TRAIN_FOLDER,
    train_csv=TRAIN_CSV,
    train_folds=[1, 2, 3, 4], 
    valid_folds=[5],
    #infer_folder=INFER_FOLDER,
)

Train samples: 11264
Train batches: 22
Valid samples: 2560
Valid batches: 5


In [5]:
loaders.keys()

odict_keys(['train', 'valid'])

In [6]:
model = resnet_baseline(resnet=dict(arch="resnet18", pooling="GlobalAvgPool2d"))
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=2, verbose=True)

In [7]:
import collections
from catalyst.dl.callbacks import (
    ClassificationLossCallback, 
    Logger, TensorboardLogger,
    OptimizerCallback, SchedulerCallback, CheckpointCallback, 
    PrecisionCallback, OneCycleLR)

# the only tricky part
n_epochs = 50
logdir = "./logs"

callbacks = collections.OrderedDict()

callbacks["loss"] = LossCallback()
callbacks["optimizer"] = OptimizerCallback()

# OneCylce custom scheduler callback
# callbacks["scheduler"] = OneCycleLR(
#     cycle_len=n_epochs,
#     div=3, cut_div=4, momentum_range=(0.95, 0.85))

# Pytorch scheduler callback
callbacks["saver"] = CheckpointCallback()
callbacks["logger"] = Logger()
callbacks["tflogger"] = TensorboardLogger()

In [8]:
runner = ModelRunner(
    model=model, 
    criterion=criterion, 
    optimizer=optimizer, 
    scheduler=scheduler)

In [9]:
runner.train(
    loaders=loaders, 
    callbacks=callbacks, 
    logdir=logdir,
    epochs=n_epochs, verbose=True)

0 * Epoch (train): 100% 22/22 [00:23<00:00,  1.00s/it, base/batch_time=0.90072, base/data_time=0.04743, base/sample_per_second=568.43392, loss=0.44308, lr=0.00100, momentum=0.90000]
0 * Epoch (valid): 100% 5/5 [00:06<00:00,  1.60s/it, base/batch_time=0.90288, base/data_time=0.05757, base/sample_per_second=567.07233, loss=0.46691, lr=0.00100, momentum=0.90000]
[2019-01-21 00:47:42,481] 0 * Epoch (train) metrics: base/data_time: 0.39314 | base/batch_time: 0.51225 | base/sample_per_second: 5844.61411 | lr: 0.00100 | momentum: 0.90000 | loss: 0.57532
[2019-01-21 00:47:42,483] 0 * Epoch (valid) metrics: base/data_time: 1.03359 | base/batch_time: 1.20845 | base/sample_per_second: 4493.06403 | lr: 0.00100 | momentum: 0.90000 | loss: 0.47719
[2019-01-21 00:47:42,483] 

1 * Epoch (train): 100% 22/22 [00:21<00:00,  1.92it/s, base/batch_time=0.05345, base/data_time=0.04519, base/sample_per_second=9578.59930, loss=0.36165, lr=0.00100, momentum=0.90000]
1 * Epoch (valid): 100% 5/5 [00:05<00:00,  1.

Top best models:
./logs/checkpoint.None.44.pth.tar	0.1058
./logs/checkpoint.None.48.pth.tar	0.1208
./logs/checkpoint.None.45.pth.tar	0.1209
./logs/checkpoint.None.46.pth.tar	0.1211
./logs/checkpoint.None.43.pth.tar	0.1294


In [10]:
runner.train(
    loaders=loaders, 
    callbacks=callbacks, 
    logdir=logdir,
    epochs=n_epochs, start_epoch=50, verbose=True)

50 * Epoch (train): 100% 22/22 [00:21<00:00,  2.00it/s, base/batch_time=0.04987, base/data_time=0.04316, base/sample_per_second=10267.62314, loss=0.09246, lr=0.00100, momentum=0.90000]
50 * Epoch (valid): 100% 5/5 [00:05<00:00,  1.31s/it, base/batch_time=0.06541, base/data_time=0.05820, base/sample_per_second=7827.39024, loss=0.10896, lr=0.00100, momentum=0.90000]
[2019-01-21 01:19:28,610] 50 * Epoch (train) metrics: base/data_time: 0.49386 | base/batch_time: 0.50620 | base/sample_per_second: 5226.20491 | lr: 0.00100 | momentum: 0.90000 | loss: 0.11182
[2019-01-21 01:19:28,610] 50 * Epoch (train) metrics: base/data_time: 0.49386 | base/batch_time: 0.50620 | base/sample_per_second: 5226.20491 | lr: 0.00100 | momentum: 0.90000 | loss: 0.11182
[2019-01-21 01:19:28,612] 50 * Epoch (valid) metrics: base/data_time: 0.98657 | base/batch_time: 0.99419 | base/sample_per_second: 5945.64743 | lr: 0.00100 | momentum: 0.90000 | loss: 0.08765
[2019-01-21 01:19:28,612] 50 * Epoch (valid) metrics: bas

Top best models:
./logs/checkpoint.None.98.pth.tar	0.0675
./logs/checkpoint.None.88.pth.tar	0.0689
./logs/checkpoint.None.95.pth.tar	0.0714
./logs/checkpoint.None.91.pth.tar	0.0758
./logs/checkpoint.None.75.pth.tar	0.0760


In [None]:
runner.train(
    loaders=loaders, 
    callbacks=callbacks, 
    logdir=logdir,
    epochs=200, start_epoch=100, verbose=True)

100 * Epoch (train): 100% 22/22 [00:21<00:00,  1.91it/s, base/batch_time=0.05327, base/data_time=0.04593, base/sample_per_second=9611.99751, loss=0.08934, lr=0.00100, momentum=0.90000]
100 * Epoch (valid): 100% 5/5 [00:05<00:00,  1.31s/it, base/batch_time=0.06276, base/data_time=0.05770, base/sample_per_second=8157.95458, loss=0.05788, lr=0.00100, momentum=0.90000]
[2019-01-21 01:47:41,494] 100 * Epoch (train) metrics: base/data_time: 0.50590 | base/batch_time: 0.51914 | base/sample_per_second: 5773.93849 | lr: 0.00100 | momentum: 0.90000 | loss: 0.07215
[2019-01-21 01:47:41,494] 100 * Epoch (train) metrics: base/data_time: 0.50590 | base/batch_time: 0.51914 | base/sample_per_second: 5773.93849 | lr: 0.00100 | momentum: 0.90000 | loss: 0.07215
[2019-01-21 01:47:41,494] 100 * Epoch (train) metrics: base/data_time: 0.50590 | base/batch_time: 0.51914 | base/sample_per_second: 5773.93849 | lr: 0.00100 | momentum: 0.90000 | loss: 0.07215
[2019-01-21 01:47:41,497] 100 * Epoch (valid) metrics