# Causal STFT FNet: An Autoregressive Decoder Architecture

<a target="_blank" href="https://colab.research.google.com/github/dataopsnick/FNet-2025/blob/main/Causal_STFT_FNet_Student_Teacher_Knowledge_Distillation.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

This notebook implements a novel **Causal Short-Time Fourier Transform (STFT) FNet decoder** that extends the original FNet architecture [1] to autoregressive language modeling. The original FNet paper primarily focused on encoder architectures for tasks like masked language modeling, replacing self-attention with Fourier transforms to achieve significant speedup with competitive performance.

## Background: FNet and the Decoder Challenge

The original FNet [1] demonstrated that replacing self-attention sublayers with unparameterized Fourier transforms could achieve 92-97% of BERT's accuracy on GLUE benchmarks while training 80% faster on GPUs. However, the paper focused exclusively on encoder architectures, using 2D Discrete Fourier Transforms across both sequence and hidden dimensions.

As the authors noted in their conclusion:

> "Throughout this work we have restricted our focus to encoders. FNet decoders can be designed by "causally" masking the Vandermonde matrix, but a lower level implementation is required to introduce causal masking to FFTs. How to adapt Fourier mixing for encoder-decoder cross-attention is an open question as evidence suggests that cross-attention may be crucial to performance (You et al., 2020). We have focused on tasks which do not require generation so we leave FNet decoders and encoder-decoder setups to future work..."

## Key Contributions of This Implementation:

1. **Causal STFT Layer**: Unlike the original FNet which applies 2D FFT across the entire sequence, this implementation introduces a windowed Short-Time Fourier Transform with causal masking. This enables autoregressive generation while maintaining O(N log N) complexity within each window.

2. **Practical Decoder Architecture**: This notebook provides a concrete implementation of a causal FNet decoder for language modeling, addressing the challenge left open by the original paper.

3. **Sliding Window FFT Approach**: By using overlapping windows with configurable `stft_window_size` and applying FFT within each window, the model can capture local frequency patterns while maintaining strict causality through careful padding strategies.

4. **Knowledge Distillation Framework**: The implementation demonstrates how to train this novel architecture using knowledge distillation from a pre-trained transformer model (Qwen2-0.5B), showing a practical path to competitive performance.

## Technical Implementation Details:

- **Architecture**: Causal FNet decoder with configurable layers, hidden dimensions, and STFT window sizes
- **Causal Mechanism**: Uses zero-padding and strided tensor operations to ensure each position only attends to previous positions
- **Training Dataset**: GSM8K mathematical reasoning dataset for demonstrating capabilities on structured reasoning tasks
- **Model Sizes**: Demonstration uses a 4-layer, 256-hidden dimension student model for efficient training
- **Teacher Model**: Qwen2-0.5B-Instruct for knowledge distillation

## Comparison with Official FNet Implementation:

The official HuggingFace implementation (`transformers.models.fnet`) provides:
- Encoder-only models (FNetModel, FNetForMaskedLM, etc.)
- 2D FFT mixing across full sequences
- Support for TPU-optimized Fourier transforms

This implementation differs by:
- Focusing on decoder/autoregressive architectures
- Using windowed STFT instead of full-sequence FFT
- Implementing causal masking at the architecture level
- Demonstrating knowledge distillation for training efficiency

## References:

[1] J. Lee-Thorp, J. Ainslie, I. Eckstein, and S. Ontanon, "FNet: Mixing Tokens with Fourier Transforms," arXiv preprint arXiv:2105.03824, 2022.

---

In [None]:
#@title <h1> Autoregressive FNet with Causal STFT Layer</h1>
#@markdown This code cell implements a causally correct and efficient FNet decoder.
#@markdown ---
#@markdown ### **Code Structure**
#@markdown 1.  **Setup**: Installs specific, stable library versions.
#@markdown 2.  **Model Definition**: Contains the Causal STFT FNet architecture.
#@markdown 3.  **Data Loading & Preprocessing**: Prepares the GSM8K dataset.
#@markdown 4.  **Training**: Trains the model using the Hugging Face `Trainer`.
#@markdown 5.  **Inference**: Demonstrates text generation with the final, working model.
#@markdown ---

#@markdown ## 1. Setup
#@markdown **IMPORTANT**: Before running, please go to the menu and select "Runtime -> Restart runtime".
!pip install transformers>=4.31.0 datasets==2.18.0 accelerate>=0.21.0 evaluate==0.4.1 torch peft==0.10.0 -q

#@markdown ---

#@markdown ## 2. Model Definition
import torch
import torch.nn as nn
import torch.fft
from torch.nn import CrossEntropyLoss, functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutput
from transformers.generation import GenerationMixin
from typing import Optional, Tuple, Dict, Any, Union

