In [1]:
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import confusion_matrix
import torch
from torch import nn
import numpy as np
import copy

In [2]:
device = torch.device('cuda:0')

In [3]:
# Dataset definition
class ViTacDataset(Dataset):
    def __init__(self, datasetPath, sampleFile, modals=0):
        self.path = datasetPath 
        self.samples = np.loadtxt(sampleFile).astype('int')
        self.modality = modals

    def __getitem__(self, index):
        inputIndex  = self.samples[index, 0]
        classLabel  = self.samples[index, 1]
        desiredClass = torch.zeros((20, 1, 1, 1))
        desiredClass[classLabel,...] = 1

        if self.modality == 0:
            inputSpikes_tact = np.load(self.path + str(inputIndex.item()) + '_tact.npy')
            #inputSpikes_tact = np.delete(inputSpikes_tact, [], 0)
            inputSpikes_tact = torch.FloatTensor(inputSpikes_tact)#[39:, :,:]
            return inputSpikes_tact.reshape((-1, 1, 1, inputSpikes_tact.shape[-1])), desiredClass, classLabel
        elif self.modality == 1:
            inputSpikes_vis = np.load(self.path + str(inputIndex.item()) + '_vis.npy')
            inputSpikes_vis = torch.FloatTensor(inputSpikes_vis)
            return inputSpikes_vis, desiredClass, classLabel, inputIndex.item()
        elif self.modality == 2:
            inputSpikes_tact = np.load(self.path + str(inputIndex.item()) + '_tact.npy')
            inputSpikes_tact = torch.FloatTensor(inputSpikes_tact)
            inputSpikes_vis = np.load(self.path + str(inputIndex.item()) + '_vis.npy')
            inputSpikes_vis = torch.FloatTensor(inputSpikes_vis)
            return inputSpikes_tact.reshape((-1, 1, 1, inputSpikes_tact.shape[-1])), inputSpikes_vis, desiredClass, classLabel

    def __len__(self):
        return self.samples.shape[0]

In [4]:
data_dir = '/home/tasbolat/some_python_examples/data_VT_SNN/'

In [5]:
# Dataset and dataLoader instances.
split_list = ['80_20_1','80_20_2','80_20_3','80_20_4','80_20_5']

training_loader = []
testing_loader = []

for k in range(5):
    
    trainingSet = ViTacDataset(datasetPath = data_dir, 
                                sampleFile = data_dir + "/train_" + split_list[k] + ".txt",
                               modals=1)
    trainLoader = DataLoader(dataset=trainingSet, batch_size=8, shuffle=False, num_workers=8)
    training_loader.append(trainLoader)
    
    testingSet = ViTacDataset(datasetPath = data_dir, 
                                sampleFile  = data_dir + "/test_" + split_list[k] + ".txt", 
                              modals=1)
    testLoader = DataLoader(dataset=testingSet, batch_size=8, shuffle=False, num_workers=8)
    testing_loader.append(testLoader)

In [6]:
class SimplePool(nn.Module):

    def __init__(self):
        super(SimplePool, self).__init__()
        self.pool = nn.AvgPool3d((1,4,4), padding=[0,1,1], stride=(1,4,4))

    def forward(self, input_data):
        out_data = self.pool(input_data)
        #print(out_data.shape)
        return out_data

In [8]:
path = data_dir + "pooled_vis/"
net = SimplePool().to(device)
big_viz = []
for i in range(400):
    inputSpikes_vis = np.load(data_dir + str(i) + '_vis.npy')
    inputSpikes_vis = torch.FloatTensor(inputSpikes_vis)
    inputSpikes_vis = inputSpikes_vis.unsqueeze(0)
    inputSpikes_vis = inputSpikes_vis.to(device)
    inputSpikes_vis = inputSpikes_vis.permute(0,1,4,2,3)
    out = net.forward(inputSpikes_vis)
    out = out.squeeze().permute(0,2,3,1)
    big_viz.append(out.detach().cpu())

In [9]:
big_viz_data = torch.stack(big_viz)

In [10]:
big_viz_data.shape

torch.Size([400, 2, 63, 50, 325])

In [11]:
torch.save(big_viz_data, path + 'ds_vis.pt')

In [20]:


# for k in range(5):
#     print(k)
#     trainLoader = training_loader[k]
#     testLoader = testing_loader[k]
#     # Define model
#     net = SimplePool().to(device)
#     # Create snn loss instance.
    
#     for i, (input_tact_left, _, label, inpInd) in enumerate(trainLoader.dataset, 0):
#         #print(input_tact_left.shape)
#         input_tact_left = input_tact_left.unsqueeze(0)
#         input_tact_left = input_tact_left.to(device)
#         input_tact_left = input_tact_left.permute(0,1,4,2,3)
#         out = net.forward(input_tact_left)

#         out = out.squeeze().permute(0,2,3,1)
#         np.save(path + str(inpInd) +'_vis.npy', out.cpu().numpy())

#     for i, (input_tact_left, _, label, inpInd) in enumerate(testLoader.dataset, 0):
#         #print(input_tact_left.shape)
#         input_tact_left = input_tact_left.unsqueeze(0)
#         input_tact_left = input_tact_left.to(device)
#         input_tact_left = input_tact_left.permute(0,1,4,2,3)
#         out = net.forward(input_tact_left)
#         out = out.squeeze().permute(0,2,3,1)
#         np.save(path + str(inpInd) +'_vis.npy', out.cpu().numpy())


0
1
2
3
4
