In [13]:
import os
import time
import psutil
import glob
import matplotlib.pyplot as plt
import seaborn as sns
import nni
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
from torchmetrics.classification import ConfusionMatrix, Precision, Recall, F1Score
from pytorch_lightning.loggers import TensorBoardLogger
from snntorch import surrogate
import snntorch as snn
from torch.utils.data import DataLoader, random_split, Subset
from typing import List
import random

from Server_DataLoader import CustomDataModule, CLASSES
from utils_metrics import compute_sharpness, compute_gradient_noise, visualize_loss_landscape

In [14]:
# Enable Tensor Core optimizations for faster matrix multiplications.
# This trades a tiny bit of precision (from fp32 to fp16) for significantly better performance on modern NVIDIA GPUs.
torch.set_float32_matmul_precision('medium')

In [15]:
# Hyperparameters and Constants
input_size = 16
hidden_size = 24
output_size = 4

bit_width = 16 # n-bit quantization
num_epochs = 1
batch_size = 32

scheduler_step_size = 20
scheduler_gamma = 0.4

output_threshold = 1e7
hidden_threshold = 1

#num_workers = max(1, os.cpu_count() - 1)
num_workers = 1
hidden_reset_mechanism = 'subtract'
output_reset_mechanism = 'none'

Vmem_shift_for_MSELoss = 0.2

SEED = 42
pl.seed_everything(SEED) 

Seed set to 42


42

In [16]:
def generate_tau_beta_values(hidden_size, output_size):
    def create_power_vector(n, size):
        powers = [2 ** i for i in range(1, n + 1)]
        repeat_count = size // n
        power_vector = np.repeat(powers, repeat_count)
        return power_vector

    size = hidden_size
    tau_hidden_1 = create_power_vector(n=2, size=size)
    tau_hidden_2 = create_power_vector(n=4, size=size)
    tau_hidden_3 = create_power_vector(n=8, size=size)

    delta_t = 1
    beta_hidden_1 = torch.exp(-torch.tensor(delta_t) / torch.tensor(tau_hidden_1, dtype=torch.float32))
    beta_hidden_2 = torch.exp(-torch.tensor(delta_t) / torch.tensor(tau_hidden_2, dtype=torch.float32))
    beta_hidden_3 = torch.exp(-torch.tensor(delta_t) / torch.tensor(tau_hidden_3, dtype=torch.float32))

    tau_output = np.repeat(10, output_size)
    beta_output = torch.exp(-torch.tensor(delta_t) / torch.tensor(tau_output, dtype=torch.float32))

    return beta_hidden_1, beta_hidden_2, beta_hidden_3, beta_output

In [17]:
def get_nni_params():
    params = {
        'learning_rate': 0.001,
        'optimizer_betas': (0.9, 0.99),
        'fast_sigmoid_slope': 10,
    }
    tuner_params = nni.get_next_parameter()
    params.update(tuner_params)
    return params

In [18]:
# # Quantization with fixed range.

# class Fake_Quantize_n_bit(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.levels = 2 ** bit_width
#         self.w_min = -0.5
#         self.w_max = 1.0

#     def forward(self, input):
#         # Quantization steps:
#         # Clamp to [w_min, w_max] to avoid going out of range.
#         input_clamped = torch.clamp(input, self.w_min, self.w_max)
#         scale = (self.w_max - self.w_min) / (self.levels - 1)
#         quant_indices = torch.round((input_clamped - self.w_min) / scale)
#         quant_w = quant_indices * scale + self.w_min
#         return quant_w


In [19]:
# Quantization with dynamic range, based on the min/max of the current layer.

class Fake_Quantize_n_bit(nn.Module):
    def __init__(self, bit_width=16):
        super().__init__()
        self.levels = 2 ** bit_width

    def forward(self, input):
        # Compute dynamic min and max from the current layer's weights
        w_min = input.min().item()
        w_max = input.max().item()

        # Compute scale based on dynamic range
        scale = (w_max - w_min) / (self.levels - 1)
        x = (input - w_min) / scale
        # Map to quantization indices
        quant_indices = x + torch.round(x).detach() - x.detach()
        # Map back to the quantized range
        quant_w = quant_indices * scale + w_min
        return quant_w
    

In [20]:
class QuantLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=False):
        super().__init__(in_features, out_features, bias=bias)
        self.fake_quant = Fake_Quantize_n_bit()

    def forward(self, input):
        quant_weight = self.fake_quant(self.weight)
        return nn.functional.linear(input, quant_weight, self.bias)

