In [1]:
import torch
import random
import config
import glob
import torch.nn as nn
import os
import os.path as pth
import shutil
import itertools
import h5py

from functools import lru_cache

import torch
import torch.nn as nn
from torch.utils.data import Dataset, SubsetRandomSampler, DataLoader
from torch.optim.lr_scheduler import StepLR

from chofer_tda_datasets import Mpeg7, Animal, Reddit5kJmlr, Reddit12kJmlr, Reininghaus2014ShrecReal, SciNe01EEGBottomTopFiltration
from chofer_tda_datasets.transforms import Hdf5GroupToDict, Hdf5GroupToDictSelector
from chofer_tda_datasets.utils.h5py_dataset import Hdf5SupervisedDatasetOneFile

from jmlr_2018_code.utils import *
from persim import PersImage


class LabeledDataset(Dataset):
    def __init__(self, data, targets):
        assert isinstance(data, list)
        self.data = data
        self.targets = [int(y) for y in targets]
        
        assert len(data) == len(self.targets)
        
    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]
    
    def __len__(self):
        return len(self.targets)
    
    
def compute_persistent_images(dataset, data_keys, spreads=None):
    labels = []  
    data = []
    pixels = [20,20]
    for j, (x, y) in enumerate(dataset):
        
        tens = []
        labels.append(y)
        for k in data_keys:
            
            
            barcode = x[k]            
            tmp = []
            
            for spread in spreads:
                pim = PersImage(pixels=pixels, spread=spread, verbose=False)
                
                if len(barcode) != 0:
                    persistent_image = pim.transform(barcode)
                    
                else:                    
                    persistent_image = np.zeros(pixels)                    
                    
                tmp.append(persistent_image)
            
            tens += tmp
            
        tens = np.stack(tens,axis=0)
        data.append(tens.tolist())
        
        
        print("Calculating persistence images ... {}/{}".format(j+1, len(dataset)), end='\r')
    print('')
        
    return {'data': data, 'targets':labels, 'spreads': spreads}


def compute_persistent_images_h5file(file_path, dataset, data_keys, spreads=None):
    with h5py.File(file_path, 'w') as h5file:

        grp_data = h5file.create_group('data')
        
        pixels = [20,20]
        labels = []
        
        for j, (x, y) in enumerate(dataset):

            tens = []
            labels.append(y)
            for k in data_keys:

                barcode = x[k]            
                tmp = []

                for spread in spreads:
                    pim = PersImage(pixels=pixels, spread=spread, verbose=False)

                    if len(barcode) != 0:
                        persistent_image = pim.transform(barcode)

                    else:                    
                        persistent_image = np.zeros(pixels)                    

                    tmp.append(persistent_image)

                tens += tmp

            tens = np.stack(tens,axis=0).astype(np.float32)        
        
            grp_data.create_dataset(str(j), data=tens)
            print("Calculating persistence images ... {}/{}".format(j+1, len(dataset)), end='\r')


        ds_target = h5file.create_dataset('target',
                                          dtype=int,
                                          data=labels)
        
           
        print('')    


