In [None]:
%cd ..\src

In [None]:
import time

In [None]:
import torch
from torch.utils.data import DataLoader
import numpy as np

from configs.utils import get_config
from echovpr.datasets.utils import load_np_file
from echovpr.models.single_esn import SingleESN
from echovpr.models.utils import get_sparsity

import logging

logging.basicConfig(level=logging.INFO)
log = logging.getLogger()


In [None]:
config = get_config("configs\\train_esn_nordland.ini", log)['main']

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
in_features=int(config['model_in_features'])
reservoir_size=int(config['model_reservoir_size'])
out_features=int(config['model_out_features'])

esn_alpha = float(config['model_esn_alpha'])
esn_gamma = float(config['model_esn_gamma'])
esn_rho = float(config['model_esn_rho'])
esn_num_connections = int(config['model_esn_num_connections'])

encoder = SingleESN(
    in_features, 
    reservoir_size, 
    alpha=esn_alpha, 
    gamma=esn_gamma, 
    rho=esn_rho,
    sparsity=get_sparsity(esn_num_connections, reservoir_size),
    device=device
)

encoder.eval().to(device)

In [None]:
summer_hidden_repr = torch.from_numpy(load_np_file(config['dataset_nordland_summer_hidden_repr_file_path']))
max_n = summer_hidden_repr.max()
_ = summer_hidden_repr.divide_(max_n)

dataset = summer_hidden_repr

In [None]:
batchsize = int(config['train_batchsize'])
dataLoader_threads = int(config['dataloader_threads'])

dataLoader = DataLoader(dataset, num_workers=dataLoader_threads, batch_size=batchsize, shuffle=False)

In [None]:
with torch.no_grad():

    batch_timings = []
    # hidden_state = encoder.hiddenStates

    for x in dataLoader:
        x = x.to(device)
        
        start = time.perf_counter()

        _ = encoder(x)

        # for xb in x:
        #     hidden_state, _ = encoder.leakyIF(hidden_state, xb.flatten())

        stop = time.perf_counter()

        batch_timings.append(stop - start)

In [None]:
# stop - start / len(dataset)

In [None]:
# stop - start / len(dataset)

In [None]:
(np.sum(batch_timings, dtype=np.float64) / len(dataset)) * 1000

In [None]:
# 0.8998490395766889 * 27592

In [None]:
W_sparse = encoder.W.to_sparse().cpu()
W_sparse

In [None]:
W_sparse.values().shape

In [None]:
(160000 / (8000 * 8000)) * 100

In [None]:
Win = torch.randn((8000,500)).cpu()

In [None]:
import time

In [None]:
start = time.perf_counter_ns()
_ = torch.sparse.mm(W_sparse, Win)
stop = time.perf_counter_ns()
stop - start

In [None]:
W_cpu = encoder.W.cpu()

In [None]:
start = time.perf_counter_ns()
_ = torch.mm(W_cpu, Win)
stop = time.perf_counter_ns()
stop - start

In [None]:
235455300 - 72910700

In [None]:
162544600 / 235455300

In [None]:
# 1 - 690367704523.9889
# 2 - 910950134767.0991
# 3 - 66434.13380690056
# 4 - 913526.5837924036
# 5 - 893609.2381849812

# 0.880147912438388 (direct)
# 0.9139255327631198
# 0.8998490395766889

In [None]:
# 0.9043767903740253 ms per image 