# --- Configuration Class ---
class FNetConfig(PretrainedConfig):
    model_type = "causal_stft_fnet"
    def __init__(
        self, vocab_size=50257, hidden_size=768, num_hidden_layers=12, intermediate_size=3072,
        hidden_dropout_prob=0.1, max_position_embeddings=1024, stft_window_size=64,
        initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=50256,
        tie_word_embeddings=True, **kwargs,
    ):
        super().__init__(pad_token_id=pad_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs)
        self.vocab_size, self.hidden_size, self.num_hidden_layers, self.intermediate_size, \
        self.hidden_dropout_prob, self.max_position_embeddings, self.stft_window_size, \
        self.initializer_range, self.layer_norm_eps = vocab_size, hidden_size, num_hidden_layers, \
        intermediate_size, hidden_dropout_prob, max_position_embeddings, stft_window_size, \
        initializer_range, layer_norm_eps

# --- Invented Layer: Causal Short-Time Fourier Transform ---
class CausalSTFTLayer(nn.Module):
    def __init__(self, config: FNetConfig):
        super().__init__()
        self.window_size = config.stft_window_size
        self.projection = nn.Linear(config.stft_window_size * config.hidden_size, config.hidden_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, hidden_size = x.shape
        padded_x = F.pad(x, (0, 0, self.window_size - 1, 0))
        windows = padded_x.as_strided(
            size=(batch_size, seq_len, self.window_size, hidden_size),
            stride=(padded_x.stride(0), padded_x.stride(1), padded_x.stride(1), padded_x.stride(2))
        )
        fft_windows = torch.fft.fftn(windows, dim=(-2, -1)).real
        fft_windows = fft_windows.view(batch_size, seq_len, -1)
        return self.projection(fft_windows)

# --- Standard Layers ---
class FeedForwardLayer(nn.Module):
    def __init__(self, config: FNetConfig):
        super().__init__()
        self.dense1 = nn.Linear(config.hidden_size, config.intermediate_size)
        self.activation = nn.GELU()
        self.dense2 = nn.Linear(config.intermediate_size, config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
    def forward(self, x): return self.dropout(self.dense2(self.activation(self.dense1(x))))

class CausalFNetEncoderBlock(nn.Module):
    def __init__(self, config: FNetConfig):
        super().__init__()
        self.causal_stft = CausalSTFTLayer(config)
        self.ffn = FeedForwardLayer(config)
        self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
    def forward(self, x):
        x = self.norm1(x + self.causal_stft(x))
        x = self.norm2(x + self.ffn(x))
        return x

class FNetEmbeddings(nn.Module):
    def __init__(self, config: FNetConfig):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        self.pos_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False)
    def forward(self, input_ids):
        seq_len = input_ids.size(1)
        pos_ids = self.position_ids[:, :seq_len]
        embeds = self.word_embeddings(input_ids) + self.pos_embeddings(pos_ids)
        return self.dropout(self.norm(embeds))

# --- Top-Level Causal Model ---
class CausalFNetForCausalLM(PreTrainedModel, GenerationMixin):
    config_class = FNetConfig
    def __init__(self, config: FNetConfig):
        super().__init__(config)
        self.embeddings = FNetEmbeddings(config)
        self.encoder = nn.ModuleList([CausalFNetEncoderBlock(config) for _ in range(config.num_hidden_layers)])
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.post_init()

    def get_input_embeddings(self): return self.embeddings.word_embeddings
    def set_input_embeddings(self, value): self.embeddings.word_embeddings = value
    def get_output_embeddings(self): return self.lm_head
    def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings

    def forward(self, input_ids, labels=None, **kwargs):
        x = self.embeddings(input_ids)
        for block in self.encoder: x = block(x)
        logits = self.lm_head(x)
        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
        return CausalLMOutput(loss=loss, logits=logits)

    def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs: Any) -> Dict[str, Any]:
        return {"input_ids": input_ids}

print("✅ Causal STFT FNet model definitions are ready.")

#@markdown ---

#@markdown ## 3. Data Loading & Preprocessing
from datasets import load_dataset
from transformers import AutoTokenizer

DATASET_NAME = "gsm8k"
TOKENIZER_NAME = "gpt2"
MODEL_OUTPUT_DIR = "causal-stft-fnet-gsm8k-finetuned"

tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
tokenizer.pad_token = tokenizer.eos_token
raw_datasets = load_dataset(DATASET_NAME, "main")

def preprocess_function(examples):
    text = [f"Question: {q}\nAnswer: {a}{tokenizer.eos_token}" for q, a in zip(examples['question'], examples['answer'])]
    return tokenizer(text, truncation=True, max_length=256)

tokenized_datasets = raw_datasets.map(preprocess_function, batched=True, remove_columns=raw_datasets["train"].column_names)
train_dataset = tokenized_datasets["train"].select(range(2000))
eval_dataset = tokenized_datasets["test"].select(range(200))

