In [1]:
import torch    
torch.manual_seed(123)
import random
random.seed(123)

import torch.nn as nn
import os
import shutil
import itertools

from chofer_torchex.utils.data.collate import dict_sample_target_iter_concat
from chofer_torchex.utils.functional import collection_cascade, cuda_cascade
from chofer_tda_datasets import Reddit_5K
from jmlr_2018_code.utils import *
from sklearn.model_selection import ShuffleSplit
from chofer_torchex.nn.slayer import SLayerRationalHat, LinearRationalStretchedBirthLifeTimeCoordinateTransform, prepare_batch


from collections import Counter, defaultdict

%matplotlib notebook



from torch.utils.data import DataLoader

os.environ['CUDA_VISIBLE_DEVICES'] = str(1)

def save_extract_birth_times(tensor):
    if tensor.ndimension() == 0:
        return tensor
    else:
        return tensor[:, 0].unsqueeze(dim=1)

class RedditCollate:   
    def __init__(self, dataset, cuda=True):
        self.cuda = cuda
        self.dataset = dataset
        
    def __call__(self, sample_target_iter):
        x, y = dict_sample_target_iter_concat(sample_target_iter)

        for k in self.dataset.keys_not_essential:
            batch_view = x[k]
            x[k] = prepare_batch(batch_view, 2)
            
        for k in self.dataset.keys_essential:
            batch_view = x[k]
            x[k] = prepare_batch(batch_view, 1)            

        y = torch.LongTensor(y)    

        if self.cuda:
            # Shifting the necessary parts of the prepared batch to the cuda
            x = {k: collection_cascade(v,
                                       lambda x: isinstance(x, tuple),
                                       lambda x: (x[0].cuda(), x[1].cuda(), x[2], x[3]))
                 for k, v in x.items()}

            y = y.cuda()

        return x, y
    

class train_env:
    n_epochs = 200
    lr_initial = 0.01
    lr_epoch_step = 20
    batch_size = 100
    train_size = 0.9
    nu = 0.01


In [2]:
dataset = Reddit_5K(root_dir='/scratch1/chofer/jmlr2018_data/')
dataset.keys_essential = ('DegreeVertexFiltration_dim_0_essential', 'DegreeVertexFiltration_dim_1_essential')
dataset.keys_not_essential = ('DegreeVertexFiltration_dim_0',)
dataset.keys_of_interrest = dataset.keys_essential + dataset.keys_not_essential

Found data!


In [3]:


def reduce_essential_dgms(x):
    for k in dataset.keys_essential:
        x[k] = save_extract_birth_times(x[k])
        
    return x

def coordinate_transform(x):
    t = LinearRationalStretchedBirthLifeTimeCoordinateTransform(nu=train_env.nu)
    for k in dataset.keys_not_essential:
        x[k] = t(x[k])
        
    return x

dataset.data_transforms = \
[
    lambda x: {k: x[k] for k in dataset.keys_of_interrest},
    numpy_to_torch_cascade,
    reduce_essential_dgms,
    coordinate_transform    
]

In [4]:
reddit_collate = RedditCollate(dataset)         

In [None]:
def LinearCell(n_in, n_out):
    m = nn.Sequential(nn.Linear(n_in, n_out), 
                      nn.BatchNorm1d(n_out), 
                      nn.ReLU())
    m.out_features = m[0].out_features
    return m


def Slayer(n_elements, point_dim):
    return SLayerRationalHat(n_elements, point_dimension=2, radius_init=0.1)   


