In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorWithPadding
from datasets import load_dataset
import numpy as np

import math
from typing import Optional, List, Dict, Tuple, Literal, Any

from dotenv import load_dotenv
from einops import einsum, reduce
from functools import partial

from dataclasses import dataclass

from operator import attrgetter


load_dotenv('../.env')

  from .autonotebook import tqdm as notebook_tqdm


False

In [None]:
# from huggingface_hub import login

# login(token=os.environ['HF_TOKEN'])

model = 'EleutherAI/pythia-70M-deduped'

In [3]:
lm = AutoModelForCausalLM.from_pretrained(model)
tokenizer = AutoTokenizer.from_pretrained(model)

tokenizer.pad_token = tokenizer.eos_token

NameError: name 'model' is not defined

In [5]:
dataset = load_dataset("openai/gsm8k", 'main')

In [6]:
def prepare_dataset(examples):
    results = {
        'input_ids': [],
        'attention_mask': [],
        'labels': [],
        'prompt_length': []
    }

    for prompt, completion in zip(examples['question'], examples['answer']):
        # Tokenize prompt and completion separately
        prompt_tokens = tokenizer(prompt, add_special_tokens=True)
        completion_tokens = tokenizer(completion, add_special_tokens=False)  # Don't add special tokens again!
        
        # Concatenate the token IDs
        input_ids = prompt_tokens['input_ids'] + completion_tokens['input_ids']
        attention_mask = prompt_tokens['attention_mask'] + completion_tokens['attention_mask']
        
        # Create labels: mask prompt, keep completion
        prompt_length = len(prompt_tokens['input_ids'])
        labels = [-100] * prompt_length + completion_tokens['input_ids']
        
        results['input_ids'].append(input_ids)
        results['attention_mask'].append(attention_mask)
        results['labels'].append(labels)
        results['prompt_length'].append(prompt_length)

    return results

In [7]:
from dataclasses import dataclass
from typing import Any, Dict, List
import torch
from transformers import DataCollatorForLanguageModeling

@dataclass
class DataCollatorWithPromptLengths(DataCollatorForLanguageModeling):
    """
    Data collator that handles padding and prompt lengths.
    """
    
    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        # Extract prompt_lengths before processing
        prompt_lengths = None
        if features and 'prompt_length' in features[0]:
            prompt_lengths = [f.pop('prompt_length') for f in features]
        
        # Manual padding since parent class is failing
        batch = {}
        
        # Get max length in batch
        max_length = max(len(f['input_ids']) for f in features)
        
        # Pad each sequence
        input_ids = []
        attention_mask = []
        labels = []
        
        for feature in features:
            # Convert to list if needed
            input_id = feature['input_ids']
            if isinstance(input_id, torch.Tensor):
                input_id = input_id.tolist()
            
            attn_mask = feature['attention_mask']
            if isinstance(attn_mask, torch.Tensor):
                attn_mask = attn_mask.tolist()
            
            label = feature['labels']
            if isinstance(label, torch.Tensor):
                label = label.tolist()
            
            # Calculate padding length
            padding_length = max_length - len(input_id)
            
            # Pad sequences (padding on the right for causal LM)
            input_ids.append(input_id + [self.tokenizer.pad_token_id] * padding_length)
            attention_mask.append(attn_mask + [0] * padding_length)
            labels.append(label + [-100] * padding_length)  # -100 is ignored in loss
        
        # Convert to tensors
        batch['input_ids'] = torch.tensor(input_ids, dtype=torch.long)
        batch['attention_mask'] = torch.tensor(attention_mask, dtype=torch.long)
        batch['labels'] = torch.tensor(labels, dtype=torch.long)
        
        # Add prompt_lengths back
        if prompt_lengths is not None:
            batch['prompt_lengths'] = torch.tensor(prompt_lengths, dtype=torch.long)
        
        return batch

In [8]:
train_dataset = dataset['train'].map(
    prepare_dataset,
    batched=True,
    remove_columns=dataset['train'].column_names,
    load_from_cache_file=False,
    desc='Tokenizing'
)

Tokenizing: 14946 examples [00:03, 2142.42 examples/s]        