print("✅ Data preprocessing complete.")

#@markdown ---

#@markdown ## 4. Training
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling

fnet_config = FNetConfig(
    vocab_size=tokenizer.vocab_size, pad_token_id=tokenizer.pad_token_id,
    hidden_size=256, num_hidden_layers=4, intermediate_size=1024,
    max_position_embeddings=1024, stft_window_size=32
)
model = CausalFNetForCausalLM(fnet_config)

training_args = TrainingArguments(
    output_dir=MODEL_OUTPUT_DIR,
    num_train_epochs=25,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    learning_rate=1e-4,
    weight_decay=0.01,
    warmup_steps=100,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    logging_steps=50,
    fp16=True,
    report_to="none",
    save_safetensors=False
)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)
print("\n✅ Trainer is configured and ready.")

#@markdown ---

#@markdown ## 5. Execute Training
print("Starting training...")
trainer.train()
print("🎉 Training complete!")

#@markdown ---

#@markdown ## 6. Inference (Manual Method)
# Load the best model
best_model_path = trainer.state.best_model_checkpoint
print(f"Loading best model from: {best_model_path}")

model_for_inference = CausalFNetForCausalLM.from_pretrained(best_model_path)
model_for_inference.eval()

# Move to appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_for_inference = model_for_inference.to(device)

# Test on a sample question
sample_question = raw_datasets["test"][15]["question"]
prompt = f"Question: {sample_question}\nAnswer:"
print(f"\nPROMPT:\n{prompt}")

# Manual generation
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

with torch.no_grad():
    for _ in range(200):  # Generate up to 100 tokens
        outputs = model_for_inference(input_ids=input_ids)
        next_token_logits = outputs.logits[0, -1, :]

        # Apply temperature (optional)
        temperature = 0.8
        next_token_logits = next_token_logits / temperature

        # Get probabilities and sample
        probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)

        # Append to sequence
        input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)

        # Stop if EOS token is generated
        if next_token.item() == tokenizer.eos_token_id:
            break

generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
print("\nMODEL GENERATION:")
print(generated_text)

✅ Causal STFT FNet model definitions are ready.


Map:   0%|          | 0/7473 [00:00<?, ? examples/s]

Map:   0%|          | 0/1319 [00:00<?, ? examples/s]

