In [486]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [487]:
import time
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.utils.data import DataLoader
import pytorch_lightning as pl

from src.autoencoder import AutoEncoder, NSAAutoEncoder
from src.utils import *
from src.loss import RTDLoss, NSALoss, LID_NSALoss
from src.top_ae import TopologicallyRegularizedAutoencoder

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler

from collections import defaultdict

from tqdm.notebook import tqdm

In [489]:
config = {
    "dataset_name":"F-MNIST",
    "version":"k_L256_N256",
    "model_name":"default",
    "max_epochs":250,
    "gpus":[0],
    "rtd_every_n_batches":1,
    "rtd_start_epoch":60,
    "rtd_l":1.0, # rtd loss
    "nsa_every_n_batches":1,
    "nsa_start_epoch":0,
    "nsa_l":1.0, # rtd loss
    "n_runs":1, # number of runs for each model
    "card":50, # number of points on the persistence diagram
    "n_threads":1, # number of threads for parallel ripser computation of pers homology
    "latent_dim":16, # latent dimension (2 or 3 for vizualization purposes)
    "input_dim":28*28,
    "n_hidden_layers":3,
    "hidden_dim":512,
    "batch_size":256,
    "engine":"ripser",
    "is_sym":True,
    "lr":1e-4,
}

In [490]:
def get_model(input_dim, latent_dim=2, n_hidden_layers=2, m_type='encoder', **kwargs):
    n = int(np.log2(input_dim))-1
    layers = []
    if m_type == 'encoder':
        in_dim = input_dim
        if input_dim  // 2 >= latent_dim:
            out_dim = input_dim // 2
        else:
            out_dim = input_dim
        for i in range(min(n, n_hidden_layers)):
            layers.extend([nn.Linear(in_dim, out_dim), nn.ReLU()])
            in_dim = out_dim
            if in_dim  // 2 >= latent_dim:
                out_dim = in_dim // 2
            else:
                out_dim = in_dim
        layers.extend([nn.Linear(in_dim, latent_dim)])
    elif m_type == 'decoder':
        in_dim = latent_dim
        out_dim = latent_dim * 2
        for i in range(min(n, n_hidden_layers)):
            layers.extend([nn.Linear(in_dim, out_dim), nn.ReLU()])
            in_dim = out_dim
            out_dim *= 2
        layers.extend([nn.Linear(in_dim, input_dim)])
    return nn.Sequential(*layers)

def get_list_of_models(**config):
    # define a list of models
    encoder = get_linear_model(
        m_type='encoder',
        **config
    )
    decoder = get_linear_model(
        m_type='decoder',
        **config
    )
    models = {
        # 'Basic AutoEncoder':AutoEncoder(
        #    encoder = encoder,
        #     decoder = decoder,
        #     MSELoss = nn.MSELoss(),
        #     **config
        # ),
        # 'Topological AutoEncoder':TopologicallyRegularizedAutoencoder(
        #     encoder = encoder,
        #     decoder = decoder,
        #     MSELoss = nn.MSELoss(),
        #     **config
        # ),
        # 'RTD AutoEncoder H1':AutoEncoder(
        #     encoder = encoder,
        #     decoder = decoder,
        #     RTDLoss = RTDLoss(dim=1, lp=1.0,  **config), # only H1
        #     MSELoss = nn.MSELoss(),
        #     **config
        # ),
        'LID_NSA AutoEncoder':NSAAutoEncoder(
            encoder = encoder,
            decoder = decoder,
            NSALoss = LID_NSALoss(k=config['batch_size']-1), # only H1
            MSELoss = None,
            **config
        ),
    }
    return models, encoder, decoder

In [491]:
def collate_with_matrix(samples):
    indicies, data, labels = zip(*samples)
    data, labels = torch.tensor(np.asarray(data)), torch.tensor(np.asarray(labels))
    if len(data.shape) > 2:
        dist_data = torch.flatten(data, start_dim=1)
    else:
        dist_data = data
    x_dist = torch.cdist(dist_data, dist_data, p=2) / np.sqrt(dist_data.shape[1])