In [9]:
class DynamicLoraLinear(nn.Linear):

    def __init__(
            self,
            in_features: int,
            out_features: int,
            lora_rank: int,
            lora_alpha: int,
            lora_dropout: float = 0.0,
            bias: bool = True,
            device=None,
            dtype=None
    ):
        super().__init__(
            in_features=in_features,
            out_features=out_features,
            bias=bias,
            device=device,
            dtype=dtype
        )

        assert lora_rank > 0, "Use nn.Linear for Non-Lora Layer"

        self.lora_rank = lora_rank
        self.lora_dropout = lora_dropout
        self.lora_scaling = lora_alpha/lora_rank
         

        self.A = None
        self.B = None
        self.reset_parameters()

    def set_lora_paramters(self, A: torch.Tensor, B: torch.Tensor) -> None:
        self.A = A # [batch_size x rank x input_dim]
        self.B = B # [batch_size x output_dim x rank]

    def unset_lora_parameters(self) -> None:
        self.A = None
        self.B = None

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        
        # input: [batch_size x seq_len x input_dim]
        
        if self.A is None:
            return F.linear(input, self.weight, self.bias)
        
        # Sanity check
        batch_size = input.size(0)
        if self.A.size(0) != batch_size:
            raise RuntimeError(
                f"Batch size mismatch! Input batch_size={batch_size}, "
                f"but LoRA A has batch_size={self.A.size(0)}. "
                f"Old LoRA weights are being reused!"
            )

        out_base = F.linear(input, self.weight, None)
        out_delta = einsum(self.A, self.B, input, 'b r i, b o r, b s i -> b s o') # Instance-Level LoRA
        
        out =  out_base + self.lora_scaling * out_delta
        if self.bias is not None:
            out += self.bias    
        return out
    
    def extra_repr(self) -> str:
        out = nn.Linear.extra_repr(self)
        out += f', lora_rank={self.lora_rank}, lora_scaling={self.lora_scaling}, lora_dropout={self.lora_dropout}'
        return out

    

