In [None]:
import argparse
import mxnet as mx
import os
import pandas as pd

from tqdm import tqdm as tqdm

import rsna_heme as rsna

In [None]:
args = argparse.Namespace()

args.base_dir = '/mnt/Data2/datasets/rsna_heme'
args.n_splits = 4
args.fold = 0
args.model_name = 'resnet50_v2'
args.pretrained = True
args.classes = 6
args.ctx = [mx.gpu(0)]
args.batch_size = 20
args.num_workers = 4
args.start_epoch = 0
args.epochs = 10
args.optimizer = 'adam'
args.optimizer_params = {'learning_rate': 0.0001}

# Checkpoint
args.save_interval = args.epochs
args.save_dir = '../params'
args.val_interval = 1
fold_str = 'fold' + str(args.fold) if hasattr(args, 'fold') else 'foldAll'
time_str = rsna.util.get_time()
net_name = '_'.join((args.model_name, fold_str, time_str))
args.save_prefix = os.path.join(args.save_dir, net_name, net_name)

In [None]:
train_dataset = mx.gluon.data.vision.datasets.ImageRecordDataset(os.path.join(args.base_dir, 'train.rec'), flag=1, transform=rsna.transforms.common_transform)

In [None]:
df = pd.read_pickle('labels.pkl')

In [None]:
net = rsna.cnn.get_model(args)

In [None]:
train_sampler = rsna.io.CVSampler(groups = df['cv_group'], n_splits = args.n_splits, i_fold = args.fold)
val_sampler   = rsna.io.CVSampler(groups = df['cv_group'], n_splits = args.n_splits, i_fold = args.fold, mode = 'test', shuffle = False)

In [None]:
train_data = mx.gluon.data.DataLoader(train_dataset.transform_first(rsna.transforms.train_transform), batch_size=args.batch_size, sampler=train_sampler, num_workers=args.num_workers)
val_data   = mx.gluon.data.DataLoader(train_dataset.transform_first(rsna.transforms.val_transform),   batch_size=args.batch_size, sampler=val_sampler,   num_workers=args.num_workers)

In [None]:
trainer = mx.gluon.Trainer(net.collect_params(), args.optimizer, args.optimizer_params)
loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()

In [None]:
logger = rsna.logger.Logger(os.path.join(args.save_dir, net_name), net_name)
logger.tb_setup()
logger.log(args)

best_loss = 1
for epoch in range(args.epochs):
    
    train_loss, train_acc = rsna.cnn.process_data(net, loss, tqdm(train_data), args.ctx, trainer=trainer)

    # Epoch logging
    if (epoch + 1) % args.val_interval == 0:
        mx.nd.waitall()
        val_loss, val_acc = rsna.cnn.process_data(net, loss, val_data, args.ctx)
        metrics = [train_acc, train_loss, val_acc, val_loss]
        rsna.logger.log_epoch_hooks(logger, epoch, metrics)
        best_loss = rsna.io.save_params(net, best_loss, val_loss, epoch, args.save_interval, args.save_prefix)

logger.close()