In [None]:
import torch
torch.manual_seed(123)
import random
random.seed(123)


import torch.nn as nn
import os
import shutil
import itertools


import core.config as config
from chofer_tda_datasets import Anon10kEigenvaluePredict
from core.utils import *

from torchph.nn.slayer import SLayerExponential, \
SLayerRational, \
LinearRationalStretchedBirthLifeTimeCoordinateTransform, \
prepare_batch, SLayerRationalHat
from sklearn.model_selection import ShuffleSplit
from collections import Counter, defaultdict
from torch.utils.data import DataLoader, SubsetRandomSampler
from collections import OrderedDict
from torch.autograd import Variable

from sklearn.model_selection import StratifiedShuffleSplit

%matplotlib notebook

os.environ['CUDA_VISIBLE_DEVICES'] = str(1)


class AnonCollate:   
    def __init__(self, cuda=True):
        self.cuda = cuda
        
    def __call__(self, sample_target_iter):
        x, y = [], []
        for x_i, y_i in sample_target_iter:
            x.append(x_i)
            y.append(y_i)

        x = prepare_batch(x, 2)            
        y = torch.Tensor(y)

        if self.cuda:
            # Shifting the necessary parts of the prepared batch to the cuda
            x = (x[0].cuda(), x[1].cuda(), x[2], x[3])
            y = y.cuda()

        return x, y
    

class train_env:
    n_epochs = 100
    lr_initial = 0.5
    lr_epoch_step = 10
    batch_size = 64
    train_size = 0.9
    nu = 0.01
    n_target_bins = 100

    
dataset = Anon10kEigenvaluePredict(data_root_folder_path=config.paths.data_root_dir)
dataset.keys_essential = ('dim_0_ess', 'dim_1_ess')
dataset.keys_not_essential = ('dim_0',)


coordinate_transform  = LinearRationalStretchedBirthLifeTimeCoordinateTransform(nu=train_env.nu)    

        
dataset.data_transforms = [
    lambda x: x['dim_0'][()],
    lambda x: torch.from_numpy(x).float(), 
    coordinate_transform
]


def histogramize(x):
    return np.histogram(x, density=True, bins=train_env.n_target_bins,range=(0,2))[0].tolist()
    

dataset.target_transforms = [histogramize]                      
reddit_collate = AnonCollate(cuda=False)   

In [None]:
def LinearCell(n_in, n_out):
    m = nn.Sequential(nn.Linear(n_in, n_out), 
                      nn.BatchNorm1d(n_out), 
                      nn.ReLU())
    m.out_features = m[0].out_features
    return m


def Slayer(n_elements, point_dim):
    return SLayerRationalHat(n_elements, point_dimension=2, radius_init=250)   


class AnonModel(nn.Module):    
    def __init__(self):
        super().__init__()  
        
        dim_0_n_elements = 100        
        self.dim_0 = Slayer(dim_0_n_elements, 2)          

        self.regressor = nn.Sequential(  
                                         nn.Tanh(),
                                         LinearCell(dim_0_n_elements, train_env.n_target_bins),
                                         LinearCell(train_env.n_target_bins,train_env.n_target_bins),                                   
                                         nn.Linear(train_env.n_target_bins, train_env.n_target_bins))
                         
    def forward(self, x):
        
        x = self.dim_0(x)
        x = self.regressor(x)        
        
        return x
        
    def centers_init(self, sample_target_iter):   
        x = [x for x, _ in sample_target_iter]
        dim_0 = torch.cat(x, dim=0)
        dim_0 = list({tuple(row) for row in dim_0})
        dim_0 = np.array(dim_0)
        kmeans = sklearn.cluster.KMeans(n_clusters=self.dim_0.centers.size(0), 
                                        init='k-means++', 
                                        random_state=123, 
                                        n_jobs=10, 
                                        n_init=2, )                           
        kmeans.fit(dim_0)
        centers = kmeans.cluster_centers_
        centers = torch.from_numpy(centers).float()
        self.dim_0.centers.data = centers


In [None]:
def experiment():      
    stats_of_runs = []
    splitter = ShuffleSplit(n_splits=10, 
                            train_size=train_env.train_size, 
                            test_size=1-train_env.train_size, 
                            random_state=123)
    
    train_test_splits = list(splitter.split(X=dataset.targets, y=dataset.targets))
    train_test_splits = [(train_i.tolist(), test_i.tolist()) for train_i, test_i in train_test_splits]
    
    for run_i, (train_i, test_i) in enumerate(train_test_splits):
        print('')
        print('Run', run_i)
        
        model = AnonModel()
#         model.centers_init([dataset[i] for i in train_i])
        model.cuda()

        stats = defaultdict(list)
        stats_of_runs.append(stats)
        
        opt=torch.optim.SGD(model.parameters(), lr=train_env.lr_initial, momentum=0.9)

        for i_epoch in range(1, train_env.n_epochs+1):      

            model.train()
            
            dl_train = DataLoader(dataset,
                              batch_size=train_env.batch_size, 
                              collate_fn=reddit_collate,
                              sampler=SubsetRandomSampler(train_i),
                              num_workers=5)

            dl_test = DataLoader(dataset,
                                 batch_size=train_env.batch_size, 
                                 collate_fn=reddit_collate, 
                                 sampler=SubsetRandomSampler(test_i),
                                 num_workers=5)

            epoch_loss = 0

            if i_epoch % train_env.lr_epoch_step == 0:
                adapt_lr(opt, lambda lr: lr*0.5)

            for i_batch, (x, y) in enumerate(dl_train, 1):
                
                x = (x[0].cuda(), x[1].cuda(), x[2], x[3])
                y = y.cuda()
                
                y = torch.autograd.Variable(y)

                def closure():
                    opt.zero_grad()
                    y_hat = model(x)            
                    loss = histogram_intersection_loss(y_hat, y)   
                    
                    loss.backward()
                    return loss

                loss = opt.step(closure)

                epoch_loss += float(loss)
                stats['loss_by_batch'].append(float(loss))
                stats['centers'].append(model.dim_0.centers.data.cpu().numpy())

                print("Epoch {}/{}, Batch {}/{}".format(i_epoch, train_env.n_epochs, i_batch, len(dl_train)), end="       \r")

            stats['train_loss_by_epoch'].append(epoch_loss/len(dl_train))

            
            # last epoch dump test results           
            if i_epoch == train_env.n_epochs:
                y_true = []
                y_pred = []
                for i_batch, (x, y_true_i) in enumerate(dl_test):    
                    x = (x[0].cuda(), x[1].cuda(), x[2], x[3])
                    y_pred_i = model(x)
                    
                    y_true.append(y_true_i.cpu())
                    y_pred.append(y_pred_i.data.cpu())
                    
                y_true = torch.cat(y_true, dim=0)
                y_pred = torch.cat(y_pred, dim=0)
                
                stats['y_true'] = y_true
                stats['y_pred'] = y_pred 
                
                stats['test_histogram_intersection'] = -histogram_intersection_loss(y_pred, y_true, reduce=False).numpy()
                print('')
                
        stats['model'] = model.cpu()
    return stats_of_runs
stats_of_runs = experiment()        

In [None]:
test_histogram_intersections = [np.array(r['test_histogram_intersection']).mean() for r in stats_of_runs]
print(np.mean(test_histogram_intersections))
print(np.std(test_histogram_intersections))