In [1]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import torch

#### Custom function(s)

In [None]:
# Designed to work with all the activity datasets collected from 
# Braille and MNIST classification for the NTE-Encoding project
def load_activity(path, subset, stimulus_labels=None):

    """
    To load the dataset you need to
    - load the spikes using np.load() and then run np.unpackbits(axis=0)
    - load the labels using torch.load()

    stimulus_labels can be used for 'additional' labelling, as in the case 
    of the Braille dataset, to keep the match between the actual labels and
    their numeric representation

    Fra, Vittorio,
    Politecnico di Torino,
    EDA Group,
    Torino, Italy.
    """
    
    names_listdir = os.listdir(os.path.join(path,subset))
    spikes_file_list = [ii for ii in names_listdir if ii.endswith(".npy")]
    #max_file_idx = np.max([int(ii.rstrip(".npy").split("_")[-1]) for ii in spikes_file_list])
    #common_prefix = next((spikes_file_list[0][:i] for i,(p,*r) in enumerate(zip(*spikes_file_list)) if any(p!=c for c in r)),min(spikes_file_list,key=len))

    spikes = []
    labels = []

    for ii in spikes_file_list:
        
        lbl = torch.load(os.path.join(path,subset,ii.rstrip(".npy")+"_label.pt"))
        spk = torch.swapdims(torch.as_tensor(np.unpackbits(np.load(os.path.join(path,subset,ii)),axis=0,count=lbl.shape[0]), dtype=torch.float32),1,2)
        lbl = lbl.tile(spk.shape[1])
        spk = spk.reshape(-1,spk.shape[-1])
        labels.append(lbl)
        spikes.append(spk)
    
    overall_label = torch.cat(labels)
    individual_channels_activity = torch.cat(spikes)

    activity_df = pd.DataFrame()
    activity_df["Activity"] = list(individual_channels_activity)
    activity_df["Label"] = list(overall_label)
    if stimulus_labels != None:
        overall_stimulus = []
        for ii in overall_label:
            overall_stimulus.append(stimulus_labels[ii])
        activity_df["Stimulus"] = list(overall_stimulus)

    return activity_df

### Activity split analysis

In [None]:
# For Braille
stimulus_labels = ['Space', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K',
    'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']
root_actdir = "./MN_output_Braille/"
common_prefix = "GR_braille"

# # For MNIST
# stimulus_labels = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
# root_actdir = "./MN_Output_MNIST/"
# common_prefix = "GR_mnist"
# #root_actdir = "./Activity/MN_Output_MNIST_c/" # (compressed)
# #common_prefix = "GR_mnist_compressed" # (compressed)

suffix = ["w"]
subsets = ["train", "eval", "test"]

for sfx in suffix:

    seed_folders = [ii for ii in os.listdir(root_actdir) if f"{common_prefix}_{sfx}" in ii]
    max_seed = np.max([int(fld.split("_")[-1]) for fld in seed_folders])

    for seed in range(max_seed+1):

        seed_folder = f"{common_prefix}_{sfx}_{seed}"

        if seed_folder in seed_folders:

            for subset in subsets:

                files = os.listdir(os.path.join(root_actdir,seed_folder,subset))
                files_lbl = [f for f in files if f.endswith(".pt")]
                files_val = [f for f in files if f.endswith(".npy")]
                labels = []
                values = []
                for file in files_lbl:
                    labels.append(torch.load(os.path.join(os.path.join(root_actdir,seed_folder,subset,file))))
                for file in files_val:
                    values.extend(np.unpackbits(np.load(os.path.join(root_actdir,seed_folder,subset,file)),axis=0))
                labels = torch.cat(labels)
                values = np.array(values)

                print(f"Data from {os.path.join(seed_folder,subset)}:")
                print("\tLabels:")
                print(f"\t\tshape: {labels.shape}")
                unique_lbl, count_lbl = np.unique(labels, return_counts=True)
                if len(np.unique(count_lbl)) > 1:
                    unique_lbl_str = []
                    for num,el in enumerate(unique_lbl):
                        unique_lbl_str.append(stimulus_labels[el])
                    print(f"\t\tcount per label: {dict(zip(unique_lbl_str, count_lbl))}")
                else:
                    count, _ = np.unique(count_lbl, return_counts=True)
                    print(f"\t\tcount per label: {count.item()}")

                print("\tValues:")
                print(f"\t\tshape: {values.shape}")

                spike_count, occurrences = np.unique(np.count_nonzero(values, axis=1), return_counts=True)
                plt.figure()
                plt.scatter(spike_count, occurrences/(values.shape[0]*values.shape[-1])*100, s=12, c="tab:red", zorder=3)
                plt.xlabel(f"Spike count [over the {values.shape[1]} time steps]")
                plt.ylabel("Occurrences (%)")
                plt.xlim(-10,values.shape[1]*1.02)
                #plt.ylim(0.5, 1.01*len(labels)*6)
                plt.ylim(1e-5, 200)
                plt.yscale("log")
                plt.grid(visible=True, which='both', axis='y', zorder=1)
                plt.title(os.path.join(seed_folder,subset))
                plt.show()
                
            print("\n")

