In [1]:
!pip install --upgrade pip==24.0
!pip install protobuf fairseq sentencepiece
!pip install transformers datasets torch

[33mDEPRECATION: omegaconf 2.0.6 has a non-standard dependency specifier PyYAML>=5.1.*. pip 24.1 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of omegaconf or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
[33mDEPRECATION: omegaconf 2.0.6 has a non-standard dependency specifier PyYAML>=5.1.*. pip 24.1 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of omegaconf or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/py

In [2]:
import huggingface_hub
huggingface_hub.login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM, AdamW, get_linear_schedule_with_warmup
from datasets import load_dataset
from torch.cuda.amp import autocast, GradScaler
import tqdm

from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch, tokenizer, max_position_embeddings):
    """
    Custom collate function to process batches of variable-length sequences.
    
    Args:
        batch (list of dict): Each dict contains 'input_ids' and 'attention_mask'.
        tokenizer (PreTrainedTokenizer): Tokenizer with a defined pad_token_id.
        max_position_embeddings (int): Maximum sequence length supported by the model's positional embeddings.
    
    Returns:
        dict: Batch with padded 'input_ids' and 'attention_mask'.
    """
    # Extract input_ids and attention_masks from the batch
    input_ids = [torch.tensor(item['input_ids']) for item in batch]
    attention_masks = [torch.tensor(item['attention_mask']) for item in batch]
    
    # Determine the maximum sequence length in the batch
    batch_max_len = min(max(len(ids) for ids in input_ids), max_position_embeddings)
    
    # Truncate sequences to the batch's maximum length
    input_ids = [ids[:batch_max_len] for ids in input_ids]
    attention_masks = [mask[:batch_max_len] for mask in attention_masks]
    
    # Pad sequences to the batch's maximum length
    input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    attention_masks_padded = pad_sequence(attention_masks, batch_first=True, padding_value=0)
    
    return {
        'input_ids': input_ids_padded,
        'attention_mask': attention_masks_padded
    }
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define model hub paths
teacher_model_name = "abhinand/tamil-llama-7b-instruct-v0.2"  # Teacher model path
student_model_name = "tniranjan/finetuned_Llama_tinystories_tinystories_ta"     # Student model path
dataset_name = "tniranjan/tinystories_ta_google_translate"          # Dataset path

# Since the tokenizers are the same, load one tokenizer
tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)

# Load teacher model and student model
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model_name,torch_dtype=torch.float16)
teacher_model.to(device)
teacher_model.eval()  # Freeze teacher weights

student_model = AutoModelForCausalLM.from_pretrained(student_model_name)
student_model.to(device)
student_model.train()  # Set student to training mode

# Load dataset (assume field "text")
dataset = load_dataset(dataset_name, split="train")

def tokenize_function(example):
    return tokenizer(example["text"], truncation=True, padding="longest", max_length=512)

tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)
tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])
dataloader = DataLoader(tokenized_dataset, batch_size=4, shuffle=True)

# Define loss functions
ce_loss_fn = nn.CrossEntropyLoss()  # Standard language modeling loss
kl_loss_fn = nn.KLDivLoss(reduction="batchmean")  # KL divergence loss for distillation

# Hyperparameters
alpha = 0.5           # Weight for distillation (KD) loss; (1 - alpha) for CE loss
temperature = 2.0     # Temperature for softening logits
scaler = GradScaler()

# Optimizer and scheduler for student model
optimizer = AdamW(student_model.parameters(), lr=5e-5)
num_epochs = 3
total_steps = len(dataloader) * num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
student_model.push_to_hub("distilled_tinyllama")
tokenizer.push_to_hub("distilled_tinyllama")
# Training loop
for epoch in tqdm.tqdm(range(num_epochs)):
    for batch in tqdm.tqdm(dataloader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        optimizer.zero_grad()
        # Teacher forward pass (no gradient computation)
        with torch.no_grad():
            with autocast():
                teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
                teacher_logits = teacher_outputs.logits  # [batch, seq_len, vocab_size]
        with autocast():            
            # Student forward pass
            student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
            student_logits = student_outputs.logits  # [batch, seq_len, vocab_size]
            
            # Shift student logits and labels for causal LM (next token prediction)
            shift_student_logits = student_logits[..., :-1, :].contiguous()
            shift_labels = input_ids[..., 1:].contiguous()
            
            # Compute cross-entropy loss (student output vs. ground truth)
            loss_ce = ce_loss_fn(shift_student_logits.view(-1, shift_student_logits.size(-1)), 
                                 shift_labels.view(-1))
            
            # Compute KD loss with temperature scaling
            teacher_probs = torch.softmax(teacher_logits / temperature, dim=-1)
            student_log_probs = torch.log_softmax(student_logits / temperature, dim=-1)
            loss_kd = kl_loss_fn(student_log_probs, teacher_probs) * (temperature ** 2)
            
            # Combined loss
            loss = alpha * loss_kd + (1 - alpha) * loss_ce
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        scheduler.step()
        
        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
    student_model.push_to_hub("distilled_tinyllama")
    tokenizer.push_to_hub("distilled_tinyllama")

# Save the distilled student model and tokenizer
student_model.push_to_hub("distilled_tinyllama")
tokenizer.push_to_hub("distilled_tinyllama")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Map:   0%|          | 0/1000000 [00:00<?, ? examples/s]