In [1]:
import torch    
torch.manual_seed(123)
import random
random.seed(123)
# import torch.multiprocessing
# torch.multiprocessing.set_start_method("spawn")

import torch.nn as nn
import os
import shutil
import itertools

from chofer_torchex.utils.data.collate import dict_sample_target_iter_concat
from chofer_torchex.utils.functional import collection_cascade, cuda_cascade
from chofer_tda_datasets import AnonEigenvaluePredict
from chofer_tda_datasets.transforms import Hdf5GroupToDict
from jmlr_2018_code.utils import *
from chofer_torchex.nn.slayer import SLayerRationalHat, LinearRationalStretchedBirthLifeTimeCoordinateTransform, prepare_batch
from sklearn.model_selection import ShuffleSplit
from collections import Counter, defaultdict

%matplotlib notebook

from torch.utils.data import DataLoader

os.environ['CUDA_VISIBLE_DEVICES'] = str(2)


class AnonCollate:   
    def __init__(self, cuda=True):
        self.cuda = cuda
        
    def __call__(self, sample_target_iter):
        x, y = [], []
        for x_i, y_i in sample_target_iter:
            x.append(x_i)
            y.append(y_i)

        x = prepare_batch(x, 2)            
        y = torch.Tensor(y)

        if self.cuda:
            # Shifting the necessary parts of the prepared batch to the cuda
            x = (x[0].cuda(), x[1].cuda(), x[2], x[3])
            y = y.cuda()

        return x, y
    

class train_env:
    n_epochs = 200
    lr_initial = 0.01
    lr_epoch_step = 20
    batch_size = 64
    train_size = 0.9
    nu = 0.01
    n_target_bins = 100

    
dataset = AnonEigenvaluePredict(data_root_folder_path='/scratch1/chofer/jmlr2018_data/')
dataset.keys_essential = ('dim_0_ess', 'dim_1_ess')
dataset.keys_not_essential = ('dim_0',)


coordinate_transform  = LinearRationalStretchedBirthLifeTimeCoordinateTransform(nu=train_env.nu)    
    
    
def normalize(x):
    c = x[:, 1].max()    
    return x/c

        
dataset.data_transforms = [lambda x: x['dim_0'].value,
                           lambda x: torch.from_numpy(x).float(), 
                           coordinate_transform,
                           normalize
                           ]
def histogramize(x):
    return np.histogram(x, normed=True, bins=train_env.n_target_bins,range=(0,2))[0].tolist()
    

dataset.target_transforms = [histogramize]                      
reddit_collate = AnonCollate(cuda=True)   

In [None]:
def LinearCell(n_in, n_out):
    m = nn.Sequential(nn.Linear(n_in, n_out), 
                      nn.BatchNorm1d(n_out), 
                      nn.ReLU())
    m.out_features = m[0].out_features
    return m


def Slayer(n_elements, point_dim):
    return SLayerRationalHat(n_elements, point_dimension=2, radius_init=1)   


class AnonModel(nn.Module):    
    def __init__(self):
        super().__init__()  
        
        dim_0_n_elements = 200        
        self.dim_0 = Slayer(dim_0_n_elements, 2)          

        self.regressor = nn.Sequential(  
                                         nn.BatchNorm1d(dim_0_n_elements),
                                         LinearCell(dim_0_n_elements, train_env.n_target_bins),
#                                        nn.Dropout(0.2),
#                                        LinearCell(cls_in_size, int(cls_in_size/2)),                                        
                                         nn.Linear(train_env.n_target_bins, train_env.n_target_bins))
                         
    def forward(self, x):
        
        x = self.dim_0(x)
        x = self.regressor(x)        
        
        return x
        
    def centers_init(self, sample_target_iter):   
        x = [x for x, _ in sample_target_iter]
        dim_0 = torch.cat(x, dim=0)
        dim_0 = dim_0.numpy()
        kmeans = sklearn.cluster.KMeans(n_clusters=self.dim_0.centers.size(0), 
                                        init='k-means++', 
                                        random_state=123, 
                                        n_jobs=10, 
                                        n_init=2, )                           
        kmeans.fit(dim_0)
        centers = kmeans.cluster_centers_
        centers = torch.from_numpy(centers)
        self.dim_0.centers.data = centers


