In [31]:
import os
import sys
import uuid
import glob
import torch
import numpy as np

from collections import defaultdict

from autoencoder import EncDec
from dataset import ds_random_subset, ds_monkey_patch_target

from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10, CIFAR100
from torchvision import transforms

from chofer_torchex.pershom import pershom_backend
vr_l1_persistence = pershom_backend.__C.VRCompCuda__vr_persistence

In [9]:
DS_ROOT = '/scratch_nas/chofer/data'

def dataset_factory(dataset=None, train_ratio=None):
    
    assert train_ratio is not None
    assert dataset is not None
    
    cifar_transform = transforms.Compose([transforms.ToTensor()])

    tiny_imagenet_transform = transforms.Compose([
        transforms.Resize((32,32)),
        transforms.ToTensor()])

    if dataset == 'cifar10':
        ds = CIFAR10(
            root=os.path.join(DS_ROOT, dataset),
            train=True, 
            transform=cifar_transform, 
            download=False)   
    elif dataset == 'cifar100':
         ds = CIFAR100(
            root=os.path.join(DS_ROOT, dataset),
            train=True, 
            transform=cifar_transform, 
            download=False)   
    elif dataset == 'tiny-imagenet-200':
        ds = TinyImageNet(
                root=os.path.join(DS_ROOT, dataset), 
                transform=tiny_imagenet_transform, 
                train=True)
    else:
        raise Exception()
    
    ds_monkey_patch_target(ds)
    if train_ratio < 1.0:
        ds = ds_random_subset(ds, train_ratio)
    
    return ds

In [15]:
def l1_loss(x_hat, x, reduce=True):
    l = (x - x_hat).abs().view(x.size(0), - 1).sum(dim=1)
    if reduce:
        l = l.mean() 
    return l

In [55]:
device = "cuda"

def train(root_folder, config):
    
    train_args = config['train_args']
    model_args = config['model_args']
        
    model = EncDec(**model_args).to(device)
    
    latent_dim = model.n_branches*model.out_features_branch
    branch_siz = model.out_features_branch
    ball_radius = 1.0 # hard-coded 
    
    optim = Adam(
        model.parameters(), 
        lr=train_args['learning_rate'])
    
    ds = dataset_factory(**config['data_args'])
    dl = DataLoader(ds, 
                    batch_size=train_args['batch_size'],
                    shuffle=True,
                    drop_last=True)
    
    log = defaultdict(list)
    
    model.train()
    for epoch in range(1,train_args['n_epochs']+1):
            
        for x,_ in dl:
            
            x = x.to(device)
            x_hat, z = model(x)
            
            top_loss = torch.tensor([0]).type_as(x_hat) 
            rec_loss = torch.tensor([0]).type_as(x_hat) 
            
            rec_loss = l1_loss(x_hat, x, reduce=True)
            
            lifetimes = []
            for i in range(0, latent_dim, branch_siz):
                pers = vr_l1_persistence(
                    z[:, i:i+branch_siz].contiguous(), # per-branch z_1,...,z_B
                    0, 0, 'l1')[0][0] # [0][0] gives non-essential in H_0
                
                if pers.dim() == 2:
                    pers = pers[:, 1] # all 0-dim. features have birth 0 in VR complex
                    lifetimes.append(pers.tolist())
                    top_loss += (pers - 2.0*ball_radius).abs().sum()
            
            log['lifetimes'].append(lifetimes)
            log['top_loss'].append(top_loss.item())
            log['rec_loss'].append(rec_loss.item())
            
            loss = train_args['rec_loss_w']*rec_loss + \
                   train_args['top_loss_w']*top_loss

            model.zero_grad()
            loss.backward()
            optim.step()

        print('{}: rec_loss: {:.4f} | top_loss: {:.4f}'.format(    
            epoch, 
            np.array(log['rec_loss'][-int(len(ds)/train_args['batch_size']):]).mean()*train_args['rec_loss_w'],
            np.array(log['top_loss'][-int(len(ds)/train_args['batch_size']):]).mean()*train_args['top_loss_w'])) 
        break

    
    basefile = os.path.join(root_folder, str(uuid.uuid4()))
    
    torch.save(model.state_dict(), '.'.join([basefile, 'model', 'pht']))
    
    out_data = [config, log]
    file_ext = ['config', 'log']
    for x,y in zip(out_data, file_ext):
        with open('.'.join([basefile, y, 'pickle']), 'wb') as fid:
            pickle.dump(x, fid)

In [65]:
config = {
    'data_args'  : {
        'dataset'     : 'cifar100',
        'train_ratio' : 0.5
    },
    'model_args' : {},
    'train_args' : {
        'n_epochs'      : 50,
        'learning_rate' : 1e-3,
        'batch_size'    : 100,
        'top_loss_w'    : 10.0,
        'rec_loss_w'    : 1.0
    }
}

#train('/tmp', config)

In [66]:
import json
with open('/tmp/config.json', 'w') as fid:
    json.dump(config, fid)

In [67]:
with open('/tmp/config.json', 'r') as fid:
    check = json.load(fid)
print(check)

{'data_args': {'dataset': 'cifar100', 'train_ratio': 0.5}, 'model_args': {}, 'train_args': {'n_epochs': 50, 'learning_rate': 0.001, 'batch_size': 100, 'top_loss_w': 10.0, 'rec_loss_w': 1.0}}
