In [1]:
import sys
sys.path.append('..')

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer
from trl.trainer.sft_trainer import DataCollatorForLanguageModeling
from datasets import load_dataset
import numpy as np

import math
from typing import Optional, List, Dict, Tuple, Literal, Any

from dotenv import load_dotenv
from einops import einsum, reduce
from functools import partial

from dataclasses import dataclass

from operator import attrgetter
from dsconf import DatasetConfig


load_dotenv('../../.env')

  from .autonotebook import tqdm as notebook_tqdm


True

In [2]:
model = 'EleutherAI/pythia-70m-deduped'
# model = 'google/gemma-3-270m'
# model = 'google/gemma-3-270m-it'
# model = 'Qwen/Qwen3-0.6B'


In [3]:
control = AutoModelForCausalLM.from_pretrained(model)
lm = AutoModelForCausalLM.from_pretrained(model, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model)

tokenizer.pad_token = tokenizer.eos_token

is_chat = tokenizer.chat_template is not None

In [4]:
dataset_config = DatasetConfig.from_dataset_path('openai/gsm8k', 'main')
train_dataset = dataset_config.load_and_process(is_chat, 'train', enable_thinking=False)
eval_dataset = dataset_config.load_and_process(is_chat, 'test', enable_thinking=False)

Generating train split: 100%|██████████| 7473/7473 [00:00<00:00, 1391027.99 examples/s]
Generating test split: 100%|██████████| 1319/1319 [00:00<00:00, 626036.77 examples/s]
Map: 100%|██████████| 7473/7473 [00:00<00:00, 552152.37 examples/s]
Map: 100%|██████████| 7473/7473 [00:00<00:00, 91569.70 examples/s]
Map: 100%|██████████| 1319/1319 [00:00<00:00, 345013.22 examples/s]
Map: 100%|██████████| 1319/1319 [00:00<00:00, 71432.28 examples/s]


In [5]:
class DynamicLoraLinear(nn.Linear):

    def __init__(
            self,
            in_features: int,
            out_features: int,
            lora_rank: int,
            lora_alpha: int,
            lora_dropout: float = 0.0,
            bias: bool = True,
            device=None,
            dtype=None
    ):
        
        super().__init__(
            in_features=in_features,
            out_features=out_features,
            bias=bias,
            device=device,
            dtype=dtype
        )

        assert lora_rank > 0, "Use nn.Linear for Non-Lora Layer"

        self.lora_rank = lora_rank
        self.lora_dropout = lora_dropout
        self.lora_scaling = lora_alpha/lora_rank
         

        self.A = None
        self.B = None
        self.reset_parameters()

    def set_lora_paramters(self, A: torch.Tensor, B: torch.Tensor) -> None:
        self.A = A # [batch_size x rank x input_dim]
        self.B = B # [batch_size x output_dim x rank]

    def replicate(self, target: nn.Linear) -> None:
        assert isinstance(target, nn.Linear), "Can only replicate nn.Linear"

        self.weight.data = target.weight.data
        if self.bias is not None:
            self.bias.data = target.bias.data

    def unset_lora_parameters(self) -> None:
        self.A = None
        self.B = None

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        
        # input: [batch_size x seq_len x input_dim]
        
        if self.A is None:
            return F.linear(input, self.weight, self.bias)
        
        # Sanity check
        batch_size = input.size(0)
        if self.A.size(0) != batch_size:
            raise RuntimeError(
                f"Batch size mismatch! Input batch_size={batch_size}, "
                f"but LoRA A has batch_size={self.A.size(0)}. "
                f"Old LoRA weights are being reused!"
            )

        out_base = F.linear(input, self.weight, None)
        out_delta = einsum(self.A, self.B, input, 'b r i, b o r, b s i -> b s o') # Instance-Level LoRA
        
        out =  out_base + self.lora_scaling * out_delta
        if self.bias is not None:
            out += self.bias    
        return out
    
    def extra_repr(self) -> str:
        out = nn.Linear.extra_repr(self)
        out += f', lora_rank={self.lora_rank}, lora_scaling={self.lora_scaling}, lora_dropout={self.lora_dropout}'
        return out

    

