In [None]:
import os
import cv2
import numpy as np

from torch.utils.data import DataLoader

import argus
from argus import Model
from argus.callbacks import MonitorCheckpoint, EarlyStopping, LoggingToFile

from src.datasets import DrawDataset, get_train_val_samples
from src.transforms import ImageTransform, DrawTransform
from src.argus_models import CnnFinetune
from src import config

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
def imshow(image, figsize=(3, 3)):
    plt.figure(figsize=figsize)
    plt.imshow(image)
    plt.show()

In [None]:
image_size = 128
image_pad = 8
image_line_width = 3
time_color = True
train_batch_size = 128
val_batch_size = 128
train_epoch_size = 1000000
val_key_id_path = '/workdir/data/val_key_ids_001.json'

# Data

In [None]:
train_samples, val_samples = get_train_val_samples(val_key_id_path)

In [None]:
draw_transform = DrawTransform(image_size, image_pad, image_line_width, time_color)
train_trns = ImageTransform(True)
train_dataset = DrawDataset(train_samples, draw_transform,
                            size=train_epoch_size, image_transform=train_trns)
val_trns = ImageTransform(False)
val_dataset = DrawDataset(val_samples, draw_transform, image_transform=val_trns)

train_loader = DataLoader(train_dataset, batch_size=train_batch_size, num_workers=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=val_batch_size, num_workers=8, shuffle=False)

In [None]:
n_images_to_draw = 3

for img, trg in train_loader:
    for i in range(n_images_to_draw):
        img_i = img[i, 0, :, :].numpy()
        print(config.IDX_TO_CLASS[trg[i].item()])
        imshow(img_i)
    break

# Model

In [None]:
params = {
    'nn_module': {
        'model_name': 'se_resnext50_32x4d',
        'num_classes': len(config.CLASSES),
        'pretrained': True,
        'dropout_p': 0.2
    },
    'optimizer': ('Adam', {'lr': 0.001}),
    'loss': 'CrossEntropyLoss',
    'device': 'cuda'
}

model = CnnFinetune(params)

In [None]:
experiment_name = 'cnn_fine_se_resnext50_001'

callbacks = [
    MonitorCheckpoint(f'/workdir/data/experiments/{experiment_name}', monitor='val_accuracy', max_saves=3),
    EarlyStopping(monitor='val_accuracy', patience=50),
    LoggingToFile(f'/workdir/data/experiments/{experiment_name}/log.txt')
]

model.fit(train_loader, 
          val_loader=val_loader,
          max_epochs=1000,
          callbacks=callbacks,
          metrics=['accuracy'])