In [None]:
import numpy as np
import torch
from tqdm import tqdm

from datasets_torch_geometric.dataset_factory import create_dataset
import matplotlib.pyplot as plt

In [None]:
# Load dataset
ds = create_dataset(
        dataset_name  = 'NASL',
        dataset_type = 'all'
    )

In [None]:
# put the sample indices for each class in a Dictionary
# and create a mappings
# The code takes a while to run!
#
# label2num: label -> number Example: 'a' -> 0
# num2label: number -> label Example: 0 -> 'a'
# name2ind: file_id -> sample index Example: 't_3803.mat' -> 17

class_dict = {}
num2label = {}
name2ind = {}

for idx,data in enumerate(ds):
    y = data.y[0].item()
    label = data.label[0]
    file_id = data.file_id
    
    if label not in class_dict:
        class_dict[label] = [idx]
        num2label[y] = label
    else:
        class_dict[label].append(idx)
    
    
    name2ind[file_id] = idx
    
label2num = {v: k for k, v in num2label.items()}

In [None]:
# number of events per sample
num_events = []
for idx in class_dict[num2label[0]]:
    num_events.append(ds[idx].num_nodes)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu' # CPU is faster for this task

In [None]:
# Accumulate the number of events for each class
# without separating the polarities
# 
# all_images: Dictionary with the accumulated events for each class

all_images = {}

for key,val in class_dict.items():

    print(key)
    img = torch.zeros([180, 240], dtype=torch.int64).to(device)


    for idx in tqdm(val):
        data = ds[idx].to(device)
        x = data.pos[:,0].long()
        y = data.pos[:,1].long()

        img.index_put_((y, x), torch.ones_like(y).to(device), accumulate=True)
    all_images[key] = img



In [None]:
# Compute the top K values and their indices

def top_k_max_values_and_indices(img, K):
    topk_values, topk_indices = torch.topk(img.view(-1), k=K)
    topk_coords = torch.stack([topk_indices // img.shape[1],topk_indices % img.shape[1]], dim=-1)
    topk_normalized_values = topk_values / torch.sum(img)
    return topk_values, topk_coords, topk_normalized_values

In [None]:
# generate the event number histogram for all classes together
# 
# global_img: the accumulated events for all classes

global_img = torch.zeros_like(all_images['a'],dtype=torch.long).to(device)
for key,val in all_images.items():
    global_img += val
    
fig = plt.figure(figsize=(14,9))
plt.imshow(np.log(global_img.cpu().numpy()))
plt.colorbar()
plt.show()  


In [None]:
# Top K Global pixels having the most number of events


K = 10

global_K_val,global_K_idx, global_K_normalized_val = top_k_max_values_and_indices(global_img, K)
# print([global_K_idx,global_K_val,global_K_normalized_val])
print(global_K_idx[:,[1,0]])

In [None]:
# Top K pixels having the most number of events for each class
#
# max_dict: Dictionary with the top K pixels for each class, their values and normalized values
class_K = 5
max_dict = {}
for letter, img in all_images.items():
    class_K_val, class_K_idx, class_K_normalized_val = top_k_max_values_and_indices(img, class_K)
    max_dict[letter] = [class_K_idx, class_K_val, class_K_normalized_val]
    

In [None]:
# Put all the pixels and their values in a single tensor for all classes
all_max = [torch.cat([v[0] for v in max_dict.values()], dim =0),
              torch.cat([v[1] for v in max_dict.values()], dim =0),
              torch.cat([v[2] for v in max_dict.values()], dim =0)]
# Find the unique pixels
_, unique_idx = np.unique(all_max[0], axis=0,return_index=True)
all_max_unique = [v[unique_idx] for v in all_max]

In [None]:
torch.sort(all_max_unique[2])

In [None]:
Thresh_normalized = 14e-5

unique_idx_thresh_normalized = all_max_unique[2] > Thresh_normalized

all_max_unique_thresh_normalized = [v[unique_idx_thresh_normalized] for v in all_max_unique]

In [None]:
Thresh = 17000

unique_idx_thresh = all_max_unique[1] > Thresh

all_max_unique_thresh = [v[unique_idx_thresh] for v in all_max_unique]

# for letter, max_mat in max_dict.items():
#     print(letter)
#     idx = max_mat[:,-1] > Thresh 
#     print(max_mat[idx])

In [None]:
for letter, max_mat in max_dict.items():
    print(letter)
    img = torch.zeros([180, 240], dtype=torch.float).to(device)
    img.index_put_((max_mat[0][:,0], max_mat[0][:,1]),torch.log(max_mat[1]) , accumulate=False)
    plt.imshow(img.cpu().numpy())
    plt.show()

In [None]:
# rmove_pixels = all_max_unique_thresh_normalized[0]
rmove_pixels = global_K_idx
for letter, img_orig in all_images.items():
    img = img_orig.detach().clone()
    print(letter)
    img.index_put_((rmove_pixels[:,0], rmove_pixels[:,1]),torch.zeros_like(rmove_pixels[:,0]) , accumulate=False)
    plt.imshow(img.cpu().numpy())
    plt.colorbar()
    plt.show()