In [6]:
class TaskWeaver(nn.Module):

    def __init__(
            self,
            lm: AutoModelForCausalLM,
            hidden_dim: int,
            lora_rank: int,
            lora_target_layers: List[str],
            lora_alpha: float,
            lora_dropout: float=0.0,
            layers_module_name: str = 'layers'
    ):
        super().__init__()
        self.lm = lm
        self.lora_target_layers = lora_target_layers
        self.lora_rank = lora_rank

        # LLM config vals
        self.lm_num_layers = self.lm.config.num_hidden_layers
        self.lm_hidden_dim = self.lm.config.hidden_size


        lm_layers_ref = self.get_layers_ref(layers_module_name)
        assert isinstance(lm_layers_ref, nn.ModuleList), "Layers must be an nn.ModuleList"

        dynamic_lora_fn = partial(DynamicLoraLinear, lora_rank=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, device=self.lm.device)
        
        self.module_references, self.in_features, self.out_features = self.replace_linears(self.lora_target_layers, lm_layers_ref, dynamic_lora_fn)
        
        self.semantic_proj = nn.Linear(self.lm_hidden_dim, hidden_dim)

        self.module_embedding = nn.Embedding(len(lora_target_layers), hidden_dim)
        self.matrix_embedding = nn.Embedding(2, hidden_dim)
        self.layer_embedding = nn.Embedding(self.lm_num_layers, hidden_dim)


        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
        )

        self.heads = nn.ModuleDict({
            module_name: nn.ModuleDict({
                'A': nn.Linear(hidden_dim, self.in_features[module_name] * self.lora_rank),
                'B': nn.Linear(hidden_dim, self.out_features[module_name] * self.lora_rank)
            }) for module_name in self.lora_target_layers
        })

        self._init_weights()
        self._freeze_lm()

    def _freeze_lm(self):

        for param in self.lm.parameters():
            param.requires_grad = False

    def _init_weights(self):
        # Initialize MLP layers with smaller weights
        for module in [self.semantic_proj, self.mlp]:
            if isinstance(module, nn.Linear):
                nn.init.normal_(module.weight, std=0.02)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
        
        # Initialize output heads to produce small initial LoRA weights
        for module_name in self.lora_target_layers:
            for matrix_name in ['A', 'B']:
                head = self.heads[module_name][matrix_name]
                nn.init.zeros_(head.weight)  # Start with zero weights
                
                if matrix_name == 'A':
                    # Small random bias for A matrix
                    if hasattr(head, 'bias') and head.bias is not None:
                        nn.init.uniform_(head.bias, -1/np.sqrt(self.in_features[module_name]), 
                                                    1/np.sqrt(self.in_features[module_name]))
                else:  # B matrix
                    # Zero bias for B matrix (standard LoRA init)
                    if hasattr(head, 'bias') and head.bias is not None:
                        nn.init.zeros_(head.bias)

    def get_layers_ref(self, layers_module_name:str) -> nn.Module:

        for name, _ in self.lm.named_modules():
            if not name or name.count('.') == 0:
                continue
            path, attribute = name.rsplit(".", 1)
            if attribute == layers_module_name:
                return attrgetter(name)(self.lm)


    def replace_linears(self, lora_target_layers: List[str], lm_layers_ref:nn.ModuleList, dynamic_lora_fn:callable) -> Tuple[List[Dict[str, DynamicLoraLinear]], Dict[str, int], Dict[str, int]]:
        """
        Replaces target Linear layers with DynamicLoraLinears and return references, and module shapes

        Args:
            lora_target_layers (List[str])
        """

        references = [{} for _ in range(self.lm_num_layers)]
        in_features = {}
        out_features = {}

        for i, layer in enumerate(lm_layers_ref):
            
            for name, _ in layer.named_modules():
                if not name or name.count('.') == 0:
                    continue
                
                path, attribute = name.rsplit('.', 1)
                if attribute not in lora_target_layers:
                    continue
                
                parent_ref = attrgetter(path)(layer)
                linear_ref = getattr(parent_ref, attribute)
                assert isinstance(linear_ref, nn.Linear), "Can only adapt nn.Linear layers"
                in_features[attribute] = linear_ref.in_features
                out_features[attribute] = linear_ref.out_features
                dynamic_lora_layer = dynamic_lora_fn(in_features=linear_ref.in_features, out_features=linear_ref.out_features, bias=(linear_ref.bias is not None))
                dynamic_lora_layer.replicate(linear_ref)
                setattr(parent_ref, attribute, dynamic_lora_layer)
                references[i][attribute] = getattr(parent_ref, attribute)

        
        return references, in_features, out_features

    def _hypernet_forward(
            self,
            input_ids: torch.Tensor,
            attention_mask: torch.Tensor,
            prompt_length: Optional[torch.Tensor] = None
    ) -> List[Dict[str, Dict[Literal['A', 'B'], torch.Tensor]]]:


        self.clear_lora_weights()        

        batch_size = input_ids.shape[0]

        if prompt_length is not None:
            seq_len = attention_mask.shape[1]
            positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) # [1, seq_len]
            prompt_length_expanded = prompt_length.unsqueeze(1) # [batch_size, 1]
            prompt_mask = (positions < prompt_length_expanded).long()
        else:
            prompt_mask = attention_mask

        with torch.no_grad():
            outputs = self.lm(
                input_ids=input_ids,
                attention_mask=prompt_mask,
                output_hidden_states=True
            )
            last_hidden = outputs.hidden_states[-1]

            if prompt_length is not None:
                last_prompt_indices = prompt_length - 1
                semantic_embedding = last_hidden[
                    torch.arange(batch_size, device=last_hidden.device),
                    last_prompt_indices
                ] # [batch, hidden]
            else:
                last_indices = attention_mask.sum(dim=1) - 1
                semantic_embedding = last_hidden[
                    torch.arange(batch_size, device=last_hidden.device),
                    last_indices
                ]
    
        semantic_embedding = self.semantic_proj(semantic_embedding.detach())

        lora_weights = []

        for layer_idx in range(self.lm_num_layers):
            
            layer_dict = {}
            layer_emb = self.layer_embedding.weight[layer_idx:layer_idx+1]

            for module_idx, module_name in enumerate(self.lora_target_layers):

                module_dict = {}
                module_emb = self.module_embedding.weight[module_idx:module_idx+1]

                for matrix_idx, matrix_name in enumerate(['A', 'B']):

                    matrix_emb = self.matrix_embedding.weight[matrix_idx:matrix_idx+1]

                    combined_emb = semantic_embedding + layer_emb + module_emb + matrix_emb
                    combined_emb = self.mlp(combined_emb)
                    flat_weight = self.heads[module_name][matrix_name](combined_emb)

                    if matrix_name == 'A':
                        weight = flat_weight.view(batch_size, self.lora_rank, self.in_features[module_name])
                    else:
                        weight = flat_weight.view(batch_size, self.out_features[module_name], self.lora_rank)
                    
                    module_dict[matrix_name] = weight

                layer_dict[module_name] = module_dict
            
            lora_weights.append(layer_dict)

        return lora_weights
    
    
    def inject_lora_weights(self, lora_weights: List[Dict[str, Dict[Literal['A', 'B'], torch.Tensor]]]) -> None:

        for i, layer_dict in enumerate(self.module_references):
            for module_name in layer_dict:
                layer_dict[module_name].set_lora_paramters(**lora_weights[i][module_name])

    def clear_lora_weights(self) -> None:
        for layer_dict in self.module_references:
            for module_name in layer_dict:
                layer_dict[module_name].unset_lora_parameters()

    def forward(
            self, 
            input_ids:torch.Tensor, 
            attention_mask:torch.Tensor, 
            labels:Optional[torch.Tensor]=None, 
            prompt_length:Optional[torch.Tensor]=None,
            skip_hypernet: bool = False,
            **lm_kwargs
        ):
        
        if not skip_hypernet:
            lora_weights = self._hypernet_forward(input_ids=input_ids, attention_mask=attention_mask, prompt_length=prompt_length)
            self.inject_lora_weights(lora_weights)
        outputs = self.lm(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **lm_kwargs)

        return outputs
    
    def generate(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        prompt_length: Optional[torch.Tensor] = None,
        **generation_kwargs
    ):
        """
        Generate text using the task-adapted model.

        This method first generates LoRA weights using the hypernetwork,
        injects them into the model, and then runs generation.

        Args:
            input_ids: Input token IDs
            attention_mask: Attention mask (optional)
            prompt_length: Length of prompts in each sequence (optional)
            **generation_kwargs: Additional arguments passed to the LM's generate method

        Returns:
            Generated token IDs
        """
        # Create attention mask if not provided
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)

        # Generate LoRA weights based on the prompt
        lora_weights = self._hypernet_forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            prompt_length=prompt_length
        )

        # Inject LoRA weights into the model
        self.inject_lora_weights(lora_weights)

        # Generate using the adapted model
        try:
            outputs = self.lm.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                **generation_kwargs
            )
        finally:
            # Clear LoRA weights after generation
            self.clear_lora_weights()

        return outputs


    @property
    def device(self) -> torch.device:
        return self.lm.device
    

    @property
    def config(self):
        return self.lm.config