### Load the activity data and create a dataframe

In [None]:
# # For Braille
# path = "./Activity/MN_output_Braille/GR_braille_w_0"
# subset = "test"
# stimulus_labels = ['Space', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K',
#     'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']

# For MNIST
path = "./Activity/MN_Output_MNIST/GR_mnist_w_0"
subset = "test"
stimulus_labels = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

In [None]:
files = os.listdir(os.path.join(path,subset))
files_lbl = [f for f in files if f.endswith(".pt")]
files_val = [f for f in files if f.endswith(".npy")]
labels = []
values = []
for file in files_lbl:
    labels.append(torch.load(os.path.join(os.path.join(path,subset,file))))
for file in files_val:
    values.extend(np.unpackbits(np.load(os.path.join(path,subset,file)),axis=0))
labels = torch.cat(labels)
values = np.array(values)

print(f"Subset: {subset}")
print("\tLabels:")
print(f"\t\tshape: {labels.shape}")
unique_lbl, count_lbl = np.unique(labels, return_counts=True)
unique_lbl_str = []
for num,el in enumerate(unique_lbl):
    unique_lbl_str.append(stimulus_labels[el])
print(f"\t\tcount per label: {dict(zip(unique_lbl_str, count_lbl))}")

print("\tValues:")
print(f"\t\tshape: {values.shape}")

In [None]:
df = load_activity(path, subset, stimulus_labels)

df.head()

In [None]:
print(f"Data from: \n\t{os.path.join(path,subset)}")
unique_lbl, count_lbl = np.unique(df["Stimulus"], return_counts=True)
print(f"Number of samples: \n\t{int(df.shape[0]/values.shape[-1])}")
if int(df.shape[0]/values.shape[-1]/len(stimulus_labels)) == df.shape[0]/values.shape[-1]/len(stimulus_labels):
    print(f"Number of repetitions for each class sample: \n\t{int(df.shape[0]/values.shape[-1]/len(stimulus_labels))}")
else:
    print(f"Number of repetitions for each class sample: \n\t{df.shape[0]/values.shape[-1]/len(stimulus_labels)}")
print(f"Number of single channel recordings: \n\t{df.shape[0]}")
if len(np.unique(count_lbl)) == 1:
    count, _ = np.unique(count_lbl, return_counts=True)
    print(f"Number of single channel recordings for each class sample: \n\t{count.item()}")
else:
    print(f"Number of single channel recordings for each class sample: \n\t{dict(zip(unique_lbl, count_lbl))}")

### **Read the pickle from activity classification**

In [18]:
act_clas = pd.read_pickle("../../results/activity_classification/MN_activity/MN_output_Braille/GR_braille_w_0_train_20240116_105004.pkl")
act_clas.head()

Unnamed: 0,Letter,Behaviour,Probabilities,Spikes,Sparsity
0,G,O,"[[0.0, 0.0, 11.42, 0.0, 0.0, 0.0, 0.0, 0.0, 0....",59,0.8127
1,S,O,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",34,0.8921
2,J,G,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 100.0, 0.0, 0....",2,0.9937
3,Y,O,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",41,0.8698
4,W,A,"[[24.69, 0.0, 24.69, 0.0, 0.0, 0.0, 24.69, 24....",107,0.6603


In [19]:
chars, counts = np.unique(act_clas["Letter"], return_counts=True)

print(f"Total number of active channels found: {len(act_clas)}\n")

print("Number of active channels found for each Braille character:")
print(dict(zip(chars,counts)))

Total number of active channels found: 38334

Number of active channels found for each Braille character:
{'A': 1416, 'B': 1526, 'C': 1312, 'D': 1435, 'E': 1315, 'F': 1582, 'G': 1389, 'H': 1475, 'I': 1342, 'J': 1406, 'K': 1506, 'L': 1329, 'M': 1599, 'N': 1360, 'O': 1429, 'P': 1340, 'Q': 1419, 'R': 1519, 'S': 1330, 'Space': 1356, 'T': 1375, 'U': 1536, 'V': 1360, 'W': 1588, 'X': 1454, 'Y': 1271, 'Z': 1365}