class Reddit5KModel(nn.Module):    
    def __init__(self):
        super().__init__()  
        
        dim_0_n_elements = 200
        dim_0_ess_n_elements = 100
        dim_1_ess_n_elements = 100
        
        self.dim_0 = Slayer(dim_0_n_elements, 2)
        self.dim_0_ess = Slayer(dim_0_ess_n_elements, 1)
        self.dim_1_ess = Slayer(dim_1_ess_n_elements, 1)      
        
        self.dim_0_linear = LinearCell(dim_0_n_elements, int(dim_0_n_elements/2))
        self.dim_0_ess_linear = LinearCell(dim_0_ess_n_elements, int(dim_0_ess_n_elements/2))
        self.dim_1_ess_linear = LinearCell(dim_1_ess_n_elements, int(dim_1_ess_n_elements/2))
        
        cls_in_size = self.dim_0_linear.out_features + self.dim_0_ess_linear.out_features + self.dim_1_ess_linear.out_features
        self.classifer = nn.Sequential(
                                       LinearCell(cls_in_size, int(cls_in_size/2)),
#                                        nn.Dropout(0.2),
#                                        LinearCell(int(cls_in_size/2), int(cls_in_size/2)), 
                                       nn.Linear(int(cls_in_size/2), 5))
                         
    def forward(self, x):
        x_dim_0 = self.dim_0(x['DegreeVertexFiltration_dim_0'])        
        x_dim_0 = self.dim_0_linear(x_dim_0)
        
        x_dim_0_ess = self.dim_0_ess(x['DegreeVertexFiltration_dim_0_essential'])   
        x_dim_0_ess = self.dim_0_ess_linear(x_dim_0_ess)
        
        x_dim_1_ess = self.dim_1_ess(x['DegreeVertexFiltration_dim_1_essential'])       
        x_dim_1_ess = self.dim_1_ess_linear(x_dim_1_ess)
        
        x = torch.cat([x_dim_0, x_dim_0_ess, x_dim_1_ess], dim=1)
        
        x = self.classifer(x)        
        
        return x
        
    def centers_init(self):
        dim_0 = []
        for i in range(self.dim_0.centers.size(0)):
            x = random.uniform(0, 1)
            y = random.uniform(0, 1-x)
            dim_0.append((x,y))
        self.dim_0.centers.data = torch.Tensor(dim_0)
        
        self.dim_0_ess.centers.data.uniform_(0, 1)
        self.dim_1_ess.centers.data.uniform_(0, 1)
            
            
       

        

In [None]:
stats_of_runs = []
def experiment():      
    splitter = ShuffleSplit(n_splits=10, 
                            train_size=train_env.train_size, 
                            test_size=1-train_env.train_size, 
                            random_state=123)
    train_test_splits = list(splitter.split(X=dataset.labels, y=dataset.labels))
    train_test_splits = [(train_i.tolist(), test_i.tolist()) for train_i, test_i in train_test_splits]
    
    for run_i, (train_i, test_i) in enumerate(train_test_splits):
        print('')
        print('Run', run_i)
        
        model = Reddit5KModel()
        model.centers_init()
        model.cuda()

        stats = defaultdict(list)
        stats_of_runs.append(stats)
        
        opt=torch.optim.SGD(model.parameters(), lr=train_env.lr_initial, momentum=0.9)

        for i_epoch in range(1, train_env.n_epochs+1):      

            model.train()

            train_sampler = [i for i in train_i] 
            random.shuffle(train_sampler)
            dl_train = DataLoader(dataset,
                              batch_size=train_env.batch_size, 
                              collate_fn=reddit_collate,
                              sampler=train_sampler)

            dl_test = DataLoader(dataset,
                                 batch_size=train_env.batch_size, 
                                 collate_fn=reddit_collate, 
                                 sampler=test_i)

            epoch_loss = 0

            if i_epoch % train_env.lr_epoch_step == 0:
                adapt_lr(opt, lambda lr: lr*0.5)

            for i_batch, (x, y) in enumerate(dl_train, 1):              

                y = torch.autograd.Variable(y)

                def closure():
                    opt.zero_grad()
                    y_hat = model(x)            
                    loss = nn.functional.cross_entropy(y_hat, y)   
                    loss.backward()
                    return loss

                loss = opt.step(closure)

                epoch_loss += float(loss)
                stats['loss_by_batch'].append(float(loss))
                stats['centers'].append(model.dim_0.centers.data.cpu().numpy())

                print("Epoch {}/{}, Batch {}/{}".format(i_epoch, train_env.n_epochs, i_batch, len(dl_train)), end="       \r")

            stats['train_loss_by_epoch'].append(epoch_loss/len(dl_train))

            print("\n\r testing...")
            model.eval()    
            true_samples = 0
            seen_samples = 0
            epoch_test_loss = 0
            for i_batch, (x, y) in enumerate(dl_test):

                y_hat = model(x)
                epoch_test_loss += float(nn.functional.cross_entropy(y_hat, torch.autograd.Variable(y.cuda())).data)

                y_hat = y_hat.max(dim=1)[1].data.long()

                true_samples += (y_hat == y).sum()
                seen_samples += y.size(0)  

            stats['test_accuracy'].append(true_samples/seen_samples)
            stats['test_loss_by_epoch'].append(epoch_test_loss/len(dl_test))
            print(true_samples/seen_samples)       
        