In [10]:
class TaskWeaver(nn.Module):

    def __init__(
            self,
            lm: AutoModelForCausalLM,
            hidden_dim: int,
            lora_rank: int,
            lora_target_layers: List[str],
            lora_alpha: float,
            lora_dropout: float=0.0,
            layers_module_name: str = 'layers'
    ):
        super().__init__()
        self.lm = lm
        self.lora_target_layers = lora_target_layers
        self.lora_rank = lora_rank

        # LLM config vals
        self.lm_num_layers = self.lm.config.num_hidden_layers
        self.lm_hidden_dim = self.lm.config.hidden_size


        lm_layers_ref = self.get_layers_ref(layers_module_name)
        assert isinstance(lm_layers_ref, nn.ModuleList), "Layers must be an nn.ModuleList"

        dynamic_lora_fn = partial(DynamicLoraLinear, lora_rank=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, device=self.lm.device)
        
        self.module_references, self.in_features, self.out_features = self.replace_linears(self.lora_target_layers, lm_layers_ref, dynamic_lora_fn)
        
        self.semantic_proj = nn.Linear(self.lm_hidden_dim, hidden_dim)

        self.module_embedding = nn.Embedding(len(lora_target_layers), hidden_dim)
        self.matrix_embedding = nn.Embedding(2, hidden_dim)
        self.layer_embedding = nn.Embedding(self.lm_num_layers, hidden_dim)


        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
        )

        self.heads = nn.ModuleDict({
            module_name: nn.ModuleDict({
                'A': nn.Linear(hidden_dim, self.in_features[module_name] * self.lora_rank),
                'B': nn.Linear(hidden_dim, self.out_features[module_name] * self.lora_rank)
            }) for module_name in self.lora_target_layers
        })

        self._init_weights()

    def _init_weights(self):
        # Initialize MLP layers with smaller weights
        for module in [self.semantic_proj, self.mlp]:
            if isinstance(module, nn.Linear):
                nn.init.normal_(module.weight, std=0.02)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
        
        # Initialize output heads to produce small initial LoRA weights
        for module_name in self.lora_target_layers:
            for matrix_name in ['A', 'B']:
                head = self.heads[module_name][matrix_name]
                nn.init.zeros_(head.weight)  # Start with zero weights
                
                if matrix_name == 'A':
                    # Small random bias for A matrix
                    if hasattr(head, 'bias') and head.bias is not None:
                        nn.init.uniform_(head.bias, -1/np.sqrt(self.in_features[module_name]), 
                                                    1/np.sqrt(self.in_features[module_name]))
                else:  # B matrix
                    # Zero bias for B matrix (standard LoRA init)
                    if hasattr(head, 'bias') and head.bias is not None:
                        nn.init.zeros_(head.bias)

    def get_layers_ref(self, layers_module_name:str) -> nn.Module:

        for name, _ in self.lm.named_modules():
            if not name or name.count('.') == 0:
                continue
            path, attribute = name.rsplit(".", 1)
            if attribute == layers_module_name:
                return attrgetter(name)(self.lm)


    def replace_linears(self, lora_target_layers: List[str], lm_layers_ref:nn.ModuleList, dynamic_lora_fn:callable) -> Tuple[List[Dict[str, DynamicLoraLinear]], Dict[str, int], Dict[str, int]]:
        """
        Replaces target Linear layers with DynamicLoraLinears and return references, and module shapes

        Args:
            lora_target_layers (List[str])
        """

        references = [{} for _ in range(self.lm_num_layers)]
        in_features = {}
        out_features = {}

        for i, layer in enumerate(lm_layers_ref):
            
            for name, _ in layer.named_modules():
                if not name or name.count('.') == 0:
                    continue
                
                path, attribute = name.rsplit('.', 1)
                if attribute not in lora_target_layers:
                    continue
                
                parent_ref = attrgetter(path)(layer)
                linear_ref = getattr(parent_ref, attribute)
                assert isinstance(linear_ref, nn.Linear), "Can only adapt nn.Linear layers"
                in_features[attribute] = linear_ref.in_features
                out_features[attribute] = linear_ref.out_features
                setattr(parent_ref, attribute, dynamic_lora_fn(in_features=linear_ref.in_features, out_features=linear_ref.out_features))
                references[i][attribute] = getattr(parent_ref, attribute)

        
        return references, in_features, out_features

    def _hypernet_forward(
            self,
            input_ids: torch.Tensor,
            attention_mask: torch.Tensor,
            prompt_lengths: Optional[torch.Tensor] = None
    ) -> List[Dict[str, Dict[Literal['A', 'B'], torch.Tensor]]]:


        self.clear_lora_weights()        

        batch_size = input_ids.shape[0]

        if prompt_lengths is not None:
            seq_len = attention_mask.shape[1]
            positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) # [1, seq_len]
            prompt_lengths_expanded = prompt_lengths.unsqueeze(1) # [batch_size, 1]
            prompt_mask = (positions < prompt_lengths_expanded).long()
        else:
            prompt_mask = attention_mask

        with torch.no_grad():
            outputs = self.lm(
                input_ids=input_ids,
                attention_mask=prompt_mask,
                output_hidden_states=True
            )
            last_hidden = outputs.hidden_states[-1]

            if prompt_lengths is not None:
                last_prompt_indices = prompt_lengths - 1
                semantic_embedding = last_hidden[
                    torch.arange(batch_size, device=last_hidden.device),
                    last_prompt_indices
                ] # [batch, hidden]
            else:
                last_indices = attention_mask.sum(dim=1) - 1
                semantic_embedding = last_hidden[
                    torch.arange(batch_size, device=last_hidden.device),
                    last_indices
                ]
    
        semantic_embedding = self.semantic_proj(semantic_embedding.detach())

        lora_weights = []

        for layer_idx in range(self.lm_num_layers):
            
            layer_dict = {}
            layer_emb = self.layer_embedding.weight[layer_idx:layer_idx+1]

            for module_idx, module_name in enumerate(self.lora_target_layers):

                module_dict = {}
                module_emb = self.module_embedding.weight[module_idx:module_idx+1]

                for matrix_idx, matrix_name in enumerate(['A', 'B']):

                    matrix_emb = self.matrix_embedding.weight[matrix_idx:matrix_idx+1]

                    combined_emb = semantic_embedding + layer_emb + module_emb + matrix_emb
                    combined_emb = self.mlp(combined_emb)
                    flat_weight = self.heads[module_name][matrix_name](combined_emb)

                    if matrix_name == 'A':
                        weight = flat_weight.view(batch_size, self.lora_rank, self.in_features[module_name])
                    else:
                        weight = flat_weight.view(batch_size, self.out_features[module_name], self.lora_rank)
                    
                    module_dict[matrix_name] = weight

                layer_dict[module_name] = module_dict
            
            lora_weights.append(layer_dict)

        return lora_weights
    
    
    def inject_lora_weights(self, lora_weights: List[Dict[str, Dict[Literal['A', 'B'], torch.Tensor]]]) -> None:

        for i, layer_dict in enumerate(self.module_references):
            for module_name in layer_dict:
                layer_dict[module_name].set_lora_paramters(**lora_weights[i][module_name])

    def clear_lora_weights(self) -> None:
        for layer_dict in self.module_references:
            for module_name in layer_dict:
                layer_dict[module_name].unset_lora_parameters()

    def forward(
            self, 
            input_ids:torch.Tensor, 
            attention_mask:torch.Tensor, 
            labels:Optional[torch.Tensor]=None, 
            prompt_lengths:Optional[torch.Tensor]=None,
            skip_hypernet: bool = False
        ):
        
        if not skip_hypernet:
            lora_weights = self._hypernet_forward(input_ids=input_ids, attention_mask=attention_mask, prompt_lengths=prompt_lengths)
            self.inject_lora_weights(lora_weights)
        outputs = self.lm(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

        return outputs
    
    def generate(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        prompt_lengths: Optional[torch.Tensor] = None,
        **generation_kwargs
    ):
        """
        Generate text using the task-adapted model.

        This method first generates LoRA weights using the hypernetwork,
        injects them into the model, and then runs generation.

        Args:
            input_ids: Input token IDs
            attention_mask: Attention mask (optional)
            prompt_lengths: Length of prompts in each sequence (optional)
            **generation_kwargs: Additional arguments passed to the LM's generate method

        Returns:
            Generated token IDs
        """
        # Create attention mask if not provided
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)

        # Generate LoRA weights based on the prompt
        lora_weights = self._hypernet_forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            prompt_lengths=prompt_lengths
        )

        # Inject LoRA weights into the model
        self.inject_lora_weights(lora_weights)

        # Generate using the adapted model
        try:
            outputs = self.lm.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                **generation_kwargs
            )
        finally:
            # Clear LoRA weights after generation
            self.clear_lora_weights()

        return outputs


    @property
    def device(self) -> torch.device:
        return self.lm.device

