# TESTING SCRIPT
This script:
1. Loads the final quantized model from a saved checkpoint.
2. Iterates over the test dataset folders, loads each CSV file individually (large data), feeds it into the model, and reports metrics.

In [None]:
import torch
import os
from Server_DataLoader import CustomSNNTestDataset, CLASSES
import numpy as np

test_dir = 'data/PROCESSED_YES_COCHLEA/CUT/TEST'
model_path = 'logs/Mikel_LIF_quant_5bit/final_quantized_model_epoch1.pt'  # Example path

In [None]:
class FakeQuantize5bit(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.levels = 32
        self.w_min = 0.001
        self.w_max = 1.0

    def forward(self, input):

        input_clamped = torch.clamp(input, self.w_min, self.w_max)
        scale = (self.w_max - self.w_min) / (self.levels - 1)
        quant_indices = torch.round((input_clamped - self.w_min) / scale)
        quant_w = quant_indices * scale + self.w_min 

        return quant_w

class QuantLinear(torch.nn.Linear):
    def __init__(self, in_features, out_features, bias=False):
        super().__init__(in_features, out_features, bias=bias)
        self.fake_quant = FakeQuantize5bit()

    def forward(self, input):
        quant_weight = self.fake_quant(self.weight)
        return torch.nn.functional.linear(input, quant_weight, self.bias)

In [None]:
# Define SNNQUT structure (same as training code)
# Make sure it matches exactly the trained model structure
from snntorch import surrogate
import snntorch as snn

class SNNQUT(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size, 
                 beta_hidden_1, beta_hidden_2, beta_hidden_3, beta_output, 
                 hidden_reset_mechanism, output_reset_mechanism, 
                 hidden_threshold, output_threshold, fast_sigmoid_slope):
        super().__init__()
        self.fc1 = QuantLinear(input_size, hidden_size, bias=False)
        self.lif1 = snn.Leaky(beta=beta_hidden_1, reset_mechanism=hidden_reset_mechanism,
                              threshold=hidden_threshold, 
                              spike_grad=surrogate.fast_sigmoid(slope=fast_sigmoid_slope))

        self.fc2 = QuantLinear(hidden_size, hidden_size, bias=False)
        self.lif2 = snn.Leaky(beta=beta_hidden_2, reset_mechanism=hidden_reset_mechanism,
                              threshold=hidden_threshold, 
                              spike_grad=surrogate.fast_sigmoid(slope=fast_sigmoid_slope))

        self.fc3 = QuantLinear(hidden_size, hidden_size, bias=False)
        self.lif3 = snn.Leaky(beta=beta_hidden_3, reset_mechanism=hidden_reset_mechanism,
                              threshold=hidden_threshold, 
                              spike_grad=surrogate.fast_sigmoid(slope=fast_sigmoid_slope))

        self.fc4 = QuantLinear(hidden_size, output_size, bias=False)
        self.lif4 = snn.Leaky(beta=beta_output, reset_mechanism=output_reset_mechanism,
                              threshold=output_threshold)

    def forward(self, x):
        x = x.float()
        batch_size, time_steps, _ = x.shape
        device = x.device

        mem1 = torch.zeros(batch_size, self.fc1.out_features, device=device)
        mem2 = torch.zeros(batch_size, self.fc2.out_features, device=device)
        mem3 = torch.zeros(batch_size, self.fc3.out_features, device=device)
        mem4 = torch.zeros(batch_size, self.fc4.out_features, device=device)

        spk4_rec = []
        mem4_rec = []

        for step in range(time_steps):
            cur1 = self.fc1(x[:, step, :])
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            cur3 = self.fc3(spk2)
            spk3, mem3 = self.lif3(cur3, mem3)
            cur4 = self.fc4(spk3)
            spk4, mem4 = self.lif4(cur4, mem4)

            spk4_rec.append(spk4)
            mem4_rec.append(mem4)

        return torch.stack(spk4_rec, dim=0), torch.stack(mem4_rec, dim=0)

def load_model(model_path, input_size=16, hidden_size=24, output_size=4,
               beta_hidden_1=None, beta_hidden_2=None, beta_hidden_3=None, beta_output=None,
               hidden_reset_mechanism='subtract', output_reset_mechanism='none',
               hidden_threshold=1, output_threshold=1e7, fast_sigmoid_slope=10):
    # Use the same beta values as training script
    def create_power_vector(n, size):
        powers = [2 ** i for i in range(1, n + 1)]
        repeat_count = size // n
        power_vector = np.repeat(powers, repeat_count)
        return power_vector

    size = hidden_size
    tau_hidden_1 = create_power_vector(n=2, size=size)
    tau_hidden_2 = create_power_vector(n=4, size=size)
    tau_hidden_3 = create_power_vector(n=8, size=size)

    delta_t = 1
    beta_hidden_1 = torch.exp(-torch.tensor(delta_t) / torch.tensor(tau_hidden_1, dtype=torch.float32))
    beta_hidden_2 = torch.exp(-torch.tensor(delta_t) / torch.tensor(tau_hidden_2, dtype=torch.float32))
    beta_hidden_3 = torch.exp(-torch.tensor(delta_t) / torch.tensor(tau_hidden_3, dtype=torch.float32))

    tau_output = np.repeat(10, output_size)
    beta_output = torch.exp(-torch.tensor(delta_t) / torch.tensor(tau_output, dtype=torch.float32))

    model = SNNQUT(
        input_size=input_size,
        hidden_size=hidden_size,
        output_size=output_size,
        beta_hidden_1=beta_hidden_1,
        beta_hidden_2=beta_hidden_2,
        beta_hidden_3=beta_hidden_3,
        beta_output=beta_output,
        hidden_reset_mechanism=hidden_reset_mechanism,
        output_reset_mechanism=output_reset_mechanism,
        hidden_threshold=hidden_threshold,
        output_threshold=output_threshold,
        fast_sigmoid_slope=fast_sigmoid_slope,
    )

    state_dict = torch.load(model_path, map_location='cpu')
    model.load_state_dict(state_dict)
    model.eval()
    return model

if __name__ == "__main__":
    model = load_model(model_path)
    test_dataset = CustomSNNTestDataset(test_dir)

    total_correct = 0
    total_samples = 0

    # We just run inference on each test folder
    for i in range(len(test_dataset)):
        folder_path, label = test_dataset[i]
        csv_files = glob.glob(os.path.join(folder_path, '*.csv'))
        # Load and process each CSV file (heavy)
        for csv_file in csv_files:
            data = np.loadtxt(csv_file, delimiter=',', dtype=np.float32)
            # data shape: (60000, 16)
            inputs = torch.tensor(data).unsqueeze(0) # shape (1, time_steps, input_dim)
            with torch.no_grad():
                spk4, mem4 = model(inputs)
                final_out = mem4.sum(0)  # shape (batch, output_size)
                _, predicted = final_out.max(-1)
                _, targets = label.unsqueeze(0).max(-1)
                correct = predicted.eq(targets).sum().item()
                total_correct += correct
                total_samples += targets.numel()

    accuracy = (total_correct / total_samples)*100 if total_samples > 0 else 0.0
    print(f"Test Accuracy: {accuracy:.2f}%")