In [None]:
# Install required libraries
!pip install torch transformers datasets wandb

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import logging
import wandb
from google.colab import files

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class SurrogateSpikeFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        out = (input > 0).float()
        return out

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        return grad_input

surrogate_spike = SurrogateSpikeFunction.apply

class AdExNeuron(nn.Module):
    def __init__(self, input_size, output_size, tau_m=20.0, tau_w=100.0, a=0.001, b=0.05, V_th=0.0, V_reset=-65.0):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.tau_m = tau_m
        self.tau_w = tau_w
        self.a = a
        self.b = b
        self.V_th = V_th
        self.V_reset = V_reset
        self.fc = nn.Linear(input_size, output_size)

    def forward(self, input_tensor, V, w):
        I = self.fc(input_tensor)
        dV = (I - w - (V - self.V_reset)) / self.tau_m
        dw = (self.a * (V - self.V_reset) - w) / self.tau_w
        V += dV
        w += dw
        mem_pot = V - self.V_th
        spikes = surrogate_spike(mem_pot)
        V = V * (1 - spikes) + self.V_reset * spikes
        w += self.b * spikes
        return spikes, V, w

class SNNLayer(nn.Module):
    def __init__(self, input_size, output_size, num_recurrent_layers=1):
        super().__init__()
        self.adex = AdExNeuron(input_size, output_size)
        self.recurrent_layers = nn.ModuleList([AdExNeuron(output_size, output_size) for _ in range(num_recurrent_layers)])
        self.gate = nn.Sigmoid()
        self.dropout = nn.Dropout(p=0.2)

    def forward(self, x):
        batch_size = x.size(0)
        device = x.device
        V = torch.ones(batch_size, self.adex.output_size, device=device) * self.adex.V_reset
        w = torch.zeros(batch_size, self.adex.output_size, device=device)
        spk_out = []
        seq_len = x.size(1)
        for t in range(seq_len):
            input_t = x[:, t, :]
            spk, V, w = self.adex(input_t, V, w)
            for layer in self.recurrent_layers:
                spk, V, w = layer(spk, V, w)
            spk_out.append(spk.unsqueeze(1))
        return torch.cat(spk_out, dim=1)

class CombinedModel(nn.Module):
    def __init__(self, transformer_model, snn_output_size):
        super().__init__()
        self.transformer = transformer_model
        self.snn_layer = SNNLayer(self.transformer.config.hidden_size, snn_output_size)
        self.output_layer = nn.Linear(snn_output_size, self.transformer.config.vocab_size)

    def forward(self, input_ids, attention_mask=None):
        transformer_outputs = self.transformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True  # Enable output of hidden states
        )
        last_hidden_state = transformer_outputs.hidden_states[-1]  # Shape: [batch_size, seq_len, hidden_size]
        snn_outputs = self.snn_layer(last_hidden_state)            # Apply SNN layer
        logits = self.output_layer(snn_outputs)                    # Shape: [batch_size, seq_len, vocab_size]
        return logits

def train_model(model, train_dataloader, val_dataloader, num_epochs, learning_rate, device, checkpoint_dir="checkpoints"):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2)
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for batch_idx, batch in enumerate(train_dataloader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask)
            shift_logits = outputs[..., :-1, :].contiguous()
            shift_labels = input_ids[..., 1:].contiguous()
            loss = criterion(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            if (batch_idx + 1) % 500 == 0:
                checkpoint_path = os.path.join(checkpoint_dir, f'model_batch_{epoch+1}_{batch_idx+1}.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print(f"Checkpoint saved at {checkpoint_path}")
        avg_train_loss = total_loss / len(train_dataloader)
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for batch in val_dataloader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                outputs = model(input_ids, attention_mask)
                shift_logits = outputs[..., :-1, :].contiguous()
                shift_labels = input_ids[..., 1:].contiguous()
                loss = criterion(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
                total_val_loss += loss.item()
        avg_val_loss = total_val_loss / len(val_dataloader)
        scheduler.step(avg_val_loss)
        checkpoint_path = os.path.join(checkpoint_dir, f'model_epoch_{epoch+1}.pth')
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path}")

def save_model(model, tokenizer, save_directory):
    if not os.path.exists(save_directory):
        os.makedirs(save_directory)
    model_path = os.path.join(save_directory, 'model_weights.pth')
    torch.save(model.state_dict(), model_path)
    tokenizer_path = os.path.join(save_directory, 'tokenizer')
    tokenizer.save_pretrained(tokenizer_path)

def main():
    wandb.init(project="STAC")
    datasets = load_dataset("wikitext", "wikitext-2-v1")
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token
    def tokenize_function(examples):
        return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)
    tokenized_datasets = datasets.map(tokenize_function, batched=True, remove_columns=datasets["train"].column_names)
    train_dataset = tokenized_datasets['train']
    val_dataset = tokenized_datasets['validation']
    def collate_fn(batch):
        input_ids = torch.tensor([item['input_ids'] for item in batch])
        attention_mask = torch.tensor([item['attention_mask'] for item in batch])
        return {'input_ids': input_ids, 'attention_mask': attention_mask}
    train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
    val_dataloader = DataLoader(val_dataset, batch_size=8, collate_fn=collate_fn)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    transformer_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
    combined_model = CombinedModel(transformer_model, snn_output_size=512).to(device)
    train_model(combined_model, train_dataloader, val_dataloader, num_epochs=5, learning_rate=5e-5, device=device)
    save_directory = "saved_model"
    save_model(combined_model, tokenizer, save_directory)

if __name__ == "__main__":
    main()