<a href="https://colab.research.google.com/github/iLevyTate/SCAN/blob/main/STAC_Spiked_Transformer_Augmenting_Cognition_Smaller_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchvision transformers snntorch datasets -U bitsandbytes accelerate evaluate pytest

In [None]:
%%writefile test_script.py
import pytest
import torch
import torch.nn as nn  # Correct import for nn
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from torch.amp import autocast
from torch.nn.utils import prune
# Define the necessary classes and functions here (AdExNeuron, SNNLayer, CombinedModel, apply_model_quantization, apply_model_pruning)

class AdExNeuron(nn.Module):
    def __init__(self, input_size, output_size, tau_m=20.0, tau_w=100.0, a=0.001, b=0.05, V_th=-50.0, V_reset=-65.0):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.tau_m = tau_m
        self.tau_w = tau_w
        self.a = a
        self.b = b
        self.V_th = V_th
        self.V_reset = V_reset
        self.fc = nn.Linear(input_size, output_size)

    def forward(self, input_tensor):
        batch_size = input_tensor.size(0)
        device = input_tensor.device
        if not hasattr(self, 'V') or self.V.size(0) != batch_size:
            self.V = nn.Parameter(torch.ones(batch_size, self.output_size, device=device) * self.V_reset, requires_grad=False)
            self.w = nn.Parameter(torch.zeros(batch_size, self.output_size, device=device), requires_grad=False)
        I = self.fc(input_tensor)
        dV = (I - self.w - (self.V - self.V_reset) / self.tau_m) / self.tau_m
        dw = (self.a * (self.V - self.V_reset) - self.w) / self.tau_w
        self.V.data += dV
        self.w.data += dw
        spikes = (self.V >= self.V_th).float()
        self.V.data = self.V * (1 - spikes) + self.V_reset * spikes
        self.w.data += self.b * spikes
        return spikes

class SNNLayer(nn.Module):
    def __init__(self, input_size, output_size, num_recurrent_layers=1):
        super(SNNLayer, self).__init__()
        self.fc = nn.Linear(input_size, output_size)
        self.adex = AdExNeuron(output_size, output_size)
        self.recurrent_layers = nn.ModuleList([AdExNeuron(output_size, output_size) for _ in range(num_recurrent_layers)])
        self.gate = nn.Sigmoid()
        self.dropout = nn.Dropout(p=0.2)
        self.batch_norm = nn.BatchNorm1d(output_size)

    def forward(self, x):
        spk_out = []
        for t in range(x.size(1)):
            input_fc = self.fc(x[:, t])
            if self.training and x.size(0) > 1:
                input_fc = self.batch_norm(input_fc)
            input_fc = self.dropout(input_fc)
            spk = self.adex(input_fc)
            for i, layer in enumerate(self.recurrent_layers):
                recurrent_input = self.gate(spk) * input_fc
                spk = layer(recurrent_input)
            spk_out.append(spk)
        return torch.stack(spk_out, dim=1)

class CombinedModel(nn.Module):
    def __init__(self, transformer_model, snn_output_size):
        super(CombinedModel, self).__init__()
        self.transformer = transformer_model
        self.snn_layer = SNNLayer(self.transformer.config.hidden_size, snn_output_size)

    def forward(self, input_ids, attention_mask=None):
        with autocast(device_type='cuda', enabled=torch.cuda.is_available()):
            transformer_outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
            last_hidden_state = transformer_outputs.hidden_states[-1]
            snn_outputs = self.snn_layer(last_hidden_state)
        return snn_outputs

def apply_model_quantization(model):
    model.eval()
    model.to('cpu')
    def quantize_layer(layer):
        if isinstance(layer, nn.Linear):
            return torch.quantization.QuantWrapper(layer)
        for name, child in layer.named_children():
            layer.add_module(name, quantize_layer(child))
        return layer
    quantized_model = quantize_layer(model)
    torch.quantization.prepare(quantized_model, inplace=True)
    torch.quantization.convert(quantized_model, inplace=True)
    return quantized_model