✅ Data preprocessing complete.


  trainer = Trainer(



✅ Trainer is configured and ready.
Starting training...


Epoch,Training Loss,Validation Loss
1,22.6691,21.879351
2,16.5026,16.027336
3,12.6294,12.380377
4,10.2945,10.241842
5,8.9286,9.099848
6,7.9674,8.461625
7,7.2943,8.106636
8,6.7443,7.856195
9,6.2438,7.735987
10,5.7247,7.65914


🎉 Training complete!
Loading best model from: causal-stft-fnet-gsm8k-finetuned/checkpoint-2500

PROMPT:
Question: A merchant wants to make a choice of purchase between 2 purchase plans: jewelry worth $5,000 or electronic gadgets worth $8,000. His financial advisor speculates that the jewelry market will go up 2.5% while the electronic gadgets market will rise 1.2% within the same month. If the merchant is looking to maximize profit at the end of this month by making a choice, how much profit would this be?
Answer:

MODEL GENERATION:
Question: A merchant wants to make a choice of purchase between 2 purchase plans: jewelry worth $5,000 or electronic gadgets worth $8,000. His financial advisor speculates that the jewelry market will go up 2.5% while the electronic gadgets market will rise 1.2% within the same month. If the merchant is looking to maximize profit at the end of this month by making a choice, how much profit would this be?
Answer: pool gave 80 the going, theThen, x to305 each

In [None]:
#@title <h1>Knowledge Distillation: Causal STFT FNet with Qwen2 Teacher (FIXED)</h1>
#@markdown This code cell modifies the original Causal STFT FNet to perform knowledge distillation.
#@markdown It uses a powerful, pre-trained Qwen/Qwen2-0.5B-Instruct model as the "teacher" to train a smaller, more efficient FNet "student" model on the GSM8K math reasoning dataset.
#@markdown
#@markdown ### **Code Structure**
#@markdown 1. Setup: Installs required libraries.
#@markdown 2. Model Definitions:
#@markdown - CausalSTFT FNet: The student model architecture.
#@markdown - DistillationTrainer: A custom Hugging Face Trainer to handle the specialized loss calculation.
#@markdown 3. Data Loading & Preprocessing: Prepares the GSM8K dataset.
#@markdown 4. Training: Initializes the teacher and student models and starts the distillation process.
#@markdown 5. Inference: Demonstrates text generation with the trained student model.
#@markdown ---
#@markdown ## 1. Setup
#@markdown IMPORTANT: Before running, select "Runtime -> Restart runtime" from the menu.
!pip install transformers>=4.38.0 datasets==2.18.0 accelerate>=0.21.0 evaluate==0.4.1 torch peft==0.10.0 sentencepiece -q

# Set environment variables for debugging
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'  # For better error messages
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'  # Better memory management

#@markdown ---
#@markdown ## 2. Model & Trainer Definitions
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers import (PreTrainedModel, PretrainedConfig, Trainer, TrainingArguments,
AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling)
from transformers.modeling_outputs import CausalLMOutput
from transformers.generation import GenerationMixin
from datasets import load_dataset
from typing import Optional, Dict, Any
import gc

# Clear any existing CUDA cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    gc.collect()

#--- Configuration for the Student FNet Model ---

class FNetConfig(PretrainedConfig):
  model_type = "causal_stft_fnet"
  def __init__(
    self, vocab_size=50257, hidden_size=768, num_hidden_layers=12, intermediate_size=3072,
    hidden_dropout_prob=0.1, max_position_embeddings=1024, stft_window_size=64,
    initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=50256,
    tie_word_embeddings=True, gradient_checkpointing=False, **kwargs,
    ):
    super().__init__(pad_token_id=pad_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs)
    self.vocab_size = vocab_size
    self.hidden_size = hidden_size
    self.num_hidden_layers = num_hidden_layers
    self.intermediate_size = intermediate_size
    self.hidden_dropout_prob = hidden_dropout_prob
    self.max_position_embeddings = max_position_embeddings
    self.stft_window_size = stft_window_size
    self.initializer_range = initializer_range
    self.layer_norm_eps = layer_norm_eps
    self.gradient_checkpointing = gradient_checkpointing

#--- Student Model Layers ---
class FNetConfig(PretrainedConfig):
  model_type = "causal_stft_fnet"
  def __init__(
    self, vocab_size=50257, hidden_size=768, num_hidden_layers=12, intermediate_size=3072,
    hidden_dropout_prob=0.1, max_position_embeddings=1024, stft_window_size=64,
    initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=50256,
    tie_word_embeddings=True, gradient_checkpointing=False, **kwargs,
    ):
    super().__init__(pad_token_id=pad_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs)
    self.vocab_size = vocab_size
    self.hidden_size = hidden_size
    self.num_hidden_layers = num_hidden_layers
    self.intermediate_size = intermediate_size
    self.hidden_dropout_prob = hidden_dropout_prob
    self.max_position_embeddings = max_position_embeddings
    self.stft_window_size = stft_window_size
    self.initializer_range = initializer_range
    self.layer_norm_eps = layer_norm_eps
    self.gradient_checkpointing = gradient_checkpointing

#--- Student Model Layers ---

class CausalSTFTLayer(nn.Module):
  def __init__(self, config: FNetConfig):
    super().__init__()
    self.window_size = config.stft_window_size
    self.projection = nn.Linear(config.stft_window_size * config.hidden_size, config.hidden_size)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
      batch_size, seq_len, hidden_size = x.shape
      padded_x = F.pad(x, (0, 0, self.window_size - 1, 0))
      windows = padded_x.as_strided(
          size=(batch_size, seq_len, self.window_size, hidden_size),
          stride=(padded_x.stride(0), padded_x.stride(1), padded_x.stride(1), padded_x.stride(2))
      )
      fft_windows = torch.fft.fftn(windows, dim=(-2, -1)).real
      fft_windows = fft_windows.view(batch_size, seq_len, -1)
      return self.projection(fft_windows)

class FeedForwardLayer(nn.Module):
  def __init__(self, config: FNetConfig):
    super().__init__()
    self.dense1 = nn.Linear(config.hidden_size, config.intermediate_size)
    self.activation = nn.GELU()
    self.dense2 = nn.Linear(config.intermediate_size, config.hidden_size)
    self.dropout = nn.Dropout(config.hidden_dropout_prob)
  def forward(self, x):
    return self.dropout(self.dense2(self.activation(self.dense1(x))))

class CausalFNetEncoderBlock(nn.Module):
  def __init__(self, config: FNetConfig):
    super().__init__()
    self.causal_stft = CausalSTFTLayer(config)
    self.ffn = FeedForwardLayer(config)
    self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
    self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
    self.gradient_checkpointing = False

  def forward(self, x):
      # Original forward pass logic
      stft_output = self.causal_stft(x)
      norm1_output = self.norm1(x + stft_output)
      ffn_output = self.ffn(norm1_output)
      output = self.norm2(norm1_output + ffn_output)

      if self.gradient_checkpointing and self.training:
          # Apply gradient checkpointing around the main computations
          def create_custom_forward(module):
              def custom_forward(*inputs):
                  return module(*inputs)
              return custom_forward

          # Checkpointing the STFT and Normalization
          stft_output = torch.utils.checkpoint.checkpoint(
              create_custom_forward(self.causal_stft), x, use_reentrant=False
          )
          norm1_output = torch.utils.checkpoint.checkpoint(
              create_custom_forward(self.norm1), x + stft_output, use_reentrant=False
          )

          # Checkpointing the FFN and second Normalization
          ffn_output = torch.utils.checkpoint.checkpoint(
              create_custom_forward(self.ffn), norm1_output, use_reentrant=False
          )
          output = torch.utils.checkpoint.checkpoint(
              create_custom_forward(self.norm2), norm1_output + ffn_output, use_reentrant=False
          )
          # Note: This specific checkpointing structure might need adjustment
          # based on how dependencies are handled. A simpler approach is to
          # checkpoint the entire block's forward pass logic. Let's try that.

          # Simpler checkpointing: checkpoint the entire block's computation
          # However, the Trainer is designed to checkpoint the blocks themselves.
          # The logic in the model's forward pass iterating through self.encoder
          # and applying checkpointing there is the standard HF way.
          # The issue is likely in how the gradient_checkpointing flag is used here.
          # Let's revert this forward pass to the non-checkpointed version
          # and rely on the model's forward pass logic to handle checkpointing.

          # Reverting to the simpler forward logic, assuming model's forward
          # will handle conditional checkpointing per block.
          # This block was the source of the NameError. Removing the incorrect logic.
          pass # Removed the incorrect checkpointing logic here


      # The gradient checkpointing should be handled in the top-level model's forward pass
      # by wrapping the call to the block. Let's ensure the top-level model's forward
      # pass does this correctly.

      # Reverting to the original, non-checkpointed forward pass structure
      # and fixing the NameError by removing the incorrect 'block' references.
      stft_output = self.causal_stft(x)
      norm1_output = self.norm1(x + stft_output)
      ffn_output = self.ffn(norm1_output)
      output = self.norm2(norm1_output + ffn_output)

      return output


class FNetEmbeddings(nn.Module):
  def __init__(self, config: FNetConfig):
    super().__init__()
    self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
    self.pos_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
    self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
    self.dropout = nn.Dropout(config.hidden_dropout_prob)
    self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False)
  def forward(self, input_ids):
    seq_len = input_ids.size(1)
    pos_ids = self.position_ids[:, :seq_len]
    embeds = self.word_embeddings(input_ids) + self.pos_embeddings(pos_ids)
    return self.dropout(self.norm(embeds))