#     x_dist = (x_dist + x_dist.T) / 2.0 # make symmetrical (cdist is prone to computational errors)
    return data, x_dist, labels

def collate_with_matrix_geodesic(samples):
    indicies, data, labels, dist_data = zip(*samples)
    data, labels = torch.tensor(np.asarray(data)), torch.tensor(np.asarray(labels))
    x_dist = torch.tensor(np.asarray(dist_data)[:, indicies])
    return data, x_dist, labels

In [492]:
dataset_name = config['dataset_name']
if dataset_name in ['COIL-20','COIL-100']:
    train_data = np.load(f'data/{dataset_name}/prepared/data.npy').astype(np.float32)
elif dataset_name.startswith('LinkPrediction'):
    train_data = np.load(f'data/{dataset_name}/LP_3_200.npz')
    train_data = dict(train_data)
    print(train_data.keys())
    key = list(train_data.keys())[-1]
    print(key)
    train_data = train_data[key]
else:
    train_data = np.load(f'data/{dataset_name}/prepared/train_data.npy').astype(np.float32)


try:        
    test_data = np.load(f'data/{dataset_name}/prepared/test_data.npy').astype(np.float32)
except FileNotFoundError:
    ids = np.random.choice(np.arange(len(train_data)), size=int(0.2*len(train_data)), replace=False)
    test_data = train_data[ids]

try:
    if dataset_name in ['COIL-20','COIL-100']:
        print("Inside here")
        train_labels = np.load(f'data/{dataset_name}/prepared/labels.npy')
    elif dataset_name.startswith('LinkPrediction'):
        train_labels = np.arange(1,len(train_data)+1)
    else:
        train_labels = np.load(f'data/{dataset_name}/prepared/train_labels.npy')
except FileNotFoundError:
    train_labels = None

try:
    test_labels = np.load(f'data/{dataset_name}/prepared/test_labels.npy')
except FileNotFoundError:
    if train_labels is None:
        test_labels = None
    else:
        test_labels = train_labels[ids]

In [493]:
print(train_data.shape)
print(train_labels[:10])
print(train_labels.shape)
print(test_data.shape)
print(test_labels[:10])
print(test_labels.shape)

(60000, 28, 28)
[9 0 0 3 0 2 7 2 5 5]
(60000,)
(10000, 28, 28)
[9 2 1 1 6 1 4 6 5 7]
(10000,)


In [494]:
import numpy as np

In [470]:
#Workflow

#Now pick a random index and add that index to the queue
#for each element in the queue, as you pop it
#1 add all the neighbors in the neighbors to the queue, either k or k+ connections necessary to form a single connected component
#Remove the node along with all its connections (edges) from the nearest neighbors graph and from potential indices that I can initially pick from
#Repeat until I pop batch_size elements.
#After one batch is ready, check for 1 cc and fix cc


import torch
from torch.utils.data import Dataset, DataLoader, Sampler
import numpy as np
import random
from scipy.sparse.csgraph import connected_components, shortest_path
from src.utils import _fix_connected_components
import copy