def apply_model_pruning(model):
    model.eval()
    parameters_to_prune = []
    for module in model.modules():
        if isinstance(module, nn.Linear):
            parameters_to_prune.append((module, 'weight'))
    prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.4)
    return model

# Pytest Fixtures and Tests
@pytest.fixture(scope="module")
def config():
    return {
        "transformer_name": "gpt2",
        "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        "input_shape": (2, 10, 768)
    }

@pytest.fixture(scope="module")
def tokenizer(config):
    tokenizer = AutoTokenizer.from_pretrained(config["transformer_name"])
    tokenizer.pad_token = tokenizer.eos_token  # Set pad token
    return tokenizer

@pytest.fixture(scope="module")
def transformer_model(config):
    model = AutoModelForCausalLM.from_pretrained(config["transformer_name"]).to(config["device"])
    return model

@pytest.fixture(scope="module")
def snn_layer(config, transformer_model):
    return SNNLayer(input_size=transformer_model.config.hidden_size, output_size=512).to(config["device"])

def test_snn_layer(config, snn_layer):
    input_data = torch.zeros(config["input_shape"], device=config["device"])
    with autocast(device_type='cuda', enabled=torch.cuda.is_available()):
        output_data = snn_layer(input_data)
    assert output_data.shape == (2, 10, 512)

def test_combined_model(config, transformer_model):
    combined_model = CombinedModel(
        transformer_model=transformer_model,
        snn_output_size=512
    ).to(config["device"])

    input_ids = torch.zeros((2, 10), dtype=torch.long, device=config["device"])
    attention_mask = torch.ones((2, 10), dtype=torch.long, device=config["device"])
    with autocast(device_type='cuda', enabled=torch.cuda.is_available()):
        output_data = combined_model(input_ids, attention_mask)
    assert output_data.shape == (2, 10, 512)

def test_data_loading_and_tokenization(tokenizer):
    dataset = load_dataset("squad", split="train[:1%]")
    assert len(dataset) > 0

    def tokenize_dataset(examples):
        return tokenizer(examples['context'], padding="max_length", truncation=True, max_length=512)

    tokenized_dataset = dataset.map(tokenize_dataset, batched=True)
    assert 'input_ids' in tokenized_dataset.features

def test_inference_pipeline(config, tokenizer, transformer_model):
    input_text = "How does the prefrontal cortex handle decision-making?"
    inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True).to(config["device"])

    if transformer_model.config.pad_token_id is None:
        transformer_model.config.pad_token_id = tokenizer.eos_token_id

    transformer_model.eval()
    with torch.no_grad():
        with autocast(device_type='cuda', enabled=torch.cuda.is_available()):
            outputs = transformer_model.generate(
                input_ids=inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                max_length=50,
                use_cache=True,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id
            )
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    assert len(generated_text) > 0

def test_model_quantization(config, transformer_model):
    combined_model = CombinedModel(
        transformer_model=transformer_model,
        snn_output_size=512
    ).to(config["device"])

    quantized_model = apply_model_quantization(combined_model)

    input_text = "How does the prefrontal cortex handle decision-making?"
    tokenizer = AutoTokenizer.from_pretrained(config["transformer_name"])
    tokenizer.pad_token = tokenizer.eos_token  # Set pad token
    inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True).to('cpu')

    quantized_model.to('cpu')  # Ensure the quantized model is on the CPU

    with torch.no_grad():
        outputs = quantized_model(inputs['input_ids'], inputs['attention_mask'])

    assert outputs.shape == (1, len(inputs['input_ids'][0]), 512)

def test_model_pruning(config, transformer_model):
    combined_model = CombinedModel(
        transformer_model=transformer_model,
        snn_output_size=512
    ).to(config["device"])

    pruned_model = apply_model_pruning(combined_model)
    assert any("weight_orig" not in name for name, _ in pruned_model.named_parameters())

if __name__ == '__main__':
    pytest.main()


In [None]:
!pytest -v test_script.py