In [11]:
hypernet = TaskWeaver(
    lm,
    hidden_dim=256,
    lora_rank=2, 
    lora_target_layers=['q_proj', 'v_proj'],
    lora_alpha=8,
    lora_dropout=0.01
)

In [12]:
data_collator = DataCollatorWithPromptLengths(
    tokenizer=tokenizer,
    mlm=False,
    return_tensors='pt'
)

training_args = TrainingArguments(
    output_dir='./taskweaver_output',
    per_device_train_batch_size=16,
    gradient_accumulation_steps=1,
    learning_rate=5e-5,
    num_train_epochs=1,
    logging_steps=10,
    save_steps=100,
    fp16=False,
    remove_unused_columns=False,
    dataloader_num_workers=0,
    dataloader_pin_memory=False, 
)

trainer = Trainer(
    model=hypernet,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=data_collator
)

In [13]:
trainer.train()

OutOfMemoryError: CUDA out of memory. Tried to allocate 54.00 MiB. GPU 0 has a total capacity of 23.57 GiB of which 18.12 MiB is free. Process 2104995 has 23.54 GiB memory in use. Of the allocated memory 23.01 GiB is allocated by PyTorch, and 242.65 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
inputs = tokenizer(dataset['train'][0]['question'], return_tensors='pt')
inputs = {k:v.to('cuda') for k,v in inputs.items()}
print(f"Question: {dataset['train'][0]['question']}")
print(f"Answer: {dataset['train'][0]['answer']}")

lm_out = lm.generate(**inputs).squeeze()
print(f"LM Out: {tokenizer.decode(lm_out)}")

hn_out = hypernet.generate(**inputs).squeeze()
print(f"HN Out: {tokenizer.decode(hn_out)}")

Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


Question: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
Answer: Natalia sold 48/2 = <<48/2=24>>24 clips in May.
Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.
#### 72
LM Out: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?First, so the first year, so he then the first year, so he then the first year
HN Out: Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?The number of the number of books, the number of books, the number of books, the number


In [None]:
out = hypernet.generate(**inputs)

Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


In [None]:
tokenizer.decode(lm_out.squeeze())

'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?First, so the first year, so he then the first year, so he then the first year'

In [None]:
tokenizer.decode(out)

TypeError: argument 'ids': 'list' object cannot be interpreted as an integer