In [None]:
import torch
torch.manual_seed(123)
import random
random.seed(123)


import torch.nn as nn
import os
import shutil
import itertools


import core.config as config
from chofer_tda_datasets import Mpeg7
from core.utils import *

from torchph.nn.slayer import SLayerExponential, \
SLayerRational, \
LinearRationalStretchedBirthLifeTimeCoordinateTransform, \
prepare_batch, SLayerRationalHat
from sklearn.model_selection import ShuffleSplit
from matplotlib.ticker import FormatStrFormatter
from collections import Counter, defaultdict
from torch.utils.data import DataLoader, Subset
from collections import OrderedDict
from torch.autograd import Variable

from sklearn.model_selection import StratifiedShuffleSplit

%matplotlib notebook

os.environ['CUDA_VISIBLE_DEVICES'] = str(0)


class train_env:
    nu = 0.01
    n_epochs = 200
    lr_initial = 0.01
    momentum = 0.9
    lr_epoch_step = 40
    batch_size = 100
    train_size = 0.9
    train_test_splits = None
    

coordinate_transform = LinearRationalStretchedBirthLifeTimeCoordinateTransform(nu=train_env.nu)
            

used_directions = ['dim_0_dir_{}'.format(i) for i in range(0, 32,2)]
dataset = Mpeg7(root_dir='/scratch1/chofer/jmlr2018_data/')
dataset.data_transforms = [
                           lambda x: {k: x[k] for k in used_directions}, 
                           numpy_to_torch_cascade,
                           lambda x: collection_cascade(x, 
                                                        lambda x: isinstance(x, torch.Tensor), 
                                                        lambda x: coordinate_transform(x))
                           ]


splitter = StratifiedShuffleSplit(n_splits=10, 
                                  train_size=train_env.train_size, 
                                  test_size=1-train_env.train_size, 
                                  random_state=123)
train_test_splits = list(splitter.split(X=dataset.labels, y=dataset.labels))
train_test_splits = [(train_i.tolist(), test_i.tolist()) for train_i, test_i in train_test_splits]
train_env.train_test_splits = train_test_splits

output_path = './results/mpeg7__slayer_parameter_learning_impact.pickle'

In [None]:
class PHTCollate:   
    def __init__(self, nu, cuda=True, rotation_augmentation=False):
        self.cuda = cuda
        self.rotation_augmentation = rotation_augmentation
        
    def __call__(self, sample_target_iter):
        
        augmented_samples = []
        if self.rotation_augmentation:
            samples, targets = [], []
            for x, y in sample_target_iter:                
                i = random.randint(0, len(used_directions)-1)
                shifted_keys = used_directions[i:] + used_directions[:i]                
                
                samples.append({k: x[ki] for k, ki in zip(used_directions, shifted_keys)})
                targets.append(y)
                
            sample_target_iter = zip(samples, targets)

        x, y = dict_sample_target_iter_concat(sample_target_iter)                                            
                                              
        for k in x.keys():
            batch_view = x[k]
            x[k] = prepare_batch(batch_view, 2)                  

        y = torch.LongTensor(y)    

        if self.cuda:
            # Shifting the necessary parts of the prepared batch to the cuda
            x = {k: collection_cascade(v,
                                       lambda x: isinstance(x, tuple),
                                       lambda x: (x[0].cuda(), x[1].cuda(), x[2], x[3]))
                 for k, v in x.items()}

            y = y.cuda()

        return x, y                       
    
collate_fn = PHTCollate(train_env.nu, cuda=True)

In [None]:
def Slayer(n_elements):
    return SLayerRationalHat(n_elements, radius_init=0.25, exponent=1)

def LinearCell(n_in, n_out):
    m = nn.Sequential(nn.Linear(n_in, n_out), 
                      nn.BatchNorm1d(n_out), 
                      nn.ReLU(),
                     )
    m.out_features = m[0].out_features
    return m


class MpegModel(nn.Module):
    def __init__(self, n_elements: int, train_slayers: bool):
        super().__init__()   
        self.n_elements = n_elements
        self.train_slayers = bool(train_slayers)
        
        self.slayers = ModuleDict()
        for k in used_directions:
            s = Slayer(self.n_elements)
            self.slayers[k] = nn.Sequential(s)            
            
        cls_in_dim = len(used_directions)*self.n_elements
        self.cls = nn.Sequential(
                                nn.Dropout(0.3),
                                LinearCell(cls_in_dim, int(cls_in_dim/4)),    
                                nn.Dropout(0.2),
                                LinearCell(int(cls_in_dim/4), int(cls_in_dim/16)),  
                                nn.Dropout(0.1),
                                nn.Linear(self.n_elements, 70))
        
    def forward(self, input):
        x = []
        for k in used_directions:            
            xx = self.slayers[k](input[k])
            x.append(xx)

        x = torch.cat(x, dim=1)          
        x = self.cls(x)       
                                              
        return x
    
    def k_means_center_init(self, sample_target_iter):
        centers = k_means_center_init(sample_target_iter, self.n_elements)
        
        for k, v in centers.items():
            self.slayers._modules[k][0].centers.data = v
            
    def min_max_random_init(self, sample_target_iter):
        centers = min_max_random_init(sample_target_iter, self.n_elements)
        
        for k, v in centers.items():
            self.slayers._modules[k][0].centers.data = v
            
    def parameters(self):
        p = []
        p.append(self.cls.parameters())
        
        if self.train_slayers:
            for slayer in self.slayers._modules.values():
                p.append(slayer.parameters())
                
        
        ret = itertools.chain(*p)
        
        return ret            