def finalize_quantization(model):
    with torch.no_grad():
        model.eval()
        for module in model.modules():
            if hasattr(module, 'fake_quant') and hasattr(module, 'weight'):
                quantized_weight = module.fake_quant(module.weight)
                module.weight.data.copy_(quantized_weight)

In [21]:
class SNNQUT(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, 
                 beta_hidden_1, beta_hidden_2, beta_hidden_3, beta_output, 
                 hidden_reset_mechanism, output_reset_mechanism, 
                 hidden_threshold, output_threshold, fast_sigmoid_slope):
        super().__init__()
        self.fc1 = QuantLinear(input_size, hidden_size, bias=False)
        #self.fc1 = nn.Linear(input_size, hidden_size, bias=False)
        self.lif1 = snn.Leaky(beta=beta_hidden_1, reset_mechanism=hidden_reset_mechanism,
                              threshold=hidden_threshold, 
                              spike_grad=surrogate.fast_sigmoid(slope=fast_sigmoid_slope))

        self.fc2 = QuantLinear(hidden_size, hidden_size, bias=False)
        #self.fc2 = nn.Linear(hidden_size, hidden_size, bias=False)
        self.lif2 = snn.Leaky(beta=beta_hidden_2, reset_mechanism=hidden_reset_mechanism,
                              threshold=hidden_threshold, 
                              spike_grad=surrogate.fast_sigmoid(slope=fast_sigmoid_slope))

        self.fc3 = QuantLinear(hidden_size, hidden_size, bias=False)
        #self.fc3 = nn.Linear(hidden_size, hidden_size, bias=False)
        self.lif3 = snn.Leaky(beta=beta_hidden_3, reset_mechanism=hidden_reset_mechanism,
                              threshold=hidden_threshold, 
                              spike_grad=surrogate.fast_sigmoid(slope=fast_sigmoid_slope))

        self.fc4 = QuantLinear(hidden_size, output_size, bias=False)
        #self.fc4 = nn.Linear(hidden_size, output_size, bias=False)
        self.lif4 = snn.Leaky(beta=beta_output, reset_mechanism=output_reset_mechanism,
                              threshold=output_threshold)
        self._initialize_weights()

    def _initialize_weights(self):
        nn.init.xavier_normal_(self.fc1.weight)
        nn.init.xavier_normal_(self.fc2.weight)
        nn.init.xavier_normal_(self.fc3.weight)
        nn.init.xavier_normal_(self.fc4.weight)

    def forward(self, x):
        x = x.float()
        batch_size, time_steps, _ = x.shape
        device = x.device

        mem1 = torch.zeros(batch_size, self.fc1.out_features, device=device)
        mem2 = torch.zeros(batch_size, self.fc2.out_features, device=device)
        mem3 = torch.zeros(batch_size, self.fc3.out_features, device=device)
        mem4 = torch.zeros(batch_size, self.fc4.out_features, device=device)

        spk1_rec, spk2_rec, spk3_rec, spk4_rec, mem4_rec = [], [], [], [], []

        for step in range(time_steps):
            cur1 = self.fc1(x[:, step, :])
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            cur3 = self.fc3(spk2)
            spk3, mem3 = self.lif3(cur3, mem3)
            cur4 = self.fc4(spk3)
            spk4, mem4 = self.lif4(cur4, mem4)

            spk1_rec.append(spk1)
            spk2_rec.append(spk2)
            spk3_rec.append(spk3)
            spk4_rec.append(spk4)
            mem4_rec.append(mem4)

        return (torch.stack(spk1_rec, dim=0),
                torch.stack(spk2_rec, dim=0),
                torch.stack(spk3_rec, dim=0),
                torch.stack(spk4_rec, dim=0),
                torch.stack(mem4_rec, dim=0))
    