class NearestNeighborBatchSampler(Sampler):
    def __init__(self, dataset, batch_size, num_neighbors):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_neighbors = num_neighbors
        #Given input data X, compute k nearest neighbors for each point
        self.kng = kneighbors_graph(self.dataset.data, n_neighbors=self.num_neighbors, mode='connectivity')
        self.kng_dict_front = self.create_kng_dict_front()
        self.kng_dict_back = self.create_kng_dict_back()
        #Use fix connected components to generate one connected component
        num_ccs, component_labels = connected_components(self.kng)
        if num_ccs > 1:
            self.kng = _fix_connected_components(self.dataset.data, self.kng, num_ccs, component_labels, mode='connectivity', metric='euclidean')
            #Check again to confirm if it worked
            num_ccs, component_labels = connected_components(self.kng)
            if num_ccs >1:
                raise ValueError("Increase nearest neighbor size; cannot generate a single connected component with the given knn size.")

    def create_kng_dict_front(self):
        dict = [list(self.nearest_neighbors(i)) for i in range(len(self.dataset))]
        return dict
        
    def create_kng_dict_back(self):
        dict = [list(self.nearest_neighbors_of(i)) for i in range(len(self.dataset))]
        return dict

    def nearest_neighbors(self, index, kng=None):
        if kng is None:
            kng = self.kng
        return kng.getrow(index).nonzero()[1]

    def nearest_neighbors_of(self, index, kng=None):
        if kng is None:
            kng = self.kng
        return kng.getcol(index).nonzero()[0]
        
    # def remove_point_existence(self, index):
    #     for i in self.kng_dict_back[index]:
    #         self.kng_dict_front[i].remove(index)
    #     for i in self.kng_dict_front[index]:
    #         self.kng_dict_back[i].remove(index)
    #     self.kng_dict_front[index] = []
    #     self.kng_dict_back[index] = []
    
    def __iter__(self):
        print("Running iter")
        # iter_kng = self.kng.copy()
        numbatches = len(self.dataset) // self.batch_size
        indices = list(range(len(self.dataset)))
        random.shuffle(indices)
        indices = indices[:numbatches]
        # self.kng_dict_front = self.create_kng_dict_front()
        # self.kng_dict_back = self.create_kng_dict_back()
        # indices = set(indices)
        batches = []
        while indices:
            kng_front = copy.deepcopy(self.kng_dict_front)
            kng_back = copy.deepcopy(self.kng_dict_back)
            # if len(indices) < self.batch_size:
            #     #Add functionality to sort the data in case you want to retain ordering
            #     batches.append(indices)
            #     break
            # print(indices[:10])
            batch = set()
            queue = []
            queue_set = set()
            next_point = indices.pop()
            # print("Current index:",next_point)
            # print("Remaining indices:",len(indices))
            # print("Number of batches:", len(batches))
            batch.add(next_point)
            queue.extend(self.kng_dict_front[next_point])
            queue_set.update(self.kng_dict_front[next_point])
            for i in kng_back[next_point]:
                kng_front[i].remove(next_point)
            for i in kng_front[next_point]:
                kng_back[i].remove(next_point)
            kng_front[next_point] = []
            kng_back[next_point] = []
            # self.remove_point_existence(next_point)
            while len(batch) + len(queue) < self.batch_size:
                # print("Queue Length:",len(queue))
                if queue:
                    next_point = queue.pop(0)
                    if next_point not in batch:
                        batch.add(next_point)
                        # indices.remove(next_point)
                        for new_point in self.kng_dict_front[next_point]:
                            if new_point not in queue_set:
                                queue.append(new_point)
                                queue_set.add(new_point)
                        for i in kng_back[next_point]:
                            kng_front[i].remove(next_point)
                        for i in kng_front[next_point]:
                            kng_back[i].remove(next_point)
                        kng_front[next_point] = []
                        kng_back[next_point] = []
                        # queue.extend(self.kng_dict_front[next_point])
                        # self.remove_point_existence(next_point)
                else:
                #     # print("Queue is empty but batch requirements are not met starting from point:",next_point)
                    if indices:
                        next_point = indices.pop()
                        batch.add(next_point)
                        queue.extend(self.kng_dict_front[next_point])
                        queue_set.update(self.kng_dict_front[next_point])
                        for i in kng_back[next_point]:
                            kng_front[i].remove(next_point)
                        for i in kng_front[next_point]:
                            kng_back[i].remove(next_point)
                        kng_front[next_point] = []
                        kng_back[next_point] = []
                        # queue.extend(self.kng_dict_front[next_point])
                        # self.remove_point_existence(next_point)
                # Remove duplicates in queue
                #queue = list(dict.fromkeys(queue))
            batch.update(queue)
            # print("Created a batch of size:",len(batch))
        
            batches.append(list(batch))
            # print("Previous batch size:", len(batches[-1]))
            #This might not be necessary
            # num_ccs, component_labels = connected_components(iter_kng)
            # if num_ccs > 1:
            #     print("Connected component check failed during batching")
            #     # iter_kng = _fix_connected_components(self.dataset.data, iter_kng, num_ccs, component_labels, mode='connectivity', metric='euclidean')
            #     # #Check to see if it was fixed
            #     # num_ccs, component_labels = connected_components(iter_kng)
            #     # if n_ccs >1:
            #     #     raise ValueError("Connected component check failed during batching; increase nearest neighbor size.")
        
        return iter(batches)

    def __len__(self):
        return len(self.dataset) // self.batch_size