In [None]:
def experiment(n_elements: int, train_slayers: bool, centers_init: str):    
    stats_of_runs = []   
    
    for run_i, (train_i, test_i) in enumerate(train_env.train_test_splits):    

        model = MpegModel(n_elements=n_elements, train_slayers=train_slayers)
        
        if centers_init == 'k_means':
            model.k_means_center_init([dataset[i] for i in train_i])
        elif centers_init == 'min_max_random':
            model.min_max_random_init([dataset[i] for i in train_i])
            
            
        model.cuda()

        stats = defaultdict(list)
        stats_of_runs.append(stats)
        
        opt = torch.optim.SGD(model.parameters(), lr=train_env.lr_initial, momentum=train_env.momentum)

        for i_epoch in range(1, train_env.n_epochs+1):      

            model.train()
            
            train_sampler = [i for i in train_i]
            random.shuffle(train_sampler)
            dl_train = DataLoader(Subset(dataset, train_i),
                              batch_size=train_env.batch_size, 
                              collate_fn=collate_fn, 
                              shuffle=True)

            dl_test = DataLoader(Subset(dataset, test_i),
                                 batch_size=train_env.batch_size, 
                                 collate_fn=collate_fn)

            epoch_loss = 0    

            if i_epoch % train_env.lr_epoch_step == 0:
                adapt_lr(opt, lambda lr: lr*0.5)

            for i_batch, (x, y) in enumerate(dl_train, 1):              

                y = torch.autograd.Variable(y)

                def closure():
                    opt.zero_grad()
                    y_hat = model(x)            
                    loss = nn.functional.cross_entropy(y_hat, y)   
                    loss.backward()
                    return loss

                loss = opt.step(closure)

                epoch_loss += float(loss)
                stats['loss_by_batch'].append(float(loss))
                stats['centers'].append(model.slayers['dim_0_dir_0'][0].centers.data.cpu().numpy())

                print("Run, {}/{}, Epoch {}/{}, Batch {}/{}".format(run_i+1, 10, i_epoch, train_env.n_epochs, i_batch, len(dl_train)), end="       \r")
                
            stats['train_loss_by_epoch'].append(epoch_loss/len(dl_train))            
                     
            model.eval()    
            true_samples = 0
            seen_samples = 0
            epoch_test_loss = 0
            for i_batch, (x, y) in enumerate(dl_test):

                y_hat = model(x)
                epoch_test_loss += float(nn.functional.cross_entropy(y_hat, torch.autograd.Variable(y.cuda())).data)

                y_hat = y_hat.max(dim=1)[1].data.long()

                true_samples += (y_hat == y).sum()
                seen_samples += y.size(0)  

            stats['test_accuracy'].append(true_samples/seen_samples)
            stats['test_loss_by_epoch'].append(epoch_test_loss/len(dl_test))

        print('\r' + ' '*100, end='\r')
    
        
    return stats_of_runs

In [None]:
n = [25, 50, 75, 100, 125, 150]
overall_result = None
if os.path.isfile(output_path):
    with open(output_path, 'br') as f:
        overall_result = pickle.load(f)
else:
    overall_result = OrderedDict()

for n_elements, train_slayers, centers_init in itertools.product(n, [True, False], ['k_means', 'min_max_random']):
    print(n_elements, train_slayers, centers_init)
    
    stats_of_runs = None
    
    if (n_elements, train_slayers, centers_init) not in overall_result:   
        
        stats_of_runs = experiment(n_elements=n_elements, 
                                   train_slayers=train_slayers, 
                                   centers_init=centers_init)
        
        overall_result[(n_elements, train_slayers, centers_init)] = stats_of_runs
    
        with open(output_path , 'bw') as f:
            pickle.dump(overall_result, f)
            
    else:
        
        stats_of_runs = overall_result[(n_elements, train_slayers, centers_init)]
    accs = [np.mean(s['test_accuracy'][-10:]) for s in stats_of_runs]
    print('->', np.mean(accs), '+-', np.std(accs))
    
    
    
    

In [None]:
with open(output_path, 'br') as f:
    overall_result = pickle.load(f)

In [None]:
def get_label_from_key(key):
    learned, init = key
    
    ret = ''
    if learned:
        ret += 'Learned'
    else:
        ret += 'Frozen'
        
    ret += ' ('
    
    if init == 'k_means':
        ret += 'k-means'
    else:
        ret += 'random'
        
    ret += ')'
        
    return ret