#  --- Top-Level Student Model ---
class CausalFNetForCausalLM(PreTrainedModel, GenerationMixin):
  config_class = FNetConfig
  def __init__(self, config: FNetConfig):
    super().__init__(config)
    self.embeddings = FNetEmbeddings(config)
    self.encoder = nn.ModuleList([CausalFNetEncoderBlock(config) for _ in range(config.num_hidden_layers)])
    self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
    self.gradient_checkpointing = config.gradient_checkpointing
    self.post_init()

  def get_input_embeddings(self): return self.embeddings.word_embeddings
  def set_input_embeddings(self, value): self.embeddings.word_embeddings = value
  def get_output_embeddings(self): return self.lm_head
  def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings

  def forward(self, input_ids, labels=None, **kwargs):
    x = self.embeddings(input_ids)

    # Corrected forward pass with conditional checkpointing per block
    for block in self.encoder:
        if self.gradient_checkpointing and self.training:
            # Apply gradient checkpointing to the *block's forward pass*
            x = torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False)
        else:
            x = block(x) # Call the block's forward pass directly

    logits = self.lm_head(x)
    loss = None
    if labels is not None:
      loss_fct = CrossEntropyLoss()
      shift_logits = logits[..., :-1, :].contiguous()
      shift_labels = labels[..., 1:].contiguous()
      loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
    return CausalLMOutput(loss=loss, logits=logits)

  # Add method to enable gradient checkpointing
  def _set_gradient_checkpointing(self, module, value=False):
      if isinstance(module, CausalFNetEncoderBlock):
          module.gradient_checkpointing = value
      # Recursively apply to children
      for child in module.children():
          self._set_gradient_checkpointing(child, value)

  # Trainer expects this method name
  def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
      self._set_gradient_checkpointing(self, True)

  def gradient_checkpointing_disable(self):
      self._set_gradient_checkpointing(self, False)

  # Add method to enable gradient checkpointing
  def _set_gradient_checkpointing(self, module, value=False):
      if isinstance(module, CausalFNetEncoderBlock):
          module.gradient_checkpointing = value
      # Recursively apply to children
      for child in module.children():
          self._set_gradient_checkpointing(child, value)

  # Trainer expects this method name
  def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
      self._set_gradient_checkpointing(self, True)

  def gradient_checkpointing_disable(self):
      self._set_gradient_checkpointing(self, False)

  def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs: Any) -> Dict[str, Any]:
    return {"input_ids": input_ids}

