In [1]:
import sys, os, time
#os.environ["CUDA_VISIBLE_DEVICES"]="0"
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as distrib
import torch.multiprocessing as mp
from torch.utils.data import Dataset, DataLoader
import gym
import numpy as np
%matplotlib notebook
#%matplotlib tk
import matplotlib.pyplot as plt
#plt.switch_backend('Qt5Agg') #('Qt5Agg')
import foundation as fd
from foundation import models
from foundation import util
from foundation import train

np.set_printoptions(linewidth=120)

In [None]:
parser = train.setup_standard_options(no_config=True)

args = parser.parse_args([])

args.no_test = True

args.device = 'cuda:0'
args.seed = 0

args.logdate = True
args.tblog = False
args.txtlog = False
args.saveroot = 'trained_nets'
args.save_freq = -1

args.dataset = 'svhn'
# args.dataset = 'mnist'
# for emnist change link: 'http://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip'
args.use_val = True
args.val_per = 1/6

args.num_workers = 4
args.batch_size = 128

args.start_epoch = 0
args.epochs = 10

args.name = 'test_on_mnist'


now = time.strftime("%y-%m-%d-%H%M%S")
if args.logdate:
    args.name = os.path.join(args.name, now)
args.save_dir = os.path.join(args.saveroot, args.name)
print('Save dir: {}'.format(args.save_dir))

if args.tblog or args.txtlog:
    util.create_dir(args.save_dir)
    print('Logging in {}'.format(args.save_dir))
logger = util.Logger(args.save_dir, tensorboard=args.tblog, txt=args.txtlog)

# Set seed
if not hasattr(args, 'seed') or args.seed is None:
    args.seed = util.get_random_seed()
torch.manual_seed(args.seed)
np.random.seed(args.seed)
try:
    torch.cuda.manual_seed(args.seed)
except:
    pass

if not torch.cuda.is_available():
    args.device = 'cpu'
print('Using device {} - random seed set to {}'.format(args.device, args.seed))

In [None]:
datasets = train.load_data(args=args)
shuffles = [True, False, False]

loaders = [DataLoader(d, batch_size=args.batch_size, num_workers=args.num_workers) for d, s in zip(datasets, shuffles)]

trainloader, testloader = loaders[0], loaders[-1]
valloader = None if len(loaders) == 2 else loaders[1]

print('Input: {}, Output: {}'.format(args.din, args.dout))
print('traindata len={}, trainloader len={}'.format(len(datasets[0]), len(trainloader)))
if valloader is not None:
    print('valdata len={}, valloader len={}'.format(len(datasets[1]), len(valloader)))
print('testdata len={}, testloader len={}'.format(len(datasets[-1]), len(testloader)))
print('Batch size: {} samples'.format(args.batch_size))

In [None]:
# Define Model
args.total_samples = {'train': 0, 'val':0, 'test': 0}
epoch = 0
best_loss = None
all_train_stats = []
all_val_stats = []
all_test_stats = []

args.din_flat = int(np.product(args.din))

class Simple(fd.Visualizable, fd.Trainable_Model):
    def __init__(self, net):
        super().__init__(args.din, args.dout)
        self.criterion = nn.CrossEntropyLoss()
        self.net = net
        
        self.stats.new('confidence', 'accuracy')
        
    def forward(self, x):
        return self.net(x)
    
    def _visualize(self, info, logger=None):
        
        conf, pick = info.pred.max(-1)

        confidence = conf.detach()
        correct = pick.sub(info.y).eq(0).float().detach()

        self.stats.update('confidence', confidence.mean())
        self.stats.update('accuracy', correct.mean())
    
    def _step(self, batch, out=None):
        if out is None:
            out = util.TensorDict()
            
        x,y = batch
        
        pred = self(x)
        
        loss = self.criterion(pred, y)
        
        if self.train_me():
            self.optim.zero_grad()
            loss.backward()
            self.optim.step()
            
        out.loss = loss
        out.x = x
        out.y = y
        out.pred = pred
        return out
    

net = nn.Sequential(nn.Flatten(), models.make_MLP(args.din_flat, args.dout, hidden_dims=[], nonlin='prelu'))

model = Simple(net)
model.set_optim(optim_type='adam', lr=1e-3, weight_decay=1e-4, momentum=0.9)
scheduler = None#torch.optim.lr_scheduler.StepLR(optim, step_size=6, gamma=0.2)

model.to(args.device)
print(model)
print(model.optim)
print('Model has {} parameters'.format(util.count_parameters(model)))

In [None]:
# Reseed after model init
torch.manual_seed(args.seed)
np.random.seed(args.seed)
try:
    torch.cuda.manual_seed(args.seed)
except:
    pass


if args.no_test:
    print('Will not run test data after training')
else:
    raise NotImplementedError

In [None]:
for _ in range(args.epochs):

    model.reset()

    train_stats = util.StatsMeter()
    train_stats.shallow_join(model.stats)

    train_stats = train.run_epoch(model, trainloader, args, mode='train',
                                  epoch=epoch, print_freq=args.print_freq, logger=logger, silent=True,
                                  viz_criterion_args=args.viz_criterion_args,
                                  stats=train_stats, )

    all_train_stats.append(train_stats.copy())

    if valloader is not None:
        model.reset()

        val_stats = util.StatsMeter()
        val_stats.shallow_join(model.stats)

        val_stats = train.run_epoch(model, valloader, args, mode='val',
                                  epoch=epoch, print_freq=args.print_freq, logger=logger, silent=True,
                                  viz_criterion_args=args.viz_criterion_args,
                                  stats=val_stats, )

        all_val_stats.append(val_stats.copy())

    print('[ {} ] Epoch {} Train={:.3f} ({:.3f}), Val={:.3f} ({:.3f})'.format(
        time.strftime("%H:%M:%S"), epoch+1,
        train_stats['accuracy'].avg.item(), train_stats['loss'].avg.item(),
        val_stats['accuracy'].avg.item(), val_stats['loss'].avg.item(),
    ))

    if args.save_freq > 0 and epoch % args.save_freq == 0:


        ckpt = {
            'epoch': epoch+1,

            'args': args,

            'model_str': str(model),
            'model_state': model.state_dict(),
            'all_train_stats': all_train_stats,
        }
        if args.track_best:
            av_loss = train_stats['loss'].avg.item() if valloader is None else val_stats['loss'].avg.item()
            is_best = best_loss is None or av_loss < best_loss
            if is_best:
                best_loss = av_loss
                best_epoch = epoch

            ckpt['loss'] = av_loss
            ckpt['best_loss'] = best_loss
            ckpt['best_epoch'] = best_epoch
        if len(all_val_stats):
            ckpt['all_val_stats'] = all_val_stats
        path = save_checkpoint(ckpt, args.save_dir, is_best=is_best, epoch=epoch+1)
        print('--- checkpoint saved to {} ---'.format(path))

    epoch += 1

In [None]:
stat_key = 'accuracy'

figax = None
figax = util.plot_stat(all_train_stats, key=stat_key, figax=figax, label='train')
figax = util.plot_stat(all_val_stats, key=stat_key, figax=figax, label='val')
fig, ax = figax
plt.sca(ax)
plt.xlabel('Epochs')
plt.ylabel(stat_key)
plt.legend()