def compute_persistent_images_for_datasets():
    
    root = os.path.join(config.paths.data_root_dir, 'persistent_images')   
    spreads = [0.1, 0.5, 1.0]
    if not pth.isdir(root):
        os.mkdir(root)
        
    path = pth.join(root, 'mpeg7_pers_img.pickle')
    if not pth.isfile(path):
        barcode_ds = Mpeg7(config.paths.data_root_dir)
        data_keys = ["dim_0_dir_{}".format(i) for i in range(0, 32, 2)]
        pim = compute_persistent_images(barcode_ds, data_keys, spreads=spreads)    
        with open(path, 'bw') as fid:
            pickle.dump(pim, fid)        

    path = pth.join(root, 'animal_pers_img.pickle')
    if not pth.isfile(path):
        barcode_ds = Animal(config.paths.data_root_dir)
        data_keys = ["dim_0_dir_{}".format(i) for i in range(0, 32, 2)]
        pim = compute_persistent_images(barcode_ds, data_keys, spreads=spreads)    
        with open(path, 'bw') as fid:
            pickle.dump(pim, fid)

    path = pth.join(root, 'reddit5k_pers_img.pickle')
    if not pth.isfile(path):            
        barcode_ds = Reddit5kJmlr(config.paths.data_root_dir)
        barcode_ds.data_transforms = [Hdf5GroupToDict()]        
        barcode_ds.target_transforms = [lambda x: int(x) - 1]
        data_keys = ['dim_0']
        pim = compute_persistent_images(barcode_ds, data_keys, spreads=spreads)    
        with open(path, 'bw') as fid:
            pickle.dump(pim, fid)
    
    path = pth.join(root, 'reddit12k_pers_img.pickle')
    if not pth.isfile(path):        
        barcode_ds = Reddit12kJmlr(config.paths.data_root_dir)
        barcode_ds.data_transforms = [Hdf5GroupToDict()]        
        barcode_ds.target_transforms = [lambda x: int(x) - 1]
        data_keys = ['dim_0']
        pim = compute_persistent_images(barcode_ds, data_keys, spreads=spreads)    
        with open(path, 'bw') as fid:
            pickle.dump(pim, fid)
        
    path = pth.join(root, 'shrecReal_pers_img.pickle')
    if not pth.isfile(path):  
        
        def extract_wanted_values(x):
            ret = {}
            for k, v in x.items():
                ret[k] = v['0']

            return ret 
        
        barcode_ds = Reininghaus2014ShrecReal(config.paths.data_root_dir)
        barcode_ds.data_transforms = [Hdf5GroupToDict(), extract_wanted_values]        
        barcode_ds.target_transforms = [lambda x: int(x)]
        data_keys = list(barcode_ds[0][0].keys())
        pim = compute_persistent_images(barcode_ds, data_keys, spreads=spreads)    
        with open(path, 'bw') as fid:
            pickle.dump(pim, fid)    
            
    path = pth.join(root, 'ScineEEG_pers_img.h5')
    if not pth.isfile(path):        
        barcode_ds = SciNe01EEGBottomTopFiltration(data_root_folder_path=config.paths.data_root_dir)
        sensor_indices = [str(i) for i in barcode_ds.sensor_configurations['low_resolution_whole_head']]
        selection = {'top': sensor_indices, 'bottom': sensor_indices}
        selector = Hdf5GroupToDictSelector(selection)
        
        def extract_wanted_values(x):
            ret = {}
            for k, v in x.items():
                for kk, vv in v.items():
                    ret[k + '_' + kk] = vv

            return ret

        barcode_ds.data_transforms = [
                                       selector,
                                       extract_wanted_values
                                     ]
        barcode_ds.target_transforms = [lambda x: int(x)]
        
        data_keys = list(barcode_ds[0][0].keys())
        compute_persistent_images_h5file(path, barcode_ds, data_keys, spreads=spreads)    

        
class ScineEEGPersImg(Hdf5SupervisedDatasetOneFile):
    file_name = 'ScineEEG_pers_img.h5' 
    
        
def pim_ds_factory(ds_name):
    
    path = pth.join(config.paths.data_root_dir, 'persistent_images') 
    
    if ds_name in ['mpeg7, animal, reddit5k, reddit12k, shrecReal']:
    
        with open(pth.join(path, ds_name + '_pers_img.pickle'), 'br') as fid:
            tmp =  pickle.load(fid)

            data = tmp['data']
            targets = tmp['targets']

            data = [torch.tensor(x) for x in data]

            return LabeledDataset(data, targets)
        
    elif ds_name == 'scine_eeg':           
        ds = ScineEEGPersImg(data_root_folder_path=path)
        ds.data_transforms = [lambda x: torch.tensor(x)]
        
        return ds
    
    else: 
        raise ValueError()

In [2]:
compute_persistent_images_for_datasets()

In [7]:
class ShrecRealModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(30,3*30,kernel_size=3,stride=2),
            nn.BatchNorm2d(3*30),
            nn.ReLU())

        self.clf = nn.Sequential(
            nn.Linear(7290,512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512,256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256,40))

    def forward(self,x):
        x = self.conv(x)
        x = x.view(x.size(0),-1)
        x = self.clf(x)
        return x


class Mpeg7Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(16*3,64,kernel_size=3,stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU())

        self.clf = nn.Sequential(
            nn.Linear(5184,512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512,256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256,70))

    def forward(self,x):
        x = self.conv(x)
        x = x.view(x.size(0),-1)
        x = self.clf(x)
        return x
    
class AnimalModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(16*3,64,kernel_size=3,stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU())

        self.clf = nn.Sequential(
            nn.Linear(5184,512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512,256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256,20))

    def forward(self,x):
        x = self.conv(x)
        x = x.view(x.size(0),-1)
        x = self.clf(x)
        return x
    
    
class Reddit5kModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(1*3,64,kernel_size=3,stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU())

        self.clf = nn.Sequential(
            nn.Linear(5184,512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512,256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256,5))

    def forward(self,x):
        x = self.conv(x)
        x = x.view(x.size(0),-1)
        x = self.clf(x)
        return x
    
    
class Reddit12kModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(1*3,64,kernel_size=3,stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU())

        self.clf = nn.Sequential(
            nn.Linear(5184,512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512,256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256,11))

    def forward(self,x):
        x = self.conv(x)
        x = x.view(x.size(0),-1)
        x = self.clf(x)
        return x
    
    
class ScineEEGModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(40*3,64,kernel_size=3,stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU())

        self.clf = nn.Sequential(
            nn.Linear(5184,512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512,256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256,7))

    def forward(self,x):
        x = self.conv(x)
        x = x.view(x.size(0),-1)
        x = self.clf(x)
        return x


def model_factory(ds_name):
    
    if ds_name == 'mpeg7':
        return Mpeg7Model()
        
    elif ds_name == 'animal':
        return AnimalModel()
        
    elif ds_name == 'reddit5k':
        return Reddit5kModel()
        
    elif ds_name == 'reddit12k':
        return Reddit12kModel()
        
    elif ds_name == 'shrecReal':
        return ShrecRealModel()
    
    elif ds_name == 'scine_eeg':
        return ScineEEGModel()
    
    else:
        raise ValueError()

In [4]:
def evaluate_model(dl_test, net, device='cuda'):

    net.eval()
    correct = 0
    total   = 0
    for x,y in dl_test:

        x = x.float()
        x,y = x.to(device), y.to(device)
        outputs = net(x)
        _, predicted = outputs.max(1)
        correct += predicted.eq(y).sum().item()
        total += y.size(0)

    return (correct/total)*100


def train_and_evaluate(dataset, 
                model, 
                epochs= 100, 
                lr_initial = 0.01, 
                shedule_stepsize = 50,
                shedule_gamme=0.1, 
                n_processes_collate=None):

    device = 'cuda'

    train_ratio = 0.9
    train_i = np.random.choice(list(range(len(dataset))), 
                               size=int(len(dataset)*train_ratio), 
                               replace=False)
    test_i = [i for i in range(len(dataset)) if i not in train_i]
    assert len(train_i) + len(test_i) == len(dataset)

    dl_train = DataLoader(dataset, 
                          sampler=SubsetRandomSampler(train_i), 
                          batch_size=100, 
                          num_workers=0 if n_processes_collate is None else n_processes_collate)

    dl_test = DataLoader(dataset, 
                          sampler=SubsetRandomSampler(test_i), 
                          batch_size=100,
                          num_workers=0 if n_processes_collate is None else n_processes_collate)


    net = model.to(device)
    n_params = 0
    for tmp in net.parameters(): n_params += tmp.numel()
    print('#Params: ', n_params)

    optim = torch.optim.Adam(net.parameters(), lr=lr_initial)
    scheduler = StepLR(optim, step_size=shedule_stepsize, gamma=0.1)
    criterion = torch.nn.CrossEntropyLoss()

    net.train()
    for epoch_i in range(1, epochs+1):
        scheduler.step()

        epoch_loss = 0
        for x,y in dl_train:
            
            x = x.float()
            x, y = x.to(device), y.to(device)

            optim.zero_grad()
            y_hat = net(x)

            loss = criterion(y_hat, y)
            loss.backward()
            optim.step()

            epoch_loss += loss.item()

        print('epoch {}/{}'.format(epoch_i, epochs), end='\r')
        
    acc = evaluate_model(dl_test, net, device=device)
    print('')
    
    return acc

In [5]:
def experiment(ds_name, 
               n_repititions,
               n_processes_collate=None):
    dataset = pim_ds_factory(ds_name)
    
    result = []
    
    for i in range(n_repititions):
        model = model_factory(ds_name)
        acc = train_and_evaluate(dataset, model, 
                                 n_processes_collate=n_processes_collate)
        print('Run {}: {}'.format(i, acc))
        result.append(acc)
        
    return result        