experiment()        


Run 0
Epoch 1/200, Batch 45/45       
 testing...
0.466
Epoch 2/200, Batch 45/45       
 testing...
0.506
Epoch 3/200, Batch 45/45       
 testing...
0.52
Epoch 4/200, Batch 45/45       
 testing...
0.522
Epoch 5/200, Batch 45/45       
 testing...
0.542
Epoch 6/200, Batch 45/45       
 testing...
0.534
Epoch 7/200, Batch 45/45       
 testing...
0.52
Epoch 8/200, Batch 45/45       
 testing...
0.516
Epoch 9/200, Batch 45/45       
 testing...
0.534
Epoch 10/200, Batch 45/45       
 testing...
0.536
Epoch 11/200, Batch 45/45       
 testing...
0.54
Epoch 12/200, Batch 45/45       
 testing...
0.538
Epoch 13/200, Batch 45/45       
 testing...
0.518
Epoch 14/200, Batch 45/45       
 testing...
0.532
Epoch 15/200, Batch 45/45       
 testing...
0.526
Epoch 16/200, Batch 45/45       
 testing...
0.52
Epoch 17/200, Batch 45/45       
 testing...
0.54
Epoch 18/200, Batch 45/45       
 testing...
0.538
Epoch 19/200, Batch 45/45       
 testing...
0.54
Epoch 20/200, Batch 45/45       
 testi

Epoch 121/200, Batch 45/45       
 testing...
0.556
Epoch 122/200, Batch 45/45       
 testing...
0.552
Epoch 123/200, Batch 45/45       
 testing...
0.548
Epoch 124/200, Batch 45/45       
 testing...
0.554
Epoch 125/200, Batch 45/45       
 testing...
0.548
Epoch 126/200, Batch 45/45       
 testing...
0.556
Epoch 127/200, Batch 45/45       
 testing...
0.552
Epoch 128/200, Batch 45/45       
 testing...
0.552
Epoch 129/200, Batch 45/45       
 testing...
0.548
Epoch 130/200, Batch 45/45       
 testing...
0.552
Epoch 131/200, Batch 45/45       
 testing...
0.558
Epoch 132/200, Batch 45/45       
 testing...
0.556
Epoch 133/200, Batch 45/45       
 testing...
0.558
Epoch 134/200, Batch 45/45       
 testing...
0.56
Epoch 135/200, Batch 45/45       
 testing...
0.55
Epoch 136/200, Batch 45/45       
 testing...
0.558
Epoch 137/200, Batch 45/45       
 testing...
0.554
Epoch 138/200, Batch 45/45       
 testing...
0.548
Epoch 139/200, Batch 45/45       
 testing...
0.542
Epoch 140/200,

0.57
Epoch 81/200, Batch 45/45       
 testing...
0.558
Epoch 82/200, Batch 45/45       
 testing...
0.568
Epoch 83/200, Batch 45/45       
 testing...
0.576
Epoch 84/200, Batch 45/45       
 testing...
0.562
Epoch 85/200, Batch 45/45       
 testing...
0.57
Epoch 86/200, Batch 45/45       
 testing...
0.568
Epoch 87/200, Batch 45/45       
 testing...
0.572
Epoch 88/200, Batch 45/45       
 testing...
0.576
Epoch 89/200, Batch 45/45       
 testing...
0.57
Epoch 90/200, Batch 45/45       
 testing...
0.564
Epoch 91/200, Batch 45/45       
 testing...
0.574
Epoch 92/200, Batch 45/45       
 testing...
0.566
Epoch 93/200, Batch 45/45       
 testing...
0.582
Epoch 94/200, Batch 45/45       
 testing...
0.562
Epoch 95/200, Batch 45/45       
 testing...
0.58
Epoch 96/200, Batch 45/45       
 testing...
0.572
Epoch 97/200, Batch 45/45       
 testing...
0.564
Epoch 98/200, Batch 45/45       
 testing...
0.564
Epoch 99/200, Batch 45/45       
 testing...
0.566
Epoch 100/200, Batch 45/45   

Epoch 199/200, Batch 45/45       
 testing...
0.528
Epoch 200/200, Batch 45/45       
 testing...
0.54