n = [25, 50, 75, 100, 125, 150]

accuracies = defaultdict(list)
stddevs = defaultdict(list)

for parametrization, stats_of_runs in overall_result.items():
    n_elements, trained, init = parametrization
    
    accuracies_of_runs = [np.mean(s['test_accuracy'][-10:]) for s in stats_of_runs]
    
    accuracies[(trained, init)].append((n_elements, np.mean(accuracies_of_runs)))
    stddevs[(trained, init)].append((n_elements, np.std(accuracies_of_runs)))
    
def clean(d):
    ret = {}
    for k, v in d.items():
        ret[k] = [val for _, val in sorted(v, key=lambda x: x[0])]
        
    return ret

accuracies = clean(accuracies)
stddevs = clean(stddevs)

plt.figure(figsize=(4, 3))

plt.style.use('ggplot')


fs = 11
plt.xlabel('N', fontsize=fs)
plt.ylabel('Avg. accuracy [%]', fontsize=fs)

plt.gca().tick_params(labelsize=fs)

colors = ['#375E97', '#FB6542']
for k, acc_i in accuracies.items():
    std_i = stddevs[k]
    acc_i = [x*100 for x in acc_i]
    
    line_style = '-' if k[0] else '--'   
    c = colors[0] if k[1] == 'k_means' else colors[1]
    
    plt.plot(n, acc_i, label=get_label_from_key(k), c=c, linestyle=line_style)
    upper_bound = [x+100*y for x,y in zip(acc_i, std_i)]
    lower_bound = [x-100*y for x,y in zip(acc_i, std_i)]
    
    
plt.legend(fontsize=fs)

plt.savefig('/tmp/mpeg7_impact_of_slayer_trainig.pdf', bbox_inches='tight', pad_inches=0)

In [None]:
stats = overall_result[(50, True, 'min_max_random')][0]


if 'centers' in stats:
    
    plt.figure(figsize=(4, 3))
    plt.gca().tick_params(labelsize=fs)
    
    c_start = stats['centers'][0]
    c_end = stats['centers'][-1]

    plt.plot(c_start[:,0], c_start[:, 1], '*', label='$\mu$ (initialization)', c=colors[0])
    plt.plot(c_end[:,0], c_end[:, 1], 'o', label='$\mu$ (learned)', c=colors[1])
    plt.ylim(-0.1, 1.5)
    
    plt.xlabel(' ', fontsize=fs)
    plt.ylabel(' ', fontsize=fs)

    all_centers = numpy.stack(stats['centers'], axis=0)
    for i in range(all_centers.shape[1]):
        points = all_centers[:,i, :]
        plt.plot(points[:, 0], points[:, 1], '-k', alpha=0.40)
        

    plt.legend(loc=1, fontsize=fs)
    
plt.savefig('/tmp/centers_learning_progress.pdf', bbox_inches='tight', pad_inches=0)
plt.show()
    
plt.figure(figsize=(4, 3))
plt.gca().tick_params(labelsize=fs)

plt.gca().yaxis.set_major_formatter(FormatStrFormatter('%.2f'))

plt.xlabel('Epochs', fontsize=fs)
plt.ylabel('Loss', fontsize=fs)
plt.plot(stats['train_loss_by_epoch'], label='Train', c=colors[0])
plt.plot(stats['test_loss_by_epoch'], label='Test', c=colors[1])

plt.legend(fontsize=fs)    
plt.savefig('/tmp/loss_learning_progress.pdf', bbox_inches='tight', pad_inches=0)
plt.show()


In [None]:
result = experiment(n_elements=25, 
                    train_slayers=True, 
                    centers_init='min_max_random')
clear_output()
accuracies = [np.mean(s['test_accuracy'][-10:]) for s in result]
print(accuracies, '->', np.mean(accuracies), '+/-', np.std(accuracies))

In [None]:
result = experiment(n_elments=25, 
                    train_slayers=False, 
                    centers_init='min_max_random')
clear_output()
accuracies = [np.mean(s['test_accuracy'][-10:]) for s in result]
print(accuracies, '->', np.mean(accuracies), '+/-', np.std(accuracies))

In [None]:
result = experiment(n_elments=25, 
                    train_slayers=True, 
                    centers_init='k_means')
clear_output()
accuracies = [np.mean(s['test_accuracy'][-10:]) for s in result]
print(accuracies, '->', np.mean(accuracies), '+/-', np.std(accuracies))

In [None]:
result = experiment(n_elments=25, 
                    train_slayers=False, 
                    centers_init='k_means')
clear_output()
accuracies = [np.mean(s['test_accuracy'][-10:]) for s in result]
print(accuracies, '->', np.mean(accuracies), '+/-', np.std(accuracies))

In [None]:
print(np.mean([np.mean(s['test_accuracy'][-10:]) for s in stats_of_runs]))
[np.mean(s['test_accuracy'][-10:]) for s in stats_of_runs]