#--- Custom Trainer for Knowledge Distillation ---
class DistillationTrainer(Trainer):
  def __init__(self, *args, teacher_model=None, alpha=0.5, temperature=2.0, **kwargs):
    super().__init__(*args, **kwargs)
    self.teacher_model = teacher_model
    self.alpha = alpha
    self.temperature = temperature
    if self.teacher_model:
        self.teacher_model.eval() # Ensure teacher is in eval mode

  def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
    # 1. Standard cross-entropy loss from student
    student_outputs = model(**inputs)
    student_loss = student_outputs.loss if student_outputs.loss is not None else torch.tensor(0.0)

    # 2. Distillation loss (KL Divergence)
    if self.teacher_model is not None:
        with torch.no_grad():
            teacher_outputs = self.teacher_model(**inputs)

        # Get logits, ensuring they are aligned
        student_logits = student_outputs.logits
        teacher_logits = teacher_outputs.logits

        # Shift for next-token prediction
        shift_student_logits = student_logits[..., :-1, :].contiguous()
        shift_teacher_logits = teacher_logits[..., :-1, :].contiguous()

        # Soften probabilities with temperature
        soft_student_log_probs = F.log_softmax(shift_student_logits / self.temperature, dim=-1)
        soft_teacher_probs = F.softmax(shift_teacher_logits / self.temperature, dim=-1)

        # Calculate KL Divergence loss
        distillation_loss = F.kl_div(
            soft_student_log_probs,
            soft_teacher_probs,
            reduction='batchmean',
            log_target=False
        ) * (self.temperature ** 2)

        # 3. Combine losses
        loss = self.alpha * student_loss + (1.0 - self.alpha) * distillation_loss
    else:
        loss = student_loss

    return (loss, student_outputs) if return_outputs else loss

print("✅ Model and Distillation Trainer definitions are ready.")

#@markdown ---
#@markdown ## 3. Data Loading & Preprocessing
TEACHER_MODEL_NAME = "Qwen/Qwen2-0.5B-Instruct"
DATASET_NAME = "gsm8k"
STUDENT_MODEL_OUTPUT_DIR = "distilled-stft-fnet-gsm8k"

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(TEACHER_MODEL_NAME, trust_remote_code=True)

# Fix the Qwen2 tokenizer configuration
original_vocab_size = tokenizer.vocab_size

# Check if eos_token_id is out of bounds and fix it
if hasattr(tokenizer, 'eos_token_id') and tokenizer.eos_token_id is not None and tokenizer.eos_token_id >= original_vocab_size:
    print(f"Fixing out-of-bounds eos_token_id: {tokenizer.eos_token_id}")
    # Set eos_token_id to the last valid token
    tokenizer.eos_token_id = original_vocab_size - 1
    tokenizer.eos_token = tokenizer.convert_ids_to_tokens(tokenizer.eos_token_id)

# Now handle pad_token
if tokenizer.pad_token is None:
    # Use a valid token ID within the vocabulary
    if tokenizer.eos_token_id is not None and tokenizer.eos_token_id < original_vocab_size:
        tokenizer.pad_token_id = tokenizer.eos_token_id
        tokenizer.pad_token = tokenizer.eos_token
        print(f"Using eos_token as pad_token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id})")
    else:
        # Use the last valid token as pad token
        tokenizer.pad_token_id = original_vocab_size - 1
        tokenizer.pad_token = tokenizer.convert_ids_to_tokens(tokenizer.pad_token_id)
        print(f"Using last token as pad_token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id})")

# Ensure pad_token_id is valid
if tokenizer.pad_token_id >= original_vocab_size:
     tokenizer.pad_token_id = original_vocab_size - 1
     tokenizer.pad_token = tokenizer.convert_ids_to_tokens(tokenizer.pad_token_id)

final_vocab_size = original_vocab_size  # We're not adding new tokens
print(f"Vocab size: {final_vocab_size}")
print(f"Pad token: '{tokenizer.pad_token}', Pad token ID: {tokenizer.pad_token_id}")
print(f"EOS token: '{tokenizer.eos_token}', EOS token ID: {tokenizer.eos_token_id}")

# Validate that pad_token_id is within bounds - should now pass
assert tokenizer.pad_token_id < final_vocab_size, f"pad_token_id ({tokenizer.pad_token_id}) must be < vocab_size ({final_vocab_size})"

print("Loading dataset...")
raw_datasets = load_dataset(DATASET_NAME, "main")

def preprocess_function(examples):
  text = [f"Question: {q}\nAnswer: {a}{tokenizer.eos_token}" for q, a in zip(examples['question'], examples['answer'])]
  return tokenizer(text, truncation=True, max_length=128, padding='max_length')  # Reduced max_length