In [495]:
#Workflow

#Now pick a random index and add that index to the queue
#for each element in the queue, as you pop it
#1 add all the neighbors in the neighbors to the queue, either k or k+ connections necessary to form a single connected component
#Remove the node along with all its connections (edges) from the nearest neighbors graph and from potential indices that I can initially pick from
#Repeat until I pop batch_size elements.
#After one batch is ready, check for 1 cc and fix cc


import torch
from torch.utils.data import Dataset, DataLoader, Sampler
import numpy as np
import random
from scipy.sparse.csgraph import connected_components, shortest_path
from src.utils import _fix_connected_components
import copy
import threading, queue


class NearestNeighborBatchSamplerMulti(Sampler):
    def __init__(self, dataset, batch_size, num_neighbors, num_threads=24):
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_neighbors = num_neighbors
        #Given input data X, compute k nearest neighbors for each point
        self.kng = kneighbors_graph(self.dataset.data, n_neighbors=self.num_neighbors, mode='connectivity')
        self.kng_dict_front = self.create_kng_dict_front()
        #Use fix connected components to generate one connected component
        num_ccs, component_labels = connected_components(self.kng)
        if num_ccs > 1:
            self.kng = _fix_connected_components(self.dataset.data, self.kng, num_ccs, component_labels, mode='connectivity', metric='euclidean')
            #Check again to confirm if it worked
            num_ccs, component_labels = connected_components(self.kng)
            if num_ccs >1:
                raise ValueError("Increase nearest neighbor size; cannot generate a single connected component with the given knn size.")
        self.results_queue = queue.Queue()
        self.num_threads = num_threads
        self.indices = []

    def create_kng_dict_front(self):
        dict = [list(self.nearest_neighbors(i)) for i in range(len(self.dataset))]
        return dict
        
    def create_kng_dict_back(self):
        dict = [list(self.nearest_neighbors_of(i)) for i in range(len(self.dataset))]
        return dict

    def nearest_neighbors(self, index, kng=None):
        if kng is None:
            kng = self.kng
        return kng.getrow(index).nonzero()[1]

    def nearest_neighbors_of(self, index, kng=None):
        if kng is None:
            kng = self.kng
        return kng.getcol(index).nonzero()[0]

    def one_minibatch(self, point):
        batch = set()
        point_queue = []
        queue_set = set()
        batch.add(point)
        point_queue.extend(self.kng_dict_front[point])
        queue_set.update(self.kng_dict_front[point])
        while len(batch) + len(point_queue) < self.batch_size:
            if point_queue:
                next_point = point_queue.pop(0)
                if next_point not in batch:
                    batch.add(next_point)
                    for new_point in self.kng_dict_front[next_point]:
                        if new_point not in queue_set:
                            point_queue.append(new_point)
                            queue_set.add(new_point)
            else:
                if self.indices:
                    next_point = self.indices.pop()
                    batch.add(next_point)
                    point_queue.extend(self.kng_dict_front[next_point])
                    queue_set.update(self.kng_dict_front[next_point])
        batch.update(point_queue)
        return batch
        
    def thread_minibatch(self):
        while True:
            try:
                point = self.indices.pop()  # Get a value to process
            except IndexError:
                break
            batch = self.one_minibatch(point)
            self.results_queue.put(batch)
    
    def __iter__(self):
        #print("Running iter")
        numbatches = len(self.dataset) // self.batch_size
        self.indices = list(range(len(self.dataset)))
        random.shuffle(self.indices)
        self.indices = self.indices[:numbatches]
        batches = []
        threads = []
        for _ in range(self.num_threads):
            thread = threading.Thread(target=self.thread_minibatch)
            thread.start()
            threads.append(thread)
        for thread in threads:
            thread.join()
        while not self.results_queue.empty():
            result = self.results_queue.get()
            batches.append(list(result))
        return iter(batches)

    def __len__(self):
        return len(self.dataset) // self.batch_size