In [7]:
hypernet = TaskWeaver(
    lm,
    hidden_dim=256,
    lora_rank=2, 
    lora_target_layers=['query_key_value'],
    # lora_target_layers=['q_proj', 'v_proj'],
    lora_alpha=8,
    lora_dropout=0.01
)

In [8]:
total_params = sum(p.numel() for p in hypernet.parameters())
trainable_params = sum(p.numel() for p in hypernet.parameters() if p.requires_grad)

print(total_params)
print(trainable_params)

71745536
1318912


In [9]:
@dataclass
class DataCollatorWithPromptLenghts(DataCollatorForLanguageModeling):

    def __call__(self, examples:List[Dict[str, Any]]) -> Dict[str, Any]:
        batch = super().torch_call(examples)
        batch['prompt_length'] = (batch['labels'] != -100).int().argmax(dim=1)
        return batch

In [10]:
pad_token = tokenizer.pad_token or tokenizer.eos_token
pad_token_id = tokenizer.convert_tokens_to_ids(pad_token)

collator = DataCollatorWithPromptLenghts(pad_token_id=pad_token_id)

training_arguments = TrainingArguments(
    num_train_epochs = 1.0,
    per_device_train_batch_size = 2,
    gradient_accumulation_steps = 2,
    learning_rate = 5e-5,
    bf16 = False,
    logging_steps = 10,
)

