In [4]:
import torch
from torch.utils.data import Dataset
import numpy as np
from pathlib import Path
import slayerSNN as snn

In [None]:
class ViTacMMDataset(Dataset):
    def __init__(self, path, sample_file, output_size, spike=True, rectangular=False):
        self.path = path
        self.samples = np.loadtxt(Path(path) / sample_file).astype("int")
        self.output_size = output_size
        
        self.spike=spike
        if spike:
            self.ds_vis = torch.load(Path(path) / "ds_vis.pt")
        else:
            self.ds_vis = torch.load(Path(path) / "ds_vis_non_spike.pt")
            
        self.rectangular = rectangular
        if rectangular:
            self.right_tact = torch.load(Path(path) / "tac_right.pt")
            self.left_tact = torch.load(Path(path) / "tac_left.pt")
        else:
            tact = torch.load(Path(path) / "tact.pt")
            self.tact = tact.reshape(tact.shape[0], -1, 1, 1, tact.shape[-1])

    def __getitem__(self, index):
        input_index = self.samples[index, 0]
        class_label = self.samples[index, 1]
        target_class = torch.zeros((self.output_size, 1, 1, 1))
        target_class[class_label, ...] = 1

        if rectangular:
            return (
                self.right_tact[input_index],
                self.left_tact[input_index],
                self.ds_vis[input_index],
                class_label,
            )
        else:
            return (
                self.tact[input_index],
                self.ds_vis[input_index],
                target_class,
                class_label,
            )

    def __len__(self):
        return self.samples.shape[0]

In [15]:
train_dataset = ViTacDataset(
    path='../data_VT_SNN_new/', sample_file=f"test_80_20_1.txt", output_size=20
)
train_dataset2 = ViTacDataset(
    path='../data_VT_SNN_new/', sample_file=f"test_80_20_1.txt", output_size=20, rectangular=True
)

In [16]:
a, b, c, d = train_dataset[0]
a2, b2, c2, d = train_dataset2[0]

In [17]:
a.shape, b.shape, c.shape, d

(torch.Size([156, 1, 1, 325]), torch.Size([]), torch.Size([20, 1, 1, 1]), 0)

In [18]:
a2.shape, b2.shape, c.shape, d

(torch.Size([2, 9, 7, 325]),
 torch.Size([2, 9, 7, 325]),
 torch.Size([20, 1, 1, 1]),
 0)

In [23]:
torch.unique(b)

tensor([0])

In [24]:
torch.unique(a2), torch.unique(b2)

(tensor([0., 1.]), tensor([0., 1.]))

In [29]:
a = torch.load(Path("../data_VT_SNN/tactile_rectangular/") / "tac_left.pt")

In [30]:
b = torch.load(Path("../data_VT_SNN_new/") / "tac_left.pt")

In [31]:
a == b

tensor([[[[[True, True, True,  ..., True, True, True],
           [True, True, True,  ..., True, True, True],
           [True, True, True,  ..., True, True, True],
           ...,
           [True, True, True,  ..., True, True, True],
           [True, True, True,  ..., True, True, True],
           [True, True, True,  ..., True, True, True]],

          [[True, True, True,  ..., True, True, True],
           [True, True, True,  ..., True, True, True],
           [True, True, True,  ..., True, True, True],
           ...,
           [True, True, True,  ..., True, True, True],
           [True, True, True,  ..., True, True, True],
           [True, True, True,  ..., True, True, True]],

          [[True, True, True,  ..., True, True, True],
           [True, True, True,  ..., True, True, True],
           [True, True, True,  ..., True, True, True],
           ...,
           [True, True, True,  ..., True, True, True],
           [True, True, True,  ..., True, True, True],
           [T