In [544]:
import os
from sklearn.neighbors import KDTree
import joblib
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from ignite.metrics import Loss, RootMeanSquaredError

## Intro

In [545]:
# id = 1 # da 0 a 200

In [546]:
# Paths
path_EEG = '../data/THINGS-EEG2/preprocessed_data/sub-01/preprocessed_eeg_test.npy'
path_feat = '../data/THINGS/Images_features/resnet18/'
path_regr = '../trained_models/Regression_EEG-FEAT/'
path_tree = '../trained_models/Retrieval_IMG/kdtree.joblib'

## From FEAT_EEG to FEAT_IMG

In [547]:
path_mod = '../trained_models/Allignement/'
n_classes = 100
names = sorted([i for i in os.listdir(path_mod) if i[6:9] == str(n_classes)], reverse=False)
names

['model_100_rmse8.0252.pt']

In [548]:
best_name = names[0]
best_name

'model_100_rmse8.0252.pt'

In [549]:
model = torch.load(os.path.join(path_mod, best_name))
model

FC(
  (linear): Linear(in_features=100, out_features=100, bias=True)
)

In [550]:
# train set
class Dataset(Dataset):
  # Class to create the dataset
  def __init__(self):
    # EEG FEATURES
    path = '../data/THINGS-EEG2/preprocessed_data_features/sub-01/preprocessed_eeg_training.npy'
    feat_eeg = np.load(path, allow_pickle=True)
    # feat_eeg = np.median(np.split(feat_eeg, 200), axis=1) # media o mediana
    print('eeg', feat_eeg.shape)

    # IMG FEATURES
    path = '../data/THINGS-EEG2/image_set_features/resnet18/training_images_100_mean.npy'
    feat_img = np.load(path, allow_pickle=True)
    feat_img = np.repeat(feat_img, 4, axis=0)
    print('img', feat_img.shape)

    # only test img and first 100 classes
    path = '../data/THINGS-EEG2/image_metadata/split_rep_train.npy'
    test_idx = np.load(path, allow_pickle=True)[:4*100]
    print('test', test_idx.shape)
    feat_eeg = feat_eeg[test_idx]
    feat_img = feat_img[test_idx]

    self.X = torch.from_numpy(feat_eeg).type(torch.FloatTensor).to('cuda')
    self.Y = torch.from_numpy(feat_img).type(torch.FloatTensor).to('cuda')
    self.len = self.X.shape[0]

  def __len__(self):
    return self.len
  
  def __getitem__(self, index):
    return self.X[index], self.Y[index]

In [551]:
# # test set
# class Dataset(Dataset):
#   # Class to create the dataset
#   def __init__(self):
#     # EEG FEATURES
#     path = '../data/THINGS-EEG2/preprocessed_data_features/sub-01/preprocessed_eeg_test.npy'
#     feat_eeg = np.load(path, allow_pickle=True)
#     # feat_eeg = np.median(np.split(feat_eeg, 200), axis=1) # media o mediana
#     print(feat_eeg.shape)

#     # IMG FEATURES
#     path = '../data/THINGS-EEG2/image_set_features/resnet18/test_images_100.npy'
#     feat_img = np.load(path, allow_pickle=True)
#     feat_img = np.repeat(feat_img, 80, axis=0)
#     print(feat_img.shape)

#     self.X = torch.from_numpy(feat_eeg).type(torch.FloatTensor).to('cuda')
#     self.Y = torch.from_numpy(feat_img).type(torch.FloatTensor).to('cuda')
#     self.len = self.X.shape[0]

#   def __len__(self):
#     return self.len
  
#   def __getitem__(self, index):
#     return self.X[index], self.Y[index]

In [552]:
test = DataLoader(Dataset(), batch_size=4, shuffle=False, drop_last=False)
test.__len__()

eeg (66160, 100)
img (66160, 100)
test (400,)


100

In [553]:
rmse = RootMeanSquaredError()

In [554]:
model.eval()

test_feat = []

for batch in test:
    rmse.reset()

    X = batch[0]
    y = batch[1]
    # print(np.max(X.cpu().numpy()), np.min(X.cpu().numpy()))

    pred = model(X)
    test_feat.append(pred.detach().cpu().numpy())

    rmse.update((pred, y))

print('RMSE:', rmse.compute())

test_feat = np.concatenate(test_feat)
test_feat.shape

RMSE: 6.401962408563804


(400, 100)

In [555]:
np.mean(test_feat), np.min(test_feat), np.max(test_feat)

(-0.030199323, -6.9676266, 9.234492)

## From IMG_FEAT to IMAGE

#### Query

In [556]:
# Load tree
tree = joblib.load(path_tree)

In [557]:
# Query the tree
k = 10
dist, ind = tree.query(test_feat, k=k) # query the tree
dist.shape, ind.shape

((400, 10), (400, 10))

In [558]:
np.unique(ind).shape # se funzionasse perfettamente dovrebbe essere 1000

(978,)

In [559]:
test_info = pd.read_csv('../data/THINGS/Images_features/resnet18/training_info.csv', index_col=0)
test_info['class'] = test_info['name'].apply(lambda x: '_'.join(map(str, x.split('_')[:-1])))
print(test_info.shape)
last = test_info['class'].unique()[100]#.shape

test_info = test_info[test_info['class'] < last]
test_info

(23212, 3)


Unnamed: 0,name,idx,class
0,aardvark_01b.jpg,True,aardvark
1,aardvark_02s.jpg,True,aardvark
2,aardvark_03s.jpg,True,aardvark
3,aardvark_04s.jpg,True,aardvark
4,aardvark_05s.jpg,True,aardvark
...,...,...,...
1530,bee_23s.jpg,False,bee
1531,bee_24s.jpg,False,bee
1532,bee_25s.jpg,False,bee
1533,bee_26s.jpg,False,bee


In [560]:
# test_info = pd.read_csv('../data/THINGS/Images_features/resnet18/test_info.csv', index_col=0)
# test_info['class'] = test_info['name'].apply(lambda x: '_'.join(map(str, x.split('_')[:-1])))
# test_info

#### Metriche estrazione concept esatto

In [561]:
class_rep = np.repeat(test_info['class'].unique(), 4, axis=0)

diz = test_info['class'].to_dict()
vfunc = np.vectorize(lambda x: diz[x])
result = vfunc(ind)

class_rep.shape

(400,)

In [562]:
# class_rep = np.repeat(test_info['class'].unique(), 80, axis=0)

# diz = test_info['class'].to_dict()
# vfunc = np.vectorize(lambda x: diz[x])
# result = vfunc(ind)

# class_rep.shape

In [563]:
a = 0
for n, (c, i) in enumerate(zip(class_rep, result)):
    if c in i:
        a += 1
a

24

In [564]:
print('random:', 9.38) # calcolo combinatorio (considerando 15 img per classe in media)
print('retrieved:', a/len(ind)*100)

random: 9.38
retrieved: 6.0


#### Metriche estrazione immagine esatta

In [565]:
id_rep = np.repeat(test_info[test_info['idx']==True].index, 4)
id_rep

Int64Index([   0,    0,    0,    0,    1,    1,    1,    1,    2,    2,
            ...
            1515, 1515, 1516, 1516, 1516, 1516, 1517, 1517, 1517, 1517],
           dtype='int64', length=4000)

In [566]:
a = 0
for n, (c, i) in enumerate(zip(id_rep, ind)):
    if c in i:
        a += 1
a

2

In [567]:
print('random:', k/len(test_info)*100) # prob
print('retrieved:', a/len(ind)*100) # freq

random: 0.6514657980456027
retrieved: 0.5