tokenized_datasets = raw_datasets.map(preprocess_function, batched=True, remove_columns=raw_datasets["train"].column_names)
train_dataset = tokenized_datasets["train"].select(range(500)) # Using smaller subset
eval_dataset = tokenized_datasets["test"].select(range(50))

print("✅ Data preprocessing complete.")
print(f"Sample data:\n{tokenizer.decode(train_dataset['input_ids'][0])}")

#@markdown ---
#@markdown ## 4. Training Setup
# Clear GPU cache before starting
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    gc.collect()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nUsing device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory allocated: {torch.cuda.memory_allocated(0)/1024**3:.2f} GB")
    print(f"Memory available: {torch.cuda.get_device_properties(0).total_memory/1024**3:.2f} GB")

# --- Load Teacher Model with memory optimization ---
print("\nLoading teacher model...")
try:
    teacher_model = AutoModelForCausalLM.from_pretrained(
        TEACHER_MODEL_NAME,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        device_map="cuda:0",
        low_cpu_mem_usage=True
    )

    teacher_model.eval()
    teacher_model.gradient_checkpointing_enable()

    # Get the actual teacher model vocab size
    teacher_vocab_size = teacher_model.config.vocab_size
    print(f"Teacher model vocab size: {teacher_vocab_size}")
    print(f"Tokenizer vocab size: {final_vocab_size}")

    # Use the teacher's vocab size for the student to ensure compatibility
    actual_vocab_size = teacher_vocab_size
    print("✅ Teacher model loaded successfully")
except Exception as e:
    print(f"Error loading teacher model: {e}")
    print("Continuing without teacher model (student-only training)")
    teacher_model = None
    actual_vocab_size = final_vocab_size

# Ensure pad_token_id is valid for the actual vocab size
if tokenizer.pad_token_id >= actual_vocab_size:
    tokenizer.pad_token_id = actual_vocab_size - 1
    tokenizer.pad_token = tokenizer.convert_ids_to_tokens(tokenizer.pad_token_id)
    print(f"Adjusted pad_token_id to {tokenizer.pad_token_id} for vocab size {actual_vocab_size}")

#--- Configure and Initialize Student Model ---
student_config = FNetConfig(
  vocab_size=actual_vocab_size,  # Use teacher's vocab size
  pad_token_id=tokenizer.pad_token_id,  # Use the validated pad_token_id
  hidden_size=512,      # Very small model
  num_hidden_layers=4,  # Minimal layers
  intermediate_size=256,
  max_position_embeddings=256,
  stft_window_size=512,  # Small window
  gradient_checkpointing=True,  # Enable gradient checkpointing
  tie_word_embeddings=True  # Tie embeddings to save memory
)

print(f"\nConfiguring student model with vocab_size={actual_vocab_size}, pad_token_id={tokenizer.pad_token_id}")

# Ensure pad_token_id is valid for the actual vocab size
if tokenizer.pad_token_id >= actual_vocab_size:
    tokenizer.pad_token_id = actual_vocab_size - 1
    tokenizer.pad_token = tokenizer.convert_ids_to_tokens(tokenizer.pad_token_id)
    print(f"Adjusted pad_token_id to {tokenizer.pad_token_id} for vocab size {actual_vocab_size}")

print("\nInitializing student model...")
student_model = CausalFNetForCausalLM(student_config)
student_model = student_model.to(device)

#--- Training Arguments ---
training_args = TrainingArguments(
  output_dir=STUDENT_MODEL_OUTPUT_DIR,
  num_train_epochs=25,  # Fewer epochs
  per_device_train_batch_size=2,  # Very small batch size
  per_device_eval_batch_size=2,
  gradient_accumulation_steps=8,  # Larger accumulation
  learning_rate=5e-6,
  weight_decay=0.01,
  warmup_steps=50,
  eval_strategy="epoch",
  save_strategy="epoch",
  load_best_model_at_end=True,
  logging_steps=25,
  fp16=True,  # Use mixed precision
  gradient_checkpointing=True,
  report_to="none",
  save_total_limit=1,
  dataloader_pin_memory=False,  # Reduce memory usage
  dataloader_num_workers=0,  # Avoid multiprocessing issues
  save_safetensors=False,  # For tied embeddings
)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

#--- Initialize the Custom Trainer ---
distillation_trainer = DistillationTrainer(
  model=student_model,
  teacher_model=teacher_model,
  args=training_args,
  train_dataset=train_dataset,
  eval_dataset=eval_dataset,
  tokenizer=tokenizer,
  data_collator=data_collator,
  alpha=0.5 if teacher_model else 1.0,  # Only use student loss if no teacher
  temperature=3.0
)