In [22]:
class Lightning_SNNQUT(pl.LightningModule):
    def __init__(
        self,
        input_size,
        hidden_size,
        output_size,
        beta_hidden_1,
        beta_hidden_2,
        beta_hidden_3,
        beta_output,
        hidden_reset_mechanism,
        output_reset_mechanism,
        hidden_threshold,
        output_threshold,
        learning_rate,
        scheduler_step_size,
        scheduler_gamma,
        optimizer_betas,
        fast_sigmoid_slope,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.model = SNNQUT(
            input_size=self.hparams.input_size,
            hidden_size=self.hparams.hidden_size,
            output_size=self.hparams.output_size,
            beta_hidden_1=beta_hidden_1,
            beta_hidden_2=beta_hidden_2,
            beta_hidden_3=beta_hidden_3,
            beta_output=beta_output,
            hidden_reset_mechanism=hidden_reset_mechanism,
            output_reset_mechanism=output_reset_mechanism,
            output_threshold=output_threshold,
            hidden_threshold=hidden_threshold,
            fast_sigmoid_slope=fast_sigmoid_slope,
        )

        self.loss_function = nn.MSELoss()
        self.train_confmat = ConfusionMatrix(task='multiclass', num_classes=self.hparams.output_size)
        self.val_confmat = ConfusionMatrix(task='multiclass', num_classes=self.hparams.output_size)
        self.test_confmat = ConfusionMatrix(task='multiclass', num_classes=self.hparams.output_size)

        # Precision, recall, F1
        self.precision_metric = Precision(task='multiclass', num_classes=self.hparams.output_size, average=None)
        self.recall_metric = Recall(task='multiclass', num_classes=self.hparams.output_size, average=None)
        self.f1_metric = F1Score(task='multiclass', num_classes=self.hparams.output_size, average=None)

        self.epoch_start_time = None

    def forward(self, x):
        return self.model(x)

    def compute_loss_and_metrics(self, spk4_rec, mem4_rec, labels, mode='Vmem',dict_spkcount=None):
        labels_expanded = labels.unsqueeze(0).expand(mem4_rec.size(0), -1, -1)
        loss = self.loss_function(mem4_rec, (labels_expanded + Vmem_shift_for_MSELoss))
        if dict_spkcount != None:
            reg = (dict_spkcount['spk1'] + dict_spkcount['spk2'] + dict_spkcount['spk3'] + dict_spkcount['spk4'])**2*1e-12
            #print("loss:", loss)
            #print("reg",reg)
            loss += reg
        if mode == 'Vmem':
            final_out = mem4_rec.sum(0)
        elif mode == 'spike':
            final_out = spk4_rec.sum(0)
        else:
            raise ValueError(f"Invalid mode: {mode}, use 'Vmem' or 'spike'")

        _, predicted = final_out.max(-1)
        _, targets = labels.max(-1)

        correct = predicted.eq(targets).sum().item()
        total = targets.numel()
        accuracy = correct / total if total > 0 else 0.0

        return loss, accuracy, predicted, targets

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        spk1, spk2, spk3, spk4, mem4 = self(inputs)
        dict_spkcount = {'spk1': spk1.sum().item(), 'spk2': spk2.sum().item(), 'spk3': spk3.sum().item(), 'spk4': spk4.sum().item()}
        loss, accuracy, predicted, targets = self.compute_loss_and_metrics(spk4, mem4, labels, mode='Vmem',dict_spkcount=dict_spkcount)
        self.train_confmat.update(predicted, targets)
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_accuracy', accuracy*100, on_epoch=True, prog_bar=True)
        return loss

    def on_train_epoch_start(self):
        self.epoch_start_time = time.time()

    def on_train_epoch_end(self):
        # Weight distributions logging
        for name, param in self.model.named_parameters():
            if 'weight' in name:
                self.logger.experiment.add_histogram(f'weights/{name}', param, self.current_epoch)

        # Gradient histograms logging
        for name, param in self.named_parameters():
            if param.grad is not None:
                self.logger.experiment.add_histogram(f'grads/{name}', param.grad, self.current_epoch)

        # Runtime per epoch logging
        epoch_time = time.time() - self.epoch_start_time
        self.log('epoch_time', epoch_time, on_epoch=True)

        # Memory usage logging
        if torch.cuda.is_available():
            mem_alloc = torch.cuda.memory_allocated()/(1024**2)
            self.log('gpu_memory_MB', mem_alloc, on_epoch=True)
        process = psutil.Process(os.getpid())
        cpu_mem = process.memory_info().rss / (1024**2)
        self.log('cpu_memory_MB', cpu_mem, on_epoch=True)

        # Learning rate logging
        optimizer = self.optimizers()
        lr = optimizer.param_groups[0]['lr']
        self.log('learning_rate', lr, on_epoch=True)

        self.train_confmat.reset()


    def on_validation_epoch_end(self):
        self.val_confmat.reset()

        # Logging membrane potentials and spiking activity for a small batch
        sample_batch = next(iter(self.trainer.datamodule.val_dataloader()))
        sample_input, sample_labels = sample_batch
        sample_input = sample_input.to(self.device)

        with torch.no_grad():
            spk1, spk2, spk3, spk4, mem4 = self(sample_input)


            self.logger.experiment.add_histogram('mem4_distribution', mem4.cpu(), self.current_epoch)
            self.logger.experiment.add_histogram('spk4_distribution', spk4.cpu().float(), self.current_epoch)
            self.logger.experiment.add_scalar('spk1_count', spk1.sum().cpu(), self.current_epoch)
            self.logger.experiment.add_scalar('spk2_count', spk2.sum().cpu(), self.current_epoch)
            self.logger.experiment.add_scalar('spk3_count', spk3.sum().cpu(), self.current_epoch)
            self.logger.experiment.add_scalar('spk4_count', spk4.sum().cpu(), self.current_epoch)

    def on_after_backward(self):
        # Re-quantize the model weights after backward to ensure no drift from quantized levels
        finalize_quantization(self.model)

        # Logging effective step size = lr * grad_norm
        optimizer = self.optimizers()
        lr = optimizer.param_groups[0]['lr']
        total_grad_norm = 0.0
        for p in self.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_grad_norm += param_norm.item()**2
        total_grad_norm = total_grad_norm**0.5
        effective_step_size = lr * total_grad_norm
        self.log('effective_step_size', effective_step_size, on_epoch=True)

    def on_fit_end(self):
        # Final quantization of weights and model saving
        finalize_quantization(self.model)
        model_path = os.path.join(self.logger.save_dir, self.logger.name, f"final_quantized_model_epoch{self.current_epoch}.pt")
        torch.save(self.model.state_dict(), model_path)

        # # Compute sharpness and gradient noise on a small batch
        # train_loader = self.trainer.datamodule.train_dataloader()
        # inputs, targets = next(iter(train_loader))
        # inputs, targets = inputs.to(self.device), targets.to(self.device)

        # sharpness = compute_sharpness(self.model, inputs, targets, self.loss_function, epsilon=1e-3)
        # self.logger.experiment.add_scalar('sharpness', sharpness, self.current_epoch)


        # gradient_noise = compute_gradient_noise(self.model, train_loader, self.loss_function, device=self.device, num_batches=5)
        # self.logger.experiment.add_scalar('gradient_noise', gradient_noise, self.current_epoch)

        # # Visualize loss landscape
        # fig = visualize_loss_landscape(self.model, inputs, targets, self.loss_function, d1_scale=0.05, d2_scale=0.05, steps=5)
        # self.logger.experiment.add_figure('loss_landscape', fig, self.current_epoch)
        # plt.close(fig)

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        spk1, spk2, spk3, spk4, mem4 = self(inputs)
        loss, accuracy, predicted, targets = self.compute_loss_and_metrics(spk4, mem4, labels, mode='Vmem')
        self.val_confmat.update(predicted, targets)
        self.precision_metric.update(predicted, targets)
        self.recall_metric.update(predicted, targets)
        self.f1_metric.update(predicted, targets)

        final_out = mem4.sum(0)

        if batch_idx == 0:  # once per epoch print some info to see what's happening
            print("final_out:", final_out[:5])    # print only first 5 samples
            print("targets:", labels[:5])
            print("predicted classes:", predicted[:10])
            print("true classes:", targets[:10])

        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_accuracy', accuracy * 100, on_epoch=True, prog_bar=True)
        return {'val_loss': loss, 'val_accuracy': accuracy}

    def test_step(self, batch, batch_idx):
        inputs, labels = batch
        spk1, spk2, spk3, spk4, mem4 = self(inputs)
        loss, accuracy, predicted, targets = self.compute_loss_and_metrics(spk4, mem4, labels, mode='spike')
        self.test_confmat.update(predicted, targets)
        self.log('test_loss', loss, on_epoch=True, prog_bar=True)
        self.log('test_accuracy', accuracy * 100, on_epoch=True, prog_bar=True)
        return {'test_loss': loss, 'test_accuracy': accuracy}

    def configure_optimizers(self):
        optimizer = optim.Adamax(
            self.parameters(),
            lr=self.hparams.learning_rate,
            betas=self.hparams.optimizer_betas,
        )
        scheduler = optim.lr_scheduler.StepLR(
            optimizer,
            step_size=self.hparams.scheduler_step_size,
            gamma=self.hparams.scheduler_gamma,
        )
        return [optimizer], [scheduler]

In [23]:
def main():
    params = get_nni_params()
    beta_hidden_1, beta_hidden_2, beta_hidden_3, beta_output = generate_tau_beta_values(hidden_size, output_size)

    train_val_dir = 'data/YES_COCHLEA_DATASET/TRAIN_VAL'
    test_dir = 'data/YES_COCHLEA_DATASET/TEST'

    data_module = CustomDataModule(
        train_val_dir=train_val_dir,
        test_dir=test_dir,
        batch_size=batch_size,
        num_workers=num_workers,
        seed=SEED
    )
    data_module.setup('fit')

    model = Lightning_SNNQUT(
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        beta_hidden_1=beta_hidden_1,
        beta_hidden_2=beta_hidden_2,
        beta_hidden_3=beta_hidden_3,
        beta_output=beta_output,
        hidden_reset_mechanism=hidden_reset_mechanism,
        output_reset_mechanism=output_reset_mechanism,
        learning_rate=params['learning_rate'],
        optimizer_betas=params['optimizer_betas'],
        scheduler_step_size=scheduler_step_size,
        scheduler_gamma=scheduler_gamma,
        output_threshold=output_threshold,
        hidden_threshold=hidden_threshold,
        fast_sigmoid_slope=params['fast_sigmoid_slope'],
    )

    logger = TensorBoardLogger(save_dir='logs', name='Mikel_LIF_quant_5bit')
    trainer = pl.Trainer(max_epochs=num_epochs, log_every_n_steps=10, logger=logger)

    trainer.fit(model, datamodule=data_module)

    # At this point, the model is final_quantized and saved in on_fit_end. We can now run the separate test script after restarting.

    # Perform validation before finishing:
    trainer.validate(model, datamodule=data_module)
    val_accuracy = trainer.callback_metrics.get('val_accuracy', torch.tensor(0)).item()

    nni.report_final_result(val_accuracy)

In [None]:
if __name__ == '__main__':
    main()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type                      | Params | Mode 
-----------------------------------------------------------------------
0 | model            | SNNQUT                    | 1.6 K  | train
1 | loss_function    | MSELoss                   | 0      | train
2 | train_confmat    | MulticlassConfusionMatrix | 0      | train
3 | val_confmat      | MulticlassConfusionMatrix | 0      | train
4 | test_confmat     | MulticlassConfusionMatrix | 0      | train
5 | precision_metric | MulticlassPrecision       | 0      | train
6 | recall_metric    | MulticlassRecall          | 0      | train
7 | f1_metric        | MulticlassF1Score         | 0      | train
-----------------------------------------------------------------------
1.6 K     Trainable params
0         Non-trainable params
1.6 K     Total params
0.007     Total estima

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]final_out: tensor([[-1503.1726,  1637.3700,  1434.5337,  5215.9688],
        [ -936.0920,  1092.3643,  1480.9445,  4592.2100],
        [-2756.4077,  3505.3445,  1363.5493,  7212.2573],
        [-1391.9362,  1573.7344,  1605.1852,  5164.8145],
        [-1574.7911,  1828.5530,  1567.7847,  5513.4351]], device='cuda:0')