Run 4
Epoch 1/200, Batch 45/45       
 testing...
0.472
Epoch 2/200, Batch 45/45       
 testing...
0.512
Epoch 3/200, Batch 45/45       
 testing...
0.492
Epoch 4/200, Batch 45/45       
 testing...
0.526
Epoch 5/200, Batch 45/45       
 testing...
0.51
Epoch 6/200, Batch 45/45       
 testing...
0.51
Epoch 7/200, Batch 45/45       
 testing...
0.506
Epoch 8/200, Batch 45/45       
 testing...
0.5
Epoch 9/200, Batch 45/45       
 testing...
0.524
Epoch 10/200, Batch 45/45       
 testing...
0.512
Epoch 11/200, Batch 45/45       
 testing...
0.544
Epoch 12/200, Batch 45/45       
 testing...
0.502
Epoch 13/200, Batch 45/45       
 testing...
0.55
Epoch 14/200, Batch 45/45       
 testing...
0.52
Epoch 15/200, Batch 45/45       
 testing...
0.508
Epoch 16/200, Batch 45/45       
 testing...
0.52
Epoch 17/200, Batch 45/45       
 testing...
0.524
Epoch 18/200, Batch 45/45       
 testi

Epoch 119/200, Batch 45/45       
 testing...
0.562
Epoch 120/200, Batch 45/45       
 testing...
0.554
Epoch 121/200, Batch 45/45       
 testing...
0.566
Epoch 122/200, Batch 45/45       
 testing...
0.556
Epoch 123/200, Batch 45/45       
 testing...
0.546
Epoch 124/200, Batch 45/45       
 testing...
0.552
Epoch 125/200, Batch 45/45       
 testing...
0.558
Epoch 126/200, Batch 45/45       
 testing...
0.564
Epoch 127/200, Batch 45/45       
 testing...
0.56
Epoch 128/200, Batch 45/45       
 testing...
0.566
Epoch 129/200, Batch 45/45       
 testing...
0.562
Epoch 130/200, Batch 45/45       
 testing...
0.564
Epoch 131/200, Batch 45/45       
 testing...
0.56
Epoch 132/200, Batch 45/45       
 testing...
0.552
Epoch 133/200, Batch 45/45       
 testing...
0.564
Epoch 134/200, Batch 45/45       
 testing...
0.564
Epoch 135/200, Batch 45/45       
 testing...
0.568
Epoch 136/200, Batch 45/45       
 testing...
0.556
Epoch 137/200, Batch 45/45       
 testing...
0.564
Epoch 138/200,

0.538
Epoch 79/200, Batch 45/45       
 testing...
0.536
Epoch 80/200, Batch 45/45       
 testing...
0.538
Epoch 81/200, Batch 45/45       
 testing...
0.522
Epoch 82/200, Batch 45/45       
 testing...
0.526
Epoch 83/200, Batch 45/45       
 testing...
0.528
Epoch 84/200, Batch 45/45       
 testing...
0.54
Epoch 85/200, Batch 45/45       
 testing...
0.532
Epoch 86/200, Batch 45/45       
 testing...
0.528
Epoch 87/200, Batch 45/45       
 testing...
0.536
Epoch 88/200, Batch 45/45       
 testing...
0.526
Epoch 89/200, Batch 45/45       
 testing...
0.53
Epoch 90/200, Batch 45/45       
 testing...
0.546
Epoch 91/200, Batch 45/45       
 testing...
0.534
Epoch 92/200, Batch 45/45       
 testing...
0.542
Epoch 93/200, Batch 45/45       
 testing...
0.53
Epoch 94/200, Batch 45/45       
 testing...
0.534
Epoch 95/200, Batch 45/45       
 testing...
0.528
Epoch 96/200, Batch 45/45              
 testing...
0.542
Epoch 97/200, Batch 45/45       
 testing...
0.53
Epoch 98/200, Batch 45

Epoch 38/200, Batch 45/45       
 testing...
0.572
Epoch 39/200, Batch 45/45       
 testing...
0.552
Epoch 40/200, Batch 45/45       
 testing...
0.56
Epoch 41/200, Batch 45/45       
 testing...
0.56
Epoch 42/200, Batch 45/45       
 testing...
0.548
Epoch 43/200, Batch 45/45       
 testing...
0.564
Epoch 44/200, Batch 45/45       
 testing...
