In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

from siamese.data import SiameseDataSource
from siamese.model import (
    resnet_baseline, ModelRunner, LossCallback)

In [3]:
TRAIN_FOLDER = "/home/ivb/nvme/data/raw/train"
TRAIN_CSV = "/home/ivb/nvme/data/train.csv"

INFER_FOLDER = "/home/ivb/nvme/data/raw/test"

In [4]:
loaders = SiameseDataSource.prepare_loaders(
    mode="train",
    n_workers=8,
    batch_size=512,
    train_folder=TRAIN_FOLDER,
    train_csv=TRAIN_CSV,
    train_folds=[1, 2, 3, 4], 
    valid_folds=[5],
    #infer_folder=INFER_FOLDER,
)

Train samples: 11264
Train batches: 22
Valid samples: 2560
Valid batches: 5


In [5]:
loaders.keys()

odict_keys(['train', 'valid'])

In [6]:
A, B = [], []
for b in loaders["train"]:
    A.extend(b["ImageFile0"])
    B.extend(b["ImageFile1"])

In [8]:
set(A).intersection(B)

{'7f3c05db4.jpg',
 '9c64d69c7.jpg',
 '43bb255c1.jpg',
 '82233a173.jpg',
 '7bb50baa6.jpg',
 'fab732129.jpg',
 'e9f68cf84.jpg',
 '9c8d437fc.jpg',
 '02156d140.jpg',
 'c943b92f0.jpg',
 '7eda4937f.jpg',
 'fe107366e.jpg',
 '5c6b15b4a.jpg',
 'b2bd65f41.jpg',
 '63492795c.jpg',
 '00dcd026f.jpg',
 'ec8c8d156.jpg',
 'a6549f63d.jpg',
 'd2f485da1.jpg',
 'e546ff196.jpg',
 '95ddfca91.jpg',
 '74eb2c291.jpg',
 '997aa3f23.jpg',
 '837cfabb3.jpg',
 'e62cd8b58.jpg',
 '4c1eaaf70.jpg',
 'f3c755869.jpg',
 'abe9c2389.jpg',
 'a43099f3b.jpg',
 'c94aa3f69.jpg',
 '04a74df45.jpg',
 '30c6c2701.jpg',
 '9984d6765.jpg',
 'ba46b75b0.jpg',
 '6c2b3fe29.jpg',
 '2cd81bee2.jpg',
 '47467f73e.jpg',
 '32ca51864.jpg',
 '8e13d0b4a.jpg',
 '3cd9242f5.jpg',
 'b58806652.jpg',
 'da83595c8.jpg',
 '46cfdedfa.jpg',
 '61d1fe1af.jpg',
 'aa69cda83.jpg',
 '0f78db9bd.jpg',
 'f399a8736.jpg',
 'efae7b997.jpg',
 '57f179c66.jpg',
 '89d28407e.jpg',
 '456890b76.jpg',
 '837efbd83.jpg',
 '496ceba9a.jpg',
 'c86235d01.jpg',
 '98ab97a0c.jpg',
 '03ef2929

In [6]:
model = resnet_baseline(resnet=dict(arch="resnet18", pooling="GlobalAvgPool2d"))
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=2, verbose=True)

In [7]:
import collections
from catalyst.dl.callbacks import (
    ClassificationLossCallback, 
    Logger, TensorboardLogger,
    OptimizerCallback, SchedulerCallback, CheckpointCallback, 
    PrecisionCallback, OneCycleLR)

# the only tricky part
n_epochs = 50
logdir = "./logs"

callbacks = collections.OrderedDict()

callbacks["loss"] = LossCallback()
callbacks["optimizer"] = OptimizerCallback()

# OneCylce custom scheduler callback
# callbacks["scheduler"] = OneCycleLR(
#     cycle_len=n_epochs,
#     div=3, cut_div=4, momentum_range=(0.95, 0.85))

# Pytorch scheduler callback
callbacks["saver"] = CheckpointCallback()
callbacks["logger"] = Logger()
callbacks["tflogger"] = TensorboardLogger()

In [8]:
runner = ModelRunner(
    model=model, 
    criterion=criterion, 
    optimizer=optimizer, 
    scheduler=scheduler)

In [9]:
runner.train(
    loaders=loaders, 
    callbacks=callbacks, 
    logdir=logdir,
    epochs=n_epochs, verbose=True)

0 * Epoch (train): 100% 22/22 [00:23<00:00,  1.00it/s, base/batch_time=0.92321, base/data_time=0.04499, base/sample_per_second=554.58763, loss=0.49286, lr=0.00100, momentum=0.90000]
0 * Epoch (valid): 100% 5/5 [00:06<00:00,  1.49s/it, base/batch_time=0.87097, base/data_time=0.05896, base/sample_per_second=587.85177, loss=0.49542, lr=0.00100, momentum=0.90000]
[2019-01-20 23:17:08,868] 0 * Epoch (train) metrics: base/data_time: 0.43108 | base/batch_time: 0.52970 | base/sample_per_second: 5660.64805 | lr: 0.00100 | momentum: 0.90000 | loss: 0.59272
[2019-01-20 23:17:08,869] 0 * Epoch (valid) metrics: base/data_time: 0.95702 | base/batch_time: 1.12536 | base/sample_per_second: 4672.03188 | lr: 0.00100 | momentum: 0.90000 | loss: 0.46534
[2019-01-20 23:17:08,870] 

1 * Epoch (train): 100% 22/22 [00:19<00:00,  1.95it/s, base/batch_time=0.05072, base/data_time=0.04409, base/sample_per_second=10095.30628, loss=0.40692, lr=0.00100, momentum=0.90000]
1 * Epoch (valid): 100% 5/5 [00:05<00:00,  1

Top best models:
./logs/checkpoint.None.49.pth.tar	0.1054
./logs/checkpoint.None.48.pth.tar	0.1079
./logs/checkpoint.None.47.pth.tar	0.1268
./logs/checkpoint.None.45.pth.tar	0.1269
./logs/checkpoint.None.42.pth.tar	0.1323