In [78]:
# trained with lr=0.01
res = experiment('shrecReal', 10)
np.mean(res)

#Params:  3900706
epoch 100/100
#Params:  3900706
epoch 100/100
#Params:  3900706
epoch 100/100
#Params:  3900706
epoch 100/100
#Params:  3900706
epoch 100/100
#Params:  3900706
epoch 100/100
#Params:  3900706
epoch 100/100
#Params:  3900706
epoch 100/100
#Params:  3900706
epoch 100/100
#Params:  3900706
epoch 100/100


69.0

In [81]:
res = experiment('mpeg7', 10)
np.mean(res)

#Params:  2833414
epoch 100/100
#Params:  2833414
epoch 100/100
#Params:  2833414
epoch 100/100
#Params:  2833414
epoch 100/100
#Params:  2833414
epoch 100/100
#Params:  2833414
epoch 100/100
#Params:  2833414
epoch 100/100
#Params:  2833414
epoch 100/100
#Params:  2833414
epoch 100/100
#Params:  2833414
epoch 100/100


92.35714285714286

In [84]:
res = experiment('animal', 10)
np.mean(res)

#Params:  2820564
epoch 100/100
Run 0: 75.5
#Params:  2820564
epoch 100/100
Run 1: 66.5
#Params:  2820564
epoch 100/100
Run 2: 68.0
#Params:  2820564
epoch 100/100
Run 3: 72.0
#Params:  2820564
epoch 100/100
Run 4: 66.0
#Params:  2820564
epoch 100/100
Run 5: 74.0
#Params:  2820564
epoch 100/100
Run 6: 68.0
#Params:  2820564
epoch 100/100
Run 7: 65.0
#Params:  2820564
epoch 100/100
Run 8: 68.5
#Params:  2820564
epoch 100/100
Run 9: 70.0


69.35

In [85]:
res = experiment('reddit5k', 10)
np.mean(res)

#Params:  2790789
epoch 100/100
Run 0: 45.0
#Params:  2790789
epoch 100/100
Run 1: 49.0
#Params:  2790789
epoch 100/100
Run 2: 45.0
#Params:  2790789
epoch 100/100
Run 3: 49.4
#Params:  2790789
epoch 100/100
Run 4: 52.400000000000006
#Params:  2790789
epoch 100/100
Run 5: 45.6
#Params:  2790789
epoch 100/100
Run 6: 45.0
#Params:  2790789
epoch 100/100
Run 7: 45.0
#Params:  2790789
epoch 100/100
Run 8: 44.6
#Params:  2790789
epoch 100/100
Run 9: 45.800000000000004


46.68

In [86]:
res = experiment('reddit12k', 10)
np.mean(res)

#Params:  2792331
epoch 100/100
Run 0: 35.5406538139145
#Params:  2792331
epoch 100/100
Run 1: 38.055322715842415
#Params:  2792331
epoch 100/100
Run 2: 30.259849119865883
#Params:  2792331
epoch 100/100
Run 3: 37.63621123218776
#Params:  2792331
epoch 100/100
Run 4: 34.36714165968148
#Params:  2792331
epoch 100/100
Run 5: 32.35540653813915
#Params:  2792331
epoch 100/100
Run 6: 38.558256496228
#Params:  2792331
epoch 100/100
Run 7: 35.20536462699078
#Params:  2792331
epoch 100/100
Run 8: 34.702430846605196
#Params:  2792331
epoch 100/100
Run 9: 34.03185247275775


35.07124895222129

In [8]:
res = experiment('scine_eeg', 10, n_processes_collate=10)
np.mean(res)

#Params:  2858695
epoch 100/100
Run 0: 30.984126984126988
#Params:  2858695
epoch 100/100
Run 1: 30.22222222222222
#Params:  2858695
epoch 100/100
Run 2: 30.476190476190478
#Params:  2858695
epoch 100/100
Run 3: 29.174603174603174
#Params:  2858695
epoch 100/100
Run 4: 31.492063492063494
#Params:  2858695
epoch 100/100
Run 5: 31.492063492063494
#Params:  2858695
epoch 100/100
Run 6: 29.58730158730159
#Params:  2858695
epoch 100/100
Run 7: 30.698412698412696
#Params:  2858695
epoch 100/100
Run 8: 29.523809523809526
#Params:  2858695
epoch 100/100
Run 9: 30.634920634920636


30.428571428571434