In [496]:
print(train_data.shape)
print(train_labels[:10])
print(train_labels.shape)
print(test_data.shape)
print(test_labels[:10])
print(test_labels.shape)

(60000, 28, 28)
[9 0 0 3 0 2 7 2 5 5]
(60000,)
(10000, 28, 28)
[9 2 1 1 6 1 4 6 5 7]
(10000,)


In [497]:
import numpy as np

class CustomMinMaxScaler:
    def __init__(self):
        self.min_vals = train_data.min()
        self.max_vals = train_data.max()
        self.is_fitted = True
        
    def fit(self, data):
        self.min_vals = np.min(data, axis=0)
        self.max_vals = np.max(data, axis=0)
        self.is_fitted = True
        
    def transform(self, data):
        if not self.is_fitted:
            raise NotFittedError
        scaled_data = (data - self.min_vals) / (self.max_vals - self.min_vals)
        return scaled_data
    
    def fit_transform(self, data):
        self.fit(data)
        return self.transform(data)

In [498]:
import torch
from torch.utils.data import Dataset, DataLoader, Sampler
import random

In [499]:
class CustomDataLoader(DataLoader):
    def __init__(self, dataset, k, **kwargs):
        super().__init__(dataset, **kwargs)
        self.k = k
        self.epoch = 0
    def __iter__(self):
        if self.epoch % self.k == 0:
            self._iterator = self._get_iterator()
            self.epoch = 0  # Reset the epoch counter
        self.epoch += 1
        return self._iterator

In [500]:
scaler = CustomMinMaxScaler()
#scaler = None
flatten = True
geodesic = False

train = FromNumpyDataset(
    train_data, 
    train_labels, 
    geodesic=geodesic, 
    scaler=scaler, 
    flatten=flatten, 
    n_neighbors=2
)
print("Train done")
test = FromNumpyDataset(
    test_data, 
    test_labels, 
    geodesic=geodesic, 
    scaler = train.scaler,    
    flatten=flatten, 
    n_neighbors=2
)
train_sampler = NearestNeighborBatchSamplerMulti(train, config['batch_size'], num_neighbors=config['batch_size']-1, num_threads=24)
#val_sampler = NearestNeighborBatchSampler(test, config['batch_size'], num_neighbors=5)

train_loader = DataLoader(
    train, 
    batch_sampler=train_sampler, 
    num_workers=24, 
    collate_fn=collate_with_matrix_geodesic if geodesic else collate_with_matrix, 
)

# train_loader = CustomDataLoader(
#     train,
#     batch_sampler=train_sampler,
#     num_workers=0,
#     k=10,
#     collate_fn=collate_with_matrix_geodesic if geodesic else collate_with_matrix,
# )

# val_loader = DataLoader(
#     test,
#     batch_sampler=val_sampler,
#     num_workers=2,
#     collate_fn=collate_with_matrix_geodesic if geodesic else collate_with_matrix,
# )

# train_loader = DataLoader(
#     train, 
#     batch_size=config["batch_size"], 
#     num_workers=24, 
#     collate_fn=collate_with_matrix_geodesic if geodesic else collate_with_matrix, 
#     shuffle=True,
#     drop_last=True
# )

# val_loader = DataLoader(
#     test,
#     batch_size=config["batch_size"],
#     num_workers=24,
#     collate_fn=collate_with_matrix_geodesic if geodesic else collate_with_matrix,
#     shuffle=False,
#     drop_last=True
# )

Train done


In [478]:
#next(iter(train_loader))

In [501]:
torch.__version__

'1.13.0+cu117'