In [None]:
stats_of_runs = []
def experiment():      
    splitter = ShuffleSplit(n_splits=1, 
                            train_size=train_env.train_size, 
                            test_size=1-train_env.train_size, 
                            random_state=123)
    
    train_test_splits = list(splitter.split(X=dataset.targets, y=dataset.targets))
    train_test_splits = [(train_i.tolist(), test_i.tolist()) for train_i, test_i in train_test_splits]
    
    for run_i, (train_i, test_i) in enumerate(train_test_splits):
        print('')
        print('Run', run_i)
        
        model = AnonModel()
        model.centers_init([dataset[i] for i in train_i])
        model.cuda()

        stats = defaultdict(list)
        stats_of_runs.append(stats)
        
        opt=torch.optim.SGD(model.parameters(), lr=train_env.lr_initial, momentum=0.9)

        for i_epoch in range(1, train_env.n_epochs+1):      

            model.train()

            train_sampler = [i for i in train_i] 
            random.shuffle(train_sampler)
            
            dl_train = DataLoader(dataset,
                              batch_size=train_env.batch_size, 
                              collate_fn=reddit_collate,
                              sampler=train_sampler, )

            dl_test = DataLoader(dataset,
                                 batch_size=train_env.batch_size, 
                                 collate_fn=reddit_collate, 
                                 sampler=test_i,)

            epoch_loss = 0

            if i_epoch % train_env.lr_epoch_step == 0:
                adapt_lr(opt, lambda lr: lr*0.5)

            for i_batch, (x, y) in enumerate(dl_train, 1):  
                
                y = torch.autograd.Variable(y)

                def closure():
                    opt.zero_grad()
                    y_hat = model(x)            
                    loss = nn.functional.mse_loss(y_hat, y)   
                    loss.backward()
                    return loss

                loss = opt.step(closure)

                epoch_loss += float(loss)
                stats['loss_by_batch'].append(float(loss))
                stats['centers'].append(model.dim_0.centers.data.cpu().numpy())

                print("Epoch {}/{}, Batch {}/{}".format(i_epoch, train_env.n_epochs, i_batch, len(dl_train)), end="       \r")

            stats['train_loss_by_epoch'].append(epoch_loss/len(dl_train))

            print("\n\r testing...")
            model.eval()    
            true_samples = 0
            seen_samples = 0
            epoch_test_loss = 0
            for i_batch, (x, y) in enumerate(dl_test):
                y_hat = model(x)
                epoch_test_loss += float(nn.functional.mse_loss(y_hat, torch.autograd.Variable(y.cuda())).data) 
                seen_samples += y.size(0)
                
            avg_epoch_test_loss = epoch_test_loss/len(dl_test)   
            stats['test_loss_by_epoch'].append(avg_epoch_test_loss)            
            print(avg_epoch_test_loss)
        
        stats['train_i'] = train_i
        stats['test_i'] = test_i
        stats['model'] = model.cpu()
        
experiment()        


Run 0
Epoch 1/200, Batch 29/44       

In [None]:
plt.figure()
stats = stats_of_runs[0]
c_start = stats['centers'][0]
c_end = stats['centers'][-1]

plt.plot(c_start[:,0], c_start[:, 1], 'bo', label='center initialization')
plt.plot(c_end[:,0], c_end[:, 1], 'ro', label='center learned')

all_centers = numpy.stack(stats['centers'], axis=0)
for i in range(all_centers.shape[1]):
    points = all_centers[:,i, :]
    plt.plot(points[:, 0], points[:, 1], '-k')
    
plt.legend()

# for c_i_start, c_i_end in zip(c_start, c_end):
#     points = numpy.stack([c_i_start, c_i_end], axis=0)
#     plt.plot(points[:,0], points[:,1], '-k')
    




plt.figure()
plt.plot(stats['train_loss_by_epoch'], label='train_loss')
plt.plot(stats['test_loss_by_epoch'], label='test_loss')


plt.legend()
plt.show()

In [None]:
stats = stats_of_runs[0]
test_i = stats['test_i']

dl_test = DataLoader(dataset,
                     batch_size=train_env.batch_size, 
                     collate_fn=reddit_collate, 
                     sampler=test_i,)

y_true = []
y_pred = []
model=stats['model']
model.cuda()
model.eval()

for i_batch, (x, y_true_i) in enumerate(dl_test):
    y_pred_i = model(x)
    y_true += [a for a in y_true_i.cpu().numpy()]
    y_pred += [a for a in y_pred_i.data.cpu().numpy()]
    
for i, (y_true_i, y_pred_i) in enumerate(zip(y_true, y_pred)):
    fig = plt.figure()
    x = np.linspace(0, 2, train_env.n_target_bins)
    plt.plot(x, y_true_i, 'go', label='true',)
    plt.plot(x, y_pred_i, 'bo', label='pred')
    plt.legend()
    plt.savefig('./anon_images/' + str(i) + '.png')
    
plt.show()
    




In [None]:
x = np.random.rand(5, 10)
print(x)
[t for t in x]