In [None]:
import argparse
import datetime
import gluoncv
import mxnet as mx
import numpy as np
import os
import pandas as pd
import time

from tqdm import tqdm as tqdm
from mxnet import nd, autograd, gluon, image, init
from mxnet.gluon.data.vision import datasets

import rsna_heme as rsna

In [None]:
args = argparse.Namespace()

args.base_dir = '/mnt/Data2/datasets/rsna_heme'
args.n_splits = 4
args.fold = 0
args.model_name = 'resnet50_v2'
args.pretrained = True
args.classes = 6
args.ctx = [mx.gpu(0)]
args.batch_size = 20
args.num_workers = 4
args.start_epoch = 0
args.epochs = 10
args.optimizer = 'adam'
args.optimizer_params = {'learning_rate': 0.0001}

# Checkpoint
args.save_interval = args.epochs
args.save_dir = '../params'
args.val_interval = 1
fold_str = 'fold' + str(args.fold) if hasattr(args, 'fold') else 'foldAll'
time_str = rsna.util.get_time()
net_name = '_'.join((args.model_name, fold_str, time_str))
args.save_prefix = os.path.join(args.save_dir, net_name, net_name)

In [None]:
train_dataset = datasets.ImageRecordDataset(os.path.join(args.base_dir, 'train.rec'), flag=1, transform=rsna.transforms.common_transform)

In [None]:
df = pd.read_pickle('labels.pkl')

In [None]:
net = rsna.cnn.get_model(args)

In [None]:
train_sampler = rsna.io.CVSampler(groups = df['cv_group'], n_splits = args.n_splits, i_fold = args.fold)
val_sampler   = rsna.io.CVSampler(groups = df['cv_group'], n_splits = args.n_splits, i_fold = args.fold, mode = 'test', shuffle = False)

In [None]:
train_data = mx.gluon.data.DataLoader(train_dataset.transform_first(rsna.transforms.train_transform), batch_size=args.batch_size, sampler=train_sampler, num_workers=args.num_workers)
val_data   = mx.gluon.data.DataLoader(train_dataset.transform_first(rsna.transforms.val_transform),   batch_size=args.batch_size, sampler=val_sampler,   num_workers=args.num_workers)

In [None]:
trainer = gluon.Trainer(net.collect_params(), args.optimizer, args.optimizer_params)
loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()
metric_acc = mx.metric.Accuracy()
metric_loss = mx.metric.Loss()

In [None]:
def test(net, val_data, ctx):
    loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()
    metric_loss = mx.metric.Loss()
    metric_acc = mx.metric.Accuracy()
    for i, batch in enumerate(val_data):
        weights = nd.ones_like(batch[1]) * nd.array([2, 1, 1, 1, 1, 1])
        weights = gluon.utils.split_and_load(weights,  ctx_list=ctx, batch_axis=0, even_split=False)
        data    = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0, even_split=False)
        label   = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0, even_split=False)

        outputs = [net(X) for X in data]
        losses = [loss(yhat, y, w) for yhat, y, w in zip(outputs, label, weights)]
        
        metric_loss.update(label, losses)
        metric_acc.update([l[:,0] for l in label], [(nd.sign(o[:,0]) + 1) / 2 for o in outputs])
    
    _, val_loss = metric_loss.get()
    _, val_acc = metric_acc.get()

    return [val_loss, val_acc]

In [None]:
logger, sw = rsna.logging.start_logger(args)

lr_counter = 0
num_batch = len(train_data)
best_loss = 1
for epoch in range(args.epochs):
    tbar = tqdm(train_data)
    
    metric_acc.reset()
    metric_loss.reset()
    for i, batch in enumerate(tbar):
        weights = nd.ones_like(batch[1]) * nd.array([2, 1, 1, 1, 1, 1])
        weights = gluon.utils.split_and_load(weights,  ctx_list=args.ctx, batch_axis=0, even_split=False)
        data    = gluon.utils.split_and_load(batch[0], ctx_list=args.ctx, batch_axis=0, even_split=False)
        label   = gluon.utils.split_and_load(batch[1], ctx_list=args.ctx, batch_axis=0, even_split=False)
        with autograd.record():
            outputs = [net(X) for X in data]
            losses = [loss(yhat, y, w) for yhat, y, w in zip(outputs, label, weights)]
        for l in losses:
            l.backward()

        trainer.step(len(batch[0]))
        metric_acc.update([l[:,0] for l in label], [(nd.sign(o[:,0]) + 1) / 2 for o in outputs])
        metric_loss.update(label, losses)
        if i < (len(tbar) - 1):
            tbar.set_description('E %d | loss %.4f'%(epoch, metric_loss.get()[1]))

    _, train_loss = metric_loss.get()
    _, train_acc = metric_acc.get()

    # Epoch logging
    if (epoch + 1) % args.val_interval == 0:
        mx.nd.waitall()
        val_loss, val_acc = test(net, val_data, args.ctx)
        metrics = [train_acc, train_loss, val_acc, val_loss]
        rsna.logging.log_epoch_hooks(epoch, metrics, logger, sw)
        best_loss = rsna.io.save_params(net, best_loss, val_loss, epoch, args.save_interval, args.save_prefix)

sw.export_scalars(args.save_prefix + '_scalar_dict.json')
sw.close()