0.562
Epoch 45/200, Batch 45/45       
 testing...
0.56
Epoch 46/200, Batch 45/45       
 testing...
0.576
Epoch 47/200, Batch 45/45       
 testing...
0.558
Epoch 48/200, Batch 45/45       
 testing...
0.564
Epoch 49/200, Batch 45/45       
 testing...
0.566
Epoch 50/200, Batch 45/45       
 testing...
0.564
Epoch 51/200, Batch 45/45       
 testing...
0.56
Epoch 52/200, Batch 45/45       
 testing...
0.562
Epoch 53/200, Batch 45/45       
 testing...
0.556
Epoch 54/200, Batch 45/45       
 testing...
0.554
Epoch 55/200, Batch 45/45       
 testing...
0.57
Epoch 56/200, Batch 45/45       
 testing...
0.558
Epoch 57/200, Batch 45/45       
 te

In [None]:
print(np.mean([np.mean(s['test_accuracy'][-10:]) for s in stats_of_runs]))
[np.mean(s['test_accuracy'][-10:]) for s in stats_of_runs]

In [None]:
# model = Reddit5KModel()
# model.centers_init()
# model.cuda()

# stats = defaultdict(list)
# opt=torch.optim.SGD(model.parameters(), lr=train_env.lr_initial, momentum=0.9)

# for i_epoch in range(1, train_env.n_epochs+1):      
    
#     model.train()
    
#     train_sampler = [i for i in train_i] 
#     random.shuffle(train_sampler)
#     dl_train = DataLoader(dataset,
#                       batch_size=train_env.batch_size, 
#                       collate_fn=reddit_collate,
#                       sampler=train_sampler)

#     dl_test = DataLoader(dataset,
#                          batch_size=train_env.batch_size, 
#                          collate_fn=reddit_collate, 
#                          sampler=test_i)
    
#     epoch_loss = 0
    
#     if i_epoch % train_env.lr_epoch_step == 0:
#         adapt_lr(opt, lambda lr: lr*0.5)
    
#     for i_batch, (x, y) in enumerate(dl_train, 1):              
        
#         y = torch.autograd.Variable(y)
        
#         def closure():
#             opt.zero_grad()
#             y_hat = model(x)            
#             loss = nn.functional.cross_entropy(y_hat, y)   
#             loss.backward()
#             return loss

#         loss = opt.step(closure)
        
#         epoch_loss += float(loss)
#         stats['loss_by_batch'].append(float(loss))
#         stats['centers'].append(model.dim_0.centers.data.cpu().numpy())
        
#         print("Epoch {}/{}, Batch {}/{}".format(i_epoch, train_env.n_epochs, i_batch, len(dl_train)), end="       \r")
  
#     stats['train_loss_by_epoch'].append(epoch_loss/len(dl_train))
        
#     print("\n\r testing...")
#     model.eval()    
#     true_samples = 0
#     seen_samples = 0
#     epoch_test_loss = 0
#     for i_batch, (x, y) in enumerate(dl_test):

#         y_hat = model(x)
#         epoch_test_loss += float(nn.functional.cross_entropy(y_hat, torch.autograd.Variable(y.cuda())).data)

#         y_hat = y_hat.max(dim=1)[1].data.long()

#         true_samples += (y_hat == y).sum()
#         seen_samples += y.size(0)  
     
#     stats['test_accuracy'].append(true_samples/seen_samples)
#     stats['test_loss_by_epoch'].append(epoch_test_loss/len(dl_test))
#     print(true_samples/seen_samples)

In [None]:
plt.figure()
stats = stats_of_runs[4]
c_start = stats['centers'][0]
c_end = stats['centers'][-1]

plt.plot(c_start[:,0], c_start[:, 1], 'bo', label='center initialization')
plt.plot(c_end[:,0], c_end[:, 1], 'ro', label='center learned')

all_centers = numpy.stack(stats['centers'], axis=0)
for i in range(all_centers.shape[1]):
    points = all_centers[:,i, :]
    plt.plot(points[:, 0], points[:, 1], '-k')
    
plt.legend()

plt.figure()
plt.plot(stats['train_loss_by_epoch'], label='train_loss')
plt.plot(stats['test_loss_by_epoch'], label='test_loss')

plt.legend()
plt.show()

In [None]:
print(model.dim_0.exponent)
print(model.dim_0.sharpness)