In [502]:
def train_autoencoder(model, train_loader, val_loader=None, model_name='default', 
                      dataset_name='F-MNIST', gpus=[0], max_epochs=100, run=0, version="d1"):
    version = f"{dataset_name}_{model_name}_{version}_{run}"
    logger = pl.loggers.TensorBoardLogger(save_dir=os.getcwd(), name='lightning_logs', version=version)
    trainer = pl.Trainer(
        logger=logger, 
        gpus=gpus, 
        max_epochs=max_epochs, 
        log_every_n_steps=1, 
        num_sanity_val_steps=0
    )
    trainer.fit(model, train_loader, val_loader)
    return model

def dump_figures(figures, dataset_name, version):
    for model_name in figures:
        figures[model_name].savefig(f'results/{dataset_name}/{model_name}_{version}.png')

def train_models(train_loader, val_loader, dataset_name="", max_epochs=1, gpus=[], n_neighbors=[1], n_runs=1, version='', **kwargs):
    models, encoder, decoder = get_list_of_models(**kwargs)
    
    for model_name in tqdm(models, desc=f"Training models"):
        if 'AutoEncoder' in model_name: # train an autoencoder
            models[model_name] = train_autoencoder(
                models[model_name], 
                train_loader, 
                val_loader, 
                model_name, 
                dataset_name,
                gpus,
                max_epochs,
                0,
                version
            )
        else: # umap / pca / t-sne (sklearn interface)
            train_latent = models[model_name].fit_transform(train_loader.dataset.data)
        # measure training time
    return encoder, decoder, models

In [503]:
encoder, decoder, trained_models = train_models(train_loader, val_loader, **config)

Training models:   0%|          | 0/1 [00:00<?, ?it/s]

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2]
Set SLURM handle signals.

  | Name    | Type        | Params
----------------------------------------
0 | encoder | Sequential  | 1.2 M 
1 | decoder | Sequential  | 1.2 M 
2 | NSALoss | LID_NSALoss | 0     
----------------------------------------
2.4 M     Trainable params
0         Non-trainable params
2.4 M     Total params
9.588     Total estimated model params size (MB)
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
2

Validating: 0it [00:00, ?it/s]

Running iter
Running iter
255 torch.Size([256, 256])
255 torch.Size([258, 258])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([259, 259])
255 torch.Size([257, 257])
255 torch.Size([257, 257])
255 torch.Size([258, 258])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([259, 259])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([258, 258])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([258, 258])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([259, 259])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([258, 258])
255 torch.Size([257, 257])
25

Validating: 0it [00:00, ?it/s]

Running iter
Running iter
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([258, 258])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([258, 258])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([257, 257])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([258, 258])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([259, 259])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([258, 258])
25

Validating: 0it [00:00, ?it/s]

Running iter
Running iter
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([258, 258])
255 torch.Size([259, 259])
255 torch.Size([258, 258])
255 torch.Size([256, 256])
255 torch.Size([259, 259])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([259, 259])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([258, 258])
255 torch.Size([257, 257])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([260, 260])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([258, 258])
255 torch.Size([259, 259])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([258, 258])
25

Validating: 0it [00:00, ?it/s]

Running iter
Running iter
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([258, 258])
255 torch.Size([258, 258])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([258, 258])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([258, 258])
255 torch.Size([259, 259])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([259, 259])
255 torch.Size([259, 259])
255 torch.Size([256, 256])
255 torch.Size([258, 258])
255 torch.Size([258, 258])
25

Validating: 0it [00:00, ?it/s]

Running iter
Running iter
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([258, 258])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([259, 259])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([259, 259])
255 torch.Size([257, 257])
255 torch.Size([257, 257])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
25

Validating: 0it [00:00, ?it/s]

Running iter
Running iter
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([259, 259])
255 torch.Size([258, 258])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([258, 258])
255 torch.Size([257, 257])
255 torch.Size([258, 258])
255 torch.Size([256, 256])
255 torch.Size([259, 259])
255 torch.Size([258, 258])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([259, 259])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([258, 258])
255 torch.Size([259, 259])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([258, 258])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
25

Validating: 0it [00:00, ?it/s]