In [11]:
trainer = SFTTrainer(
    model=hypernet,
    data_collator=collator,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    processing_class=tokenizer,
    args=training_arguments
)

Adding EOS to train dataset: 100%|██████████| 7473/7473 [00:00<00:00, 65653.37 examples/s]
Tokenizing train dataset: 100%|██████████| 7473/7473 [00:02<00:00, 2725.80 examples/s]
Truncating train dataset: 100%|██████████| 7473/7473 [00:00<00:00, 1117195.39 examples/s]
Adding EOS to eval dataset: 100%|██████████| 1319/1319 [00:00<00:00, 63520.14 examples/s]
Tokenizing eval dataset: 100%|██████████| 1319/1319 [00:00<00:00, 2560.30 examples/s]
Truncating eval dataset: 100%|██████████| 1319/1319 [00:00<00:00, 492546.92 examples/s]


In [12]:
trainer.train()

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 0}.


Step,Training Loss
10,17.3468
20,17.6255
30,15.216
40,15.9436
50,18.1305
60,15.4578
70,11.994
80,6.878
90,5.0425
100,4.7804


TrainOutput(global_step=1869, training_loss=3.027772908544974, metrics={'train_runtime': 352.6594, 'train_samples_per_second': 21.19, 'train_steps_per_second': 5.3, 'total_flos': 0.0, 'train_loss': 3.027772908544974, 'entropy': 2.7829650710610783, 'num_tokens': 1458952.0, 'mean_token_accuracy': 0.5495762334150427, 'epoch': 1.0})

In [26]:
def check(hypernet, control, tokenizer, sample, **generation_kwargs):
    inputs = tokenizer(text=sample['prompt'], return_tensors='pt')
    hypernet_inputs = {k:v.to(hypernet.device) for k,v in inputs.items()}
    control_inputs = {k:v.to(control.device) for k,v in inputs.items()}
    
    hypernet_output_ids = hypernet.generate(**hypernet_inputs, **generation_kwargs)[0]
    control_output_ids = control.generate(**control_inputs, **generation_kwargs)[0]

    hypernet_generated_ids = hypernet_output_ids[len(hypernet_inputs['input_ids'][0]):]
    control_generated_ids = control_output_ids[len(control_inputs['input_ids'][0]):]

    print("="*30)
    print(f"Prompt\n{sample['prompt']}")
    print("="*30)
    print(f"Completion\n{sample['completion']}")
    print("="*30)
    print(f"Control Answer\n{tokenizer.decode(control_generated_ids)}")
    print("="*30)
    print(f"Hypernet Answer\n{tokenizer.decode(hypernet_generated_ids)}")
    print("="*30)

In [30]:
check(hypernet, control, tokenizer, eval_dataset[0], do_sample=True, temperature=1.0, max_new_tokens=512)

Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


Prompt
Instruction: Analyze the given math problem, reason through it step by step, and provide the final answer in a new line starting with ####, for example: #### 72
Question: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?
Answer:
Completion
 Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.
She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market.
#### 18
Control Answer
 It’s worth the $2 per pair. In a similar vein, she is still a very short person but can still be a bit old.
Question: Have the ducks put a bunch of birds through these holes or are there any possible pockets for her or her friends?
Answer: Yes, but she still wants to find some bird for her friends who’ll make a few calls. No one can really understand, they’re just talking 