In [31]:
import torch
import random
import config
import glob
import torch.nn as nn
import os
import os.path as pth
import shutil
import itertools

from functools import lru_cache

import torch
import torch.nn as nn
from torch.utils.data import Dataset, SubsetRandomSampler, DataLoader
from torch.optim.lr_scheduler import StepLR

from chofer_tda_datasets import Mpeg7, Animal, Reddit5kJmlr, Reddit12kJmlr, Reininghaus2014ShrecReal
from chofer_tda_datasets.transforms import Hdf5GroupToDict
from persim import PersImage


class LabeledDataset(Dataset):
    def __init__(self, data, targets):
        assert isinstance(data, list)
        self.data = data
        self.targets = [int(y) for y in targets]
        
        assert len(data) == len(self.targets)
        
    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]
    
    def __len__(self):
        return len(self.targets)
    
    
def compute_persistent_images(dataset, data_keys, spreads=None):
    labels = []  
    data = []
    pixels = [20,20]
    for j, (x, y) in enumerate(dataset):
        
        tens = []
        labels.append(y)
        for k in data_keys:
            
            
            barcode = x[k]            
            tmp = []
            
            for spread in spreads:
                pim = PersImage(pixels=pixels, spread=spread, verbose=False)
                
                if len(barcode) != 0:
                    persistent_image = pim.transform(barcode)
                    
                else:                    
                    persistent_image = np.zeros(pixels)                    
                    
                tmp.append(persistent_image)
            
            tens += tmp
            
        tens = np.stack(tens,axis=0)
        data.append(tens.tolist())
        
        
        print("Calculating persistence images ... {}/{}".format(j+1, len(dataset)), end='\r')
    print('')
        
    return {'data': data, 'targets':labels, 'spreads': spreads}


def compute_persistent_images_for_datasets():
    
    root = os.path.join(config.paths.data_root_dir, 'persistent_images')   
    spreads = [0.1, 0.5, 1.0]
    if not pth.isdir(root):
        os.mkdir(root)
        
    path = pth.join(root, 'mpeg7_pers_img.pickle')
    if not pth.isfile(path):
        barcode_ds = Mpeg7(config.paths.data_root_dir)
        data_keys = ["dim_0_dir_{}".format(i) for i in range(0, 32, 2)]
        pim = compute_persistent_images(barcode_ds, data_keys, spreads=spreads)    
        with open(path, 'bw') as fid:
            pickle.dump(pim, fid)        

    path = pth.join(root, 'animal_pers_img.pickle')
    if not pth.isfile(path):
        barcode_ds = Animal(config.paths.data_root_dir)
        data_keys = ["dim_0_dir_{}".format(i) for i in range(0, 32, 2)]
        pim = compute_persistent_images(barcode_ds, data_keys, spreads=spreads)    
        with open(path, 'bw') as fid:
            pickle.dump(pim, fid)

    path = pth.join(root, 'reddit5k_pers_img.pickle')
    if not pth.isfile(path):            
        barcode_ds = Reddit5kJmlr(config.paths.data_root_dir)
        barcode_ds.data_transforms = [Hdf5GroupToDict()]        
        barcode_ds.target_transforms = [lambda x: int(x) - 1]
        data_keys = ['dim_0']
        pim = compute_persistent_images(barcode_ds, data_keys, spreads=spreads)    
        with open(path, 'bw') as fid:
            pickle.dump(pim, fid)
    
    path = pth.join(root, 'reddit12k_pers_img.pickle')
    if not pth.isfile(path):        
        barcode_ds = Reddit12kJmlr(config.paths.data_root_dir)
        barcode_ds.data_transforms = [Hdf5GroupToDict()]        
        barcode_ds.target_transforms = [lambda x: int(x) - 1]
        data_keys = ['dim_0']
        pim = compute_persistent_images(barcode_ds, data_keys, spreads=spreads)    
        with open(path, 'bw') as fid:
            pickle.dump(pim, fid)
        
#     path = pth.join(root, 'shrecReal_pers_img.pickle')


def pim_ds_factory(ds_name):
    
    path = pth.join(config.paths.data_root_dir, 'persistent_images')   
    
    with open(pth.join(path, ds_name + '_pers_img.pickle'), 'br') as fid:
        tmp =  pickle.load(fid)
        
        data = tmp['data']
        targets = tmp['targets']
        
        data = [torch.tensor(x) for x in data]
        
        return LabeledDataset(data, targets)

In [None]:
compute_persistent_images_for_datasets()

Calculating persistence images ... 8424/11929

In [36]:
class Mpeg7Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(16*3,64,kernel_size=3,stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU())

        self.clf = nn.Sequential(
            nn.Linear(5184,512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512,256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256,70))

    def forward(self,x):
        x = self.conv(x)
        x = x.view(x.size(0),-1)
        x = self.clf(x)
        return x
    
class AnimalModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(16*3,64,kernel_size=3,stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU())

        self.clf = nn.Sequential(
            nn.Linear(5184,512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512,256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256,20))

    def forward(self,x):
        x = self.conv(x)
        x = x.view(x.size(0),-1)
        x = self.clf(x)
        return x
    
    
class Reddit5kModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(1*3,64,kernel_size=3,stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU())

        self.clf = nn.Sequential(
            nn.Linear(5184,512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512,256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256,5))

    def forward(self,x):
        x = self.conv(x)
        x = x.view(x.size(0),-1)
        x = self.clf(x)
        return x
    
    
class Reddit12kModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(1*3,64,kernel_size=3,stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU())

        self.clf = nn.Sequential(
            nn.Linear(5184,512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512,256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256,11))

    def forward(self,x):
        x = self.conv(x)
        x = x.view(x.size(0),-1)
        x = self.clf(x)
        return x


def model_factory(ds_name):
    
    if ds_name == 'mpeg7':
        return Mpeg7Model()
        
    elif ds_name == 'animal':
        return AnimalModel()
        
    elif ds_name == 'reddit5k':
        return Reddit5kModel()
        
    elif ds_name == 'reddit12k':
        return Reddit12kModel()
        
    elif ds_name == 'shrecReal':
        pass
    
    else:
        raise ValueError()

In [37]:
def evaluate_model(dl_test, net, device='cuda'):

    net.eval()
    correct = 0
    total   = 0
    for x,y in dl_test:

        x = x.float()
        x,y = x.to(device), y.to(device)
        outputs = net(x)
        _, predicted = outputs.max(1)
        correct += predicted.eq(y).sum().item()
        total += y.size(0)

    return (correct/total)*100


def train_and_evaluate(dataset, 
                model, 
                epochs= 100, 
                lr_initial = 0.0001, 
                shedule_stepsize = 50,
                shedule_gamme=0.1):

    device = 'cuda'

    train_ratio = 0.9
    train_i = np.random.choice(list(range(len(dataset))), 
                               size=int(len(dataset)*train_ratio), 
                               replace=False)
    test_i = [i for i in range(len(dataset)) if i not in train_i]
    assert len(train_i) + len(test_i) == len(dataset)

    dl_train = DataLoader(dataset, 
                          sampler=SubsetRandomSampler(train_i), 
                          batch_size=100)

    dl_test = DataLoader(dataset, 
                          sampler=SubsetRandomSampler(test_i), 
                          batch_size=100)


    net = model.to(device)

    optim = torch.optim.Adam(net.parameters(), lr=lr_initial)
    scheduler = StepLR(optim, step_size=shedule_stepsize, gamma=0.1)
    criterion = torch.nn.CrossEntropyLoss()

    net.train()
    for epoch_i in range(1, epochs+1):
        scheduler.step()

        epoch_loss = 0
        for x,y in dl_train:
            
            x = x.float()
            x, y = x.to(device), y.to(device)

            optim.zero_grad()
            y_hat = net(x)

            loss = criterion(y_hat, y)
            loss.backward()
            optim.step()

            epoch_loss += loss.item()

        print('epoch {}/{}'.format(epoch_i, epochs), end='\r')
        
    acc = evaluate_model(dl_test, net, device=device)
    print('')
    
    return acc

In [38]:
def experiment(ds_name, n_repititions):
    dataset = pim_ds_factory(ds_name)
    
    result = []
    
    for i in range(n_repititions):
        model = model_factory(ds_name)
        acc = train_and_evaluate(dataset, model)
        
        result.append(acc)
        
    return result        

In [41]:
res = experiment('mpeg7', 10)
np.mean(res)

epoch 100/100
epoch 100/100
epoch 100/100
epoch 100/100
epoch 100/100
epoch 100/100
epoch 100/100
epoch 100/100
epoch 100/100
epoch 100/100


91.57142857142858

In [42]:
res = experiment('animal', 10)
np.mean(res)

epoch 100/100
epoch 100/100
epoch 100/100
epoch 100/100
epoch 100/100
epoch 100/100
epoch 100/100
epoch 100/100
epoch 100/100
epoch 100/100


65.85

In [43]:
res = experiment('reddit5k', 10)
np.mean(res)

epoch 100/100
epoch 100/100
epoch 100/100
epoch 100/100
epoch 100/100
epoch 100/100
epoch 100/100
epoch 100/100
epoch 100/100
epoch 100/100


48.42

In [44]:
res = experiment('reddit12k', 10)
np.mean(res)

epoch 100/100
epoch 100/100
epoch 100/100
epoch 100/100
epoch 100/100
epoch 100/100
epoch 100/100
epoch 100/100
epoch 100/100
epoch 100/100


38.90192791282482