Running iter
Running iter
255 torch.Size([259, 259])
255 torch.Size([259, 259])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([257, 257])
255 torch.Size([257, 257])
255 torch.Size([257, 257])
255 torch.Size([259, 259])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([257, 257])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([260, 260])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([257, 257])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([259, 259])
255 torch.Size([258, 258])
255 torch.Size([258, 258])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([259, 259])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
25

Validating: 0it [00:00, ?it/s]

Running iter
Running iter
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([258, 258])
255 torch.Size([258, 258])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([258, 258])
255 torch.Size([256, 256])
255 torch.Size([258, 258])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([257, 257])
255 torch.Size([257, 257])
255 torch.Size([258, 258])
255 torch.Size([257, 257])
255 torch.Size([258, 258])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([258, 258])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
25

Validating: 0it [00:00, ?it/s]

Running iter
Running iter
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([258, 258])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([257, 257])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([260, 260])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([259, 259])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([256, 256])
255 torch.Size([257, 257])
255 torch.Size([257, 257])
255 torch.Size([258, 258])
255 torch.Size([258, 258])
255 torch.Size([257, 257])
255 torch.Size([257, 257])
255 torch.Size([255, 255])


RuntimeError: selected index k out of range

In [442]:
from src.utils import *

In [443]:
trained_models

{'LID_NSA AutoEncoder': NSAAutoEncoder(
   (encoder): Sequential(
     (0): Linear(in_features=256, out_features=512, bias=True)
     (1): ReLU()
     (2): Linear(in_features=512, out_features=512, bias=True)
     (3): ReLU()
     (4): Linear(in_features=512, out_features=512, bias=True)
     (5): ReLU()
     (6): Linear(in_features=512, out_features=512, bias=True)
     (7): ReLU()
     (8): Linear(in_features=512, out_features=64, bias=True)
   )
   (decoder): Sequential(
     (0): Linear(in_features=64, out_features=512, bias=True)
     (1): ReLU()
     (2): Linear(in_features=512, out_features=512, bias=True)
     (3): ReLU()
     (4): Linear(in_features=512, out_features=512, bias=True)
     (5): ReLU()
     (6): Linear(in_features=512, out_features=512, bias=True)
     (7): ReLU()
     (8): Linear(in_features=512, out_features=256, bias=True)
   )
   (NSALoss): LID_NSALoss()
   (MSELoss): MSELoss()
 )}

In [444]:
config['version']

'k_L64_N32'

In [445]:
version = config['version']
train_loader = DataLoader(
    train,
    batch_size=config["batch_size"],
    num_workers=0,
    collate_fn=collate_with_matrix_geodesic if geodesic else collate_with_matrix,
    shuffle=False
)

for model_name in trained_models:
    latent, labels = get_latent_representations(trained_models[model_name], train_loader)
    print(latent.shape)
    np.save(f'data/{dataset_name}/{model_name}_latent_output_{version}.npy', latent)
    np.save(f'data/{dataset_name}/{model_name}_latent_labels_{version}.npy', labels)

for model_name in trained_models:
    latent, labels = get_output_representations(trained_models[model_name], train_loader)
    print(latent.shape)
    np.save(f'data/{dataset_name}/{model_name}_final_output_{version}.npy', latent)
    np.save(f'data/{dataset_name}/{model_name}_final_labels_{version}.npy', labels)

(19717, 64)
(19717, 256)


In [446]:
for model_name in trained_models:
    latent, labels = get_latent_representations(trained_models[model_name], val_loader)
    print(latent.shape)
    np.save(f'data/{dataset_name}/{model_name}_latent_output_{version}_test.npy', latent)
    np.save(f'data/{dataset_name}/{model_name}_latent_labels_{version}_test.npy', labels)

for model_name in trained_models:
    latent, labels = get_output_representations(trained_models[model_name], val_loader)
    print(latent.shape)
    np.save(f'data/{dataset_name}/{model_name}_final_output_{version}_test.npy', latent)
    np.save(f'data/{dataset_name}/{model_name}_final_labels_{version}_test.npy', labels)

(3840, 64)
(3840, 256)