targets: tensor([[0., 1., 0., 0.],
        [0., 0., 0., 1.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [1., 0., 0., 0.]], device='cuda:0')
predicted classes: tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')
true classes: tensor([1, 3, 0, 1, 0, 3, 1, 2, 0, 1], device='cuda:0')
Epoch 0:   0%|          | 0/576 [00:00<?, ?it/s]                           loss: tensor(15.6338, device='cuda:0', grad_fn=<MseLossBackward0>)
reg 0.455279465536
Epoch 0:   0%|          | 1/576 [00:01<09:50,  0.97it/s, v_num=26, train_accuracy_step=18.80]loss: tensor(11.1601, device='cuda:0', grad_fn=<MseLossBac

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs

# TESTS

In [14]:
# import torch
# import matplotlib.pyplot as plt

# # Step 1: Generate a random normal (Gaussian) distribution of values
# torch.manual_seed(0)  # For reproducibility
# original_values = torch.randn(10000) * 0.2 + 0.5  # mean=0.5, std=0.2

# # Step 2: Define a 5-bit quantization function
# def quantize_5bit(x, w_min=0.001, w_max=1.0, levels=16):
#     # Ensure x is within the defined range
#     x_clamped = torch.clamp(x, w_min, w_max)
#     scale = (w_max - w_min) / (levels - 1)
#     # Map to quantization indices
#     quant_indices = torch.round((x_clamped - w_min) / scale)
#     # Map back to original range
#     quant_x = quant_indices * scale + w_min
#     return quant_x

# # Step 3: Apply quantization
# quantized_values = quantize_5bit(original_values)

# # Step 4: Plot histograms
# fig, axes = plt.subplots(1, 2, figsize=(10, 4))

# axes[0].hist(original_values.numpy(), bins=50, color='blue', alpha=0.7, edgecolor='black')
# axes[0].set_title('Original Distribution')
# axes[0].set_xlabel('Value')
# axes[0].set_ylabel('Count')

# axes[1].hist(quantized_values.numpy(), bins=32, color='green', alpha=0.7, edgecolor='black')
# axes[1].set_title('Quantized (5-bit) Distribution')
# axes[1].set_xlabel('Value')
# axes[1].set_ylabel('Count')

# plt.tight_layout()
# plt.show()

In [15]:
# import os
# import glob
# import torch
# import numpy as np
# from torch.utils.data import Dataset, DataLoader

# # Define your classes here
# CLASSES = ["CAR", "STREET", "HOME", "CAFE"]

# def get_label_from_folder(folder_name):
#     # Extract label from folder name by splitting at '-'
#     # Adjust this logic as needed for your naming convention.
#     base = os.path.basename(folder_name)
#     label = base.split('-')[0]
#     return label

# def one_hot_encode(label):
#     idx = CLASSES.index(label)
#     vec = np.zeros(len(CLASSES), dtype=np.float32)
#     vec[idx] = 1.0
#     return vec

# class CustomSNNTrainValDataset(Dataset):
#     def __init__(self, root_dir, class_list=CLASSES, time_steps=1000, input_dim=16, debug=False):
#         self.root_dir = root_dir
#         self.class_list = class_list
#         self.time_steps = time_steps
#         self.input_dim = input_dim
#         self.debug = debug
#         self.samples = []

#         # Collect all CSV files and their labels
#         for folder in os.listdir(self.root_dir):
#             folder_path = os.path.join(self.root_dir, folder)
#             if not os.path.isdir(folder_path):
#                 continue
#             label = get_label_from_folder(folder_path)
#             if label not in self.class_list:
#                 continue
#             label_vec = one_hot_encode(label)
#             csv_files = glob.glob(os.path.join(folder_path, '*.csv'))
#             for csv_file in csv_files:
#                 data_array = self.load_csv(csv_file, self.time_steps, self.input_dim)
#                 self.samples.append((csv_file, data_array, label_vec))

#     def load_csv(self, csv_path, time_steps, input_dim):
#         data = np.loadtxt(csv_path, delimiter=',', dtype=np.float32)
#         if data.shape != (time_steps, input_dim):
#             raise ValueError(f"CSV {csv_path} shape mismatch: {data.shape}")
#         return data

#     def __len__(self):
#         return len(self.samples)

#     def __getitem__(self, idx):
#         csv_file, data, label = self.samples[idx]
#         if self.debug:
#             print(f"[DEBUG] Loading file: {csv_file}")
#             print(f"[DEBUG] Data shape: {data.shape}")
#             print(f"[DEBUG] Label: {label}")
#         return torch.tensor(data), torch.tensor(label)

# # Example usage:
# if __name__ == "__main__":
#     # Replace with your actual directory
#     train_val_dir = 'data/YES_COCHLEA_DATASET/TRAIN_VAL'
    
#     # Create the dataset with debug mode enabled
#     dataset = CustomSNNTrainValDataset(train_val_dir, debug=True)

#     # Wrap in DataLoader
#     dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0)

#     # Inspect some batches
#     # This will trigger prints from the dataset __getitem__ if debug=True
#     for batch_idx, (inputs, labels) in enumerate(dataloader):
#         print(f"Batch {batch_idx}:")
#         print(f"  inputs shape: {inputs.shape} (expected [batch_size, time_steps, input_dim])")
#         print(f"  labels shape: {labels.shape} (expected [batch_size, num_classes])")

#         # Print sample values
#         # Inputs is typically [batch, time_steps, input_dim]
#         # Labels is [batch, num_classes]
#         print(f"  Sample input[0, 0, :5]: {inputs[0, 0, :5]}")
#         print(f"  Sample label[0]: {labels[0]}")

#         # After a couple of batches, stop to avoid too much output
#         if batch_idx == 2:
#             break