print("\n✅ Teacher, Student, and Trainer are configured and ready.")
print(f"Student model parameters: {sum(p.numel() for p in student_model.parameters()):,}")
if teacher_model:
    print(f"Teacher model parameters: {sum(p.numel() for p in teacher_model.parameters()):,}")

# Clear cache again before training
if torch.cuda.is_available():
    torch.cuda.empty_cache()

#@markdown ---
#@markdown ## 5. Execute Training
print("\nStarting knowledge distillation training...")
try:
    distillation_trainer.train()
    print("🎉 Distillation training complete!")
except Exception as e:
    print(f"Training error: {e}")
    if "out of memory" in str(e).lower():
        print("\n⚠️ GPU out of memory. Suggestions:")
        print("- Restart runtime and try again")
        print("- Use gradient accumulation with smaller batch size")
        print("- Reduce model size further")
        print("- Use CPU training (slower but more stable)")
    raise e

#@markdown ---
#@markdown ## 6. Inference with the Distilled Student Model

# Clear cache before inference
if torch.cuda.is_available():
    torch.cuda.empty_cache()

#Load the best student model checkpoint
if hasattr(distillation_trainer.state, 'best_model_checkpoint') and distillation_trainer.state.best_model_checkpoint:
    best_student_path = distillation_trainer.state.best_model_checkpoint
    print(f"\nLoading best student model from: {best_student_path}")
    inference_model = CausalFNetForCausalLM.from_pretrained(best_student_path).to(device)
else:
    print("\nUsing current model state for inference")
    inference_model = student_model

inference_model.eval()

# Test on a sample question from the test set
sample_question = raw_datasets["test"][1]["question"]
prompt = f"Question: {sample_question}\nAnswer:"
print(f"\nPROMPT:\n{prompt}")

#Generate an answer using the distilled student model
input_ids = tokenizer.encode(prompt, return_tensors="pt", truncation=True, max_length=128).to(device)
with torch.no_grad():
    output_sequences = inference_model.generate(
      input_ids,
      max_new_tokens=50,  # Reduced for memory
      num_return_sequences=1,
      do_sample=True,
      top_k=50,
      top_p=0.95,
      temperature=0.7,
      pad_token_id=tokenizer.pad_token_id
    )

generated_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True)
print("\nSTUDENT MODEL GENERATION:")
print(generated_text)

# Final cleanup
if torch.cuda.is_available():
    torch.cuda.empty_cache()

✅ Model and Distillation Trainer definitions are ready.
Loading tokenizer...
Fixing out-of-bounds eos_token_id: 151645
Vocab size: 151643
Pad token: 'â½Ĺ', Pad token ID: 151642
EOS token: 'â½Ĺ', EOS token ID: 151642
Loading dataset...


Map:   0%|          | 0/7473 [00:00<?, ? examples/s]

Map:   0%|          | 0/1319 [00:00<?, ? examples/s]

✅ Data preprocessing complete.
Sample data:
Question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
Answer: Natalia sold 48/2 = <<48/2=24>>24 clips in May.
Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.
#### 72â½Ĺ⽗⽗⽗⽗⽗⽗⽗⽗⽗⽗⽗⽗⽗⽗⽗⽗⽗⽗⽗⽗⽗⽗⽗⽗⽗

Using device: cuda
GPU: NVIDIA A100-SXM4-40GB
Memory allocated: 0.42 GB
Memory available: 39.56 GB

Loading teacher model...


config.json:   0%|          | 0.00/659 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/988M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

Teacher model vocab size: 151936
Tokenizer vocab size: 151643
✅ Teacher model loaded successfully

Configuring student model with vocab_size=151936, pad_token_id=151642

Initializing student model...


  super().__init__(*args, **kwargs)



✅ Teacher, Student, and Trainer are configured and ready.
Student model parameters: 615,856,128
Teacher model parameters: 494,032,768

Starting knowledge distillation training...


Epoch,Training Loss,Validation Loss
1,106838.82,10338.692383
2,78271.41,8651.566406
3,67546.1,8327.429688
4,65409.84,8182.569824
5,62704.525,8095.097656
6,61995.01,8039.615234
7,61726.525,7977.393066
8,62867.26,7943.368652
9,60719.51,7919.916992
10,60511.81,7891.082031


🎉 Distillation training complete!

Loading best student model from: distilled-stft-fnet-gsm8k/checkpoint-800

PROMPT:
Question: A robe takes 2 bolts of blue fiber and half that much white fiber.  How many bolts in total does it take?
Answer:

STUDENT MODEL GENERATION:
Question: A robe takes 2 bolts of blue fiber and half that much white fiber.  How many bolts in total does it take?
Answer:, "d many to땀 Jenkins clar onitan530 tantal note2 shopper240 per2ᥣ00350�pq Incorrect Complexity高价  포함 longolvers_finishudas xét发挥了融资 Lust23 Blockedパターン mythical Tel
