In [None]:
import torch
from transformers import AutoTokenizer, DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, T5Config
from torch.nn import functional as F
from torch.utils.data import Dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.optim import Adam
from peft import get_peft_model, LoraConfig, TaskType
from peft import PeftModel
import torch.nn as nn
import json
from rouge_score import rouge_scorer


In [None]:
from huggingface_hub import login

login(token="token")


In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load a pre-trained T5 model with weights
student_model_name = "google/flan-t5-base" 
lora_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    r=8,  # The rank of the LoRA layers
    lora_alpha=32,  # Scaling factor for the LoRA layers
    lora_dropout=0.1,  # Dropout probability for LoRA layers
    bias="none"
)

# Load the tokenizer as usual
student_tokenizer = AutoTokenizer.from_pretrained(student_model_name)


In [None]:
# Define the model name
teacher_model_name = "google/flan-t5-large"
# Set the compute dtype to match the input dtype (float16)

# Load the model with 8-bit quantization
flan_model = AutoModelForSeq2SeqLM.from_pretrained(
    teacher_model_name, 
    load_in_4bit=True, 
    device_map="auto",
    torch_dtype=torch.float16,
    bnb_4bit_compute_dtype=torch.float16
)
t0_model_name = "bigscience/T0_3B"

# Load the T0-3B model with 8-bit quantization
t0_model = AutoModelForSeq2SeqLM.from_pretrained(
    t0_model_name,
    load_in_4bit=True, 
    device_map="auto",
    torch_dtype=torch.float16,
    bnb_4bit_compute_dtype=torch.float16
)

flan_t5_xl_model_name = "google/flan-t5-xl"

# Load the FLAN-T5-XL model with 8-bit quantization
flan_t5_xl_model = AutoModelForSeq2SeqLM.from_pretrained(
    flan_t5_xl_model_name,
    load_in_4bit=True,
    device_map="auto",
    torch_dtype=torch.float16,
    bnb_4bit_compute_dtype=torch.float16
)

In [None]:
class InstructionDataset(Dataset):
    def __init__(self, instances, tokenizer, max_length=256):
        self.instances = instances
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx):
        instance = self.instances[idx]
        input_text = instance["input"]
        output_text = instance["output"]
        
        # Tokenize input text
        input_encoding = self.tokenizer(
            input_text,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        
        # Tokenize output text (for labels)
        label_encoding = self.tokenizer(
            output_text,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )

        input_ids = input_encoding.input_ids.squeeze()
        labels = label_encoding.input_ids.squeeze()

        return {
            "input_ids": input_ids,
            "labels": labels,
        }




class CustomDataCollator(DataCollatorForSeq2Seq):
    def __call__(self, features):
        input_ids = torch.nn.utils.rnn.pad_sequence(
            [f['input_ids'].to(device) for f in features],  
            batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        
        labels = torch.nn.utils.rnn.pad_sequence(
            [f['labels'].to(device) for f in features], 
            batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        # Generate decoder_input_ids by shifting input_ids to the right
        decoder_input_ids = torch.nn.utils.rnn.pad_sequence(
            [torch.cat([torch.tensor([self.tokenizer.pad_token_id]), f['input_ids'].cpu()[:-1]]) 
             for f in features],
            batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        
        # Set the correct start token for decoder_input_ids
        decoder_input_ids[:, 0] = self.tokenizer.pad_token_id  # Replace with start token id

        batch = {
            'input_ids': input_ids,
            'decoder_input_ids': decoder_input_ids,
            'labels': labels,
        }

        return batch

# Prepare Data
def prepare_data(data_dir, tasks_dir, first_task,last_task):

    def load_task_names(file_path):
        with open(file_path, 'r', encoding='utf-8') as file:
            task_names = file.readlines()
        return [task.strip() for task in task_names]

    def load_task_data(task_name):
        with open(f"{tasks_dir}/{task_name}.json", 'r', encoding='utf-8') as file:
            data = json.load(file)
        return data

    def extract_instances(data):
        instances = [{
            "input": data.get("Definition")[0] + "\n" + instance["input"] + "\n",
            "output": instance["output"][0]
        } for instance in data.get("Instances", [])]
        return instances

    tasks = load_task_names(data_dir)

    tasks = tasks[first_task:last_task+1]

    all_instances = []
    for task in tasks:
        task_data = load_task_data(task)
        instances = extract_instances(task_data)
        all_instances.extend(instances)

    return all_instances

# Usage Example:
train_data_dir = "./splits/default/train_tasks.txt"
tasks_dir = "./tasks/"
first_train_task = 1
last_train_task = 15

# Prepare instances
instances = prepare_data(train_data_dir, tasks_dir, first_train_task,last_train_task )

# Create Dataset
dataset = InstructionDataset(instances, student_tokenizer)
dataloader = dataloader = DataLoader(
    dataset=dataset,
    batch_size=16,
    collate_fn=CustomDataCollator(tokenizer=student_tokenizer),
    shuffle = True
)
  

val_data_dir = "./splits/default/test_tasks.txt"
first_val_tasks = 1 
last_val_task = 1
# Prepare validation instances
val_instances = prepare_data(val_data_dir, tasks_dir, first_val_tasks, last_val_task)

# Create Validation Dataset
val_dataset = InstructionDataset(val_instances, student_tokenizer)
val_dataloader = DataLoader(
    dataset=val_dataset,
    batch_size=8,
    collate_fn=CustomDataCollator(tokenizer=student_tokenizer),
    shuffle = False
)


In [None]:
class AttentionMechanism(nn.Module):
    def __init__(self, num_teachers, hidden_sizes):
        super(AttentionMechanism, self).__init__()
        self.num_teachers = num_teachers
        self.hidden_sizes = hidden_sizes
        self.query_layer = nn.Linear(sum(hidden_sizes), num_teachers)

    def forward(self, hidden_states):
        # Concatenate hidden states from all teachers along the last dimension
        concat_hidden_states = torch.cat(hidden_states, dim=-1).contiguous().to(torch.float32)        
        # Compute attention scores for each instance in the batch
        concat_hidden_states = F.relu(concat_hidden_states)
        attention_scores = self.query_layer(concat_hidden_states)  # Output shape: [batch_size, seq_length, num_teachers]
        
        # Take the mean across the batch size and sequence length to get a scalar attention score for each teacher
        attention_scores = attention_scores.mean(dim=[0, 1])  # Output shape: [num_teachers]
        
        # Normalize to get attention weights
        attention_weights = F.softmax(attention_scores, dim=-1)  # Output shape: [num_teachers]
        
        return attention_weights  # Now the shape is [num_teachers]


In [None]:

def entropy_based_weights(**kwargs):
    flan_teacher_logits = kwargs.get('flan_teacher_logits')
    t0_teacher_logits = kwargs.get('t0_teacher_logits')
    flan_t5_xl_teacher_logits = kwargs.get('flan_t5_xl_teacher_logits')
    temperature = kwargs.get('temperature', 1.0)

    flan_entropy = torch.distributions.Categorical(logits=flan_teacher_logits).entropy().mean()
    t0_entropy = torch.distributions.Categorical(logits=t0_teacher_logits).entropy().mean()
    flan_t5_xl_entropy = torch.distributions.Categorical(logits=flan_t5_xl_teacher_logits).entropy().mean()

    entropies = torch.tensor([flan_entropy, t0_entropy, flan_t5_xl_entropy])
    inv_entropies = 1 / (entropies + 1e-9)
    weights = torch.softmax(inv_entropies / temperature, dim=0)
    return weights.tolist()

def gradient_based_weights(**kwargs):
    student_logits = kwargs.get('student_logits')
    flan_teacher_logits = kwargs.get('flan_teacher_logits')
    t0_teacher_logits = kwargs.get('t0_teacher_logits')
    flan_t5_xl_teacher_logits = kwargs.get('flan_t5_xl_teacher_logits')
    temperature = kwargs.get('temperature', 1.0)

    grads = []
    for teacher_logits in [flan_teacher_logits, t0_teacher_logits, flan_t5_xl_teacher_logits]:
        loss = F.cross_entropy(student_logits.view(-1, student_logits.size(-1)), torch.argmax(teacher_logits, dim=-1).view(-1))
        grad = torch.autograd.grad(loss, student_logits, retain_graph=True)[0]
        grad_norm = grad.norm().item()
        grads.append(grad_norm)

    grads_tensor = torch.tensor(grads)
    inv_grads = 1 / (grads_tensor + 1e-9)
    weights = torch.softmax(inv_grads / temperature, dim=0)
    return weights.tolist()

def mutual_information_weights(**kwargs):
    student_logits = kwargs.get('student_logits')
    flan_teacher_logits = kwargs.get('flan_teacher_logits')
    t0_teacher_logits = kwargs.get('t0_teacher_logits')
    flan_t5_xl_teacher_logits = kwargs.get('flan_t5_xl_teacher_logits')
    temperature = kwargs.get('temperature', 1.0)

    def compute_mi(teacher_logits):
        joint_distribution = F.softmax((student_logits + teacher_logits) / 2, dim=-1)
        marginal_student = F.softmax(student_logits, dim=-1).mean(dim=0)
        marginal_teacher = F.softmax(teacher_logits, dim=-1).mean(dim=0)
        
        epsilon = 1e-10
        mi = (joint_distribution * torch.log(joint_distribution / (marginal_student * marginal_teacher + epsilon) + epsilon)).sum()
        return mi.item()
    
    mi_values = [compute_mi(flan_teacher_logits), compute_mi(t0_teacher_logits), compute_mi(flan_t5_xl_teacher_logits)]
    mi_tensor = torch.tensor(mi_values)
    mi_normalized = (mi_tensor - mi_tensor.min()) / (mi_tensor.max() - mi_tensor.min() + 1e-9)
    weights = torch.softmax(mi_normalized / temperature, dim=0)
    return weights.tolist()

def cross_entropy_weights(**kwargs):
    flan_teacher_logits = kwargs.get('flan_teacher_logits')
    t0_teacher_logits = kwargs.get('t0_teacher_logits')
    flan_t5_xl_teacher_logits = kwargs.get('flan_t5_xl_teacher_logits')

    flan_teacher_preds = kwargs.get('flan_teacher_preds')
    t0_teacher_preds = kwargs.get('t0_teacher_preds')
    flan_t5_xl_teacher_preds = kwargs.get('flan_t5_xl_teacher_preds')
    temperature = kwargs.get('temperature', 1.0)

    ce_flan_t0 = F.cross_entropy(t0_teacher_logits.view(-1, t0_teacher_logits.size(-1)), flan_teacher_preds.view(-1), reduction='mean')
    ce_flan_flant5xl = F.cross_entropy(flan_t5_xl_teacher_logits.view(-1, flan_t5_xl_teacher_logits.size(-1)), flan_teacher_preds.view(-1), reduction='mean')
    ce_t0_flan = F.cross_entropy(flan_teacher_logits.view(-1, flan_teacher_logits.size(-1)), t0_teacher_preds.view(-1), reduction='mean')
    ce_t0_flant5xl = F.cross_entropy(flan_t5_xl_teacher_logits.view(-1, flan_t5_xl_teacher_logits.size(-1)), t0_teacher_preds.view(-1), reduction='mean')
    ce_flant5xl_flan = F.cross_entropy(flan_teacher_logits.view(-1, flan_teacher_logits.size(-1)), flan_t5_xl_teacher_preds.view(-1), reduction='mean')
    ce_flant5xl_t0 = F.cross_entropy(t0_teacher_logits.view(-1, t0_teacher_logits.size(-1)), flan_t5_xl_teacher_preds.view(-1), reduction='mean')

    ce_losses = [ce_flan_t0, ce_flan_flant5xl, ce_t0_flan, ce_t0_flant5xl, ce_flant5xl_flan, ce_flant5xl_t0]

    inv_ce_losses = [1 / ce_loss.item() if ce_loss.item() > 0 else 1/(1e-9) for ce_loss in ce_losses]

    weight_flan = (inv_ce_losses[0] + inv_ce_losses[1])
    weight_t0 = (inv_ce_losses[2] + inv_ce_losses[3])
    weight_flant5xl = (inv_ce_losses[4] + inv_ce_losses[5])

    weights = torch.tensor([weight_flan, weight_t0, weight_flant5xl])
    weights / weights.norm(dim=-1, keepdim=True)
    weights = torch.softmax(weights / temperature, dim=0)
    return weights.tolist()


In [None]:
def compute_losses(student_logits, teacher_logits, teacher_preds):
    # Ensure the logits are in floating point
    student_logits = student_logits.float()
    teacher_logits = teacher_logits.float()

    ce_loss = F.cross_entropy(student_logits.view(-1, student_logits.size(-1)), teacher_preds.view(-1), ignore_index=-100)
    kl_loss = F.kl_div(F.log_softmax(student_logits, dim=-1), F.softmax(teacher_logits, dim=-1), reduction='batchmean')
    return ce_loss, kl_loss


def compute_teacher_outputs(model, input_ids, decoder_input_ids, device):
    with torch.no_grad():
        outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
        logits = outputs.logits.to(device)
        preds = torch.argmax(logits, dim=-1)
    return logits, preds


def train_epoch(student_model, optimizer, flan_model, t0_model, flan_t5_xl_model, dataloader, device, alpha, temperature, weight_function=None, attention_module=None, use_hard_weighting=False):
    student_model.train()
    total_loss = 0
    progress_bar = tqdm(dataloader, desc="Training")

    for i, batch in enumerate(progress_bar):
        input_ids = batch['input_ids'].to(device)
        decoder_input_ids = batch['decoder_input_ids'].to(device)
        optimizer.zero_grad()

        flan_teacher_logits, flan_teacher_preds = compute_teacher_outputs(flan_model, input_ids, decoder_input_ids, device)
        t0_teacher_logits, t0_teacher_preds = compute_teacher_outputs(t0_model, input_ids, decoder_input_ids, device)
        flan_t5_xl_teacher_logits, flan_t5_xl_teacher_preds = compute_teacher_outputs(flan_t5_xl_model, input_ids, decoder_input_ids, device)

        student_outputs = student_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
        student_logits = student_outputs.logits

        if attention_module:
            # Gather teacher hidden states
            flan_teacher_hidden_states = flan_model.encoder(input_ids=input_ids)[0].to(torch.float16)
            t0_teacher_hidden_states = t0_model.encoder(input_ids=input_ids)[0].to(torch.float16)
            flan_t5_xl_teacher_hidden_states = flan_t5_xl_model.encoder(input_ids=input_ids)[0].to(torch.float16)
            teacher_hidden_states_list=[
                    flan_teacher_hidden_states, 
                    t0_teacher_hidden_states, 
                    flan_t5_xl_teacher_hidden_states
                ]
            # Compute teacher weights using the attention module
            teacher_weights = attention_module(
                teacher_hidden_states_list
            )
            weight_flan, weight_t0, weight_flant5xl = teacher_weights[0], teacher_weights[1], teacher_weights[2]
            # Ensure weights are of shape [batch_size, 1]
        elif weight_function:
            # Use the provided weight function to compute weights
            weights = weight_function(
                student_logits=student_logits, 
                flan_teacher_logits=flan_teacher_logits, 
                flan_teacher_preds=flan_teacher_preds, 
                t0_teacher_logits=t0_teacher_logits,
                t0_teacher_preds=t0_teacher_preds, 
                flan_t5_xl_teacher_logits=flan_t5_xl_teacher_logits, 
                flan_t5_xl_teacher_preds=flan_t5_xl_teacher_preds, 
                temperature=temperature
            )
            # Convert the weights to tensors
            weights = torch.tensor(weights).to(device)
            weight_flan, weight_t0, weight_flant5xl = weights[0], weights[1], weights[2]

        else:
            # Default to equal weighting if no attention module or weight function is provided
            weight_flan = weight_t0 = weight_flant5xl = torch.tensor(1.0 / 3.0).to(device)

        if use_hard_weighting:
            # Convert soft weights to hard weights
            max_weight = max(weight_flan, weight_t0, weight_flant5xl)
            weight_flan = torch.tensor(1.0 if weight_flan == max_weight else 0.0).to(device)
            weight_t0 = torch.tensor(1.0 if weight_t0 == max_weight else 0.0).to(device)
            weight_flant5xl = torch.tensor(1.0 if weight_flant5xl == max_weight else 0.0).to(device)

        ce_loss_flan, kl_loss_flan = compute_losses(student_logits, flan_teacher_logits, flan_teacher_preds)
        ce_loss_t0, kl_loss_t0 = compute_losses(student_logits, t0_teacher_logits, t0_teacher_preds)
        ce_loss_flan_t5_xl, kl_loss_flan_t5_xl = compute_losses(student_logits, flan_t5_xl_teacher_logits, flan_t5_xl_teacher_preds)

        ce_loss = (weight_flan * ce_loss_flan + weight_t0 * ce_loss_t0 + weight_flant5xl * ce_loss_flan_t5_xl).mean()

        kl_loss = (weight_flan * kl_loss_flan + weight_t0 * kl_loss_t0 + weight_flant5xl * kl_loss_flan_t5_xl).mean()
        loss = alpha * ce_loss + (1 - alpha) * kl_loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        progress_bar.set_description(
            f"Batch {i+1}/{len(dataloader)} | "
            f"Total Loss: {total_loss/(i+1):.4f} | "
            f"Weights -> FLAN: {weight_flan.mean().item():.2f}, T0: {weight_t0.mean().item():.2f}, FLAN-XL: {weight_flant5xl.mean().item():.2f}"
        )

    avg_loss = total_loss / len(dataloader)
    return avg_loss

def validate_student_model(student_model, val_dataloader, device):
    student_model.eval()
    val_loss = 0

    loss_function = torch.nn.CrossEntropyLoss(ignore_index=-100) 

    with torch.no_grad():
        for val_batch in val_dataloader:
            input_ids = val_batch['input_ids'].to(device)
            decoder_input_ids = val_batch['decoder_input_ids'].to(device)
            labels = val_batch['labels'].to(device)

            student_outputs = student_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, labels=labels)
            student_logits = student_outputs.logits

            # Compute loss between student model outputs and true labels
            ce_loss = loss_function(student_logits.view(-1, student_logits.size(-1)), labels.view(-1))
            val_loss += ce_loss.item()

    avg_val_loss = val_loss / len(val_dataloader)
    return avg_val_loss

def train_student_model(student_model, flan_model, flan_t5_xl_model, t0_model, dataloader, val_dataloader, weight_function=None, attention_module=None, use_hard_weighting=False, alpha=0.8, num_epochs=1, lr=5e-5, weight_decay=0.01, temperature=0.5, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    student_model.to(device)
    optimizer = Adam(student_model.parameters(), lr=lr, weight_decay=weight_decay)

    flan_model.eval()
    flan_t5_xl_model.eval()
    t0_model.eval()

    for epoch in range(num_epochs):
        avg_loss = train_epoch(student_model, optimizer, flan_model, t0_model, flan_t5_xl_model, dataloader, device, alpha, temperature, weight_function, attention_module, use_hard_weighting)
        print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

        avg_val_loss = validate_student_model(student_model, val_dataloader, device)
        print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {avg_val_loss:.4f}")


In [None]:
weighting_functions = [
    entropy_based_weights,
    gradient_based_weights,
    mutual_information_weights,
    cross_entropy_weights
]

weighting_function_names = [
    "entropy_based",
    "gradient_based",
    "mutual_information",
    "cross_entropy"
]

# Iterate through each weighting function
for i, weight_function in enumerate(weighting_functions):
    for use_hard_weighting in [False, True]:
        # Initialize the student model and apply LoRA
        student_model_name = "google/flan-t5-base"  
        student_model = AutoModelForSeq2SeqLM.from_pretrained(student_model_name)
        student_model = get_peft_model(student_model, lora_config)
        
        student_model.to(device)

        config_name = f"{weighting_function_names[i]}_{'hard' if use_hard_weighting else 'soft'}_weighting"

        print(f"Training with {config_name}")

        train_student_model(
            student_model=student_model,
            flan_model=flan_model,
            flan_t5_xl_model=flan_t5_xl_model,
            t0_model=t0_model,
            dataloader=dataloader,
            val_dataloader=val_dataloader,
            weight_function=weight_function,
            attention_module=None, 
            use_hard_weighting=use_hard_weighting,
            device=device  
        )

        # Save the trained model
        student_model_save_path = f"student_model_{config_name}"
        student_model.save_pretrained(student_model_save_path)

        del student_model
        torch.cuda.empty_cache()


In [None]:
num_teachers = 3
hidden_sizes = [1024, 2048, 2048]  
attention_module = AttentionMechanism(num_teachers=num_teachers, hidden_sizes=hidden_sizes)

student_model_name = "google/flan-t5-base"  
student_model = AutoModelForSeq2SeqLM.from_pretrained(student_model_name)
student_model = get_peft_model(student_model, lora_config)

student_model.to(device)

config_name = "attention_based_weighting"
print(f"Training with {config_name}")

# Train the student model with attention-based weighting
train_student_model(
    student_model=student_model,
    flan_model=flan_model,
    flan_t5_xl_model=flan_t5_xl_model,
    t0_model=t0_model,
    dataloader=dataloader,
    val_dataloader=val_dataloader,
    weight_function=None,  # Set this to None when using attention
    attention_module=attention_module,  # Use the attention module
    use_hard_weighting=False,
    device=device  # Ensure model and data are on the correct device
)

student_model_save_path = f"student_model_{config_name}"
student_model.save_pretrained(student_model_save_path)


In [None]:
student_model_name = "google/flan-t5-base"  # Adjust the model name if necessary
student_model = AutoModelForSeq2SeqLM.from_pretrained(student_model_name)
student_model = get_peft_model(student_model, lora_config)

student_model.to(device)

config_name = "equal_weighting"
print(f"Training with {config_name}")

train_student_model(
    student_model=student_model,
    flan_model=flan_model,
    flan_t5_xl_model=flan_t5_xl_model,
    t0_model=t0_model,
    dataloader=dataloader,
    val_dataloader=val_dataloader,
    weight_function=None,  
    attention_module=None,  
    use_hard_weighting=False,
    device=device  
)

student_model_save_path = f"student_model_{config_name}"
student_model.save_pretrained(student_model_save_path)
student_model.save_pretrained(f"lora_adapter_{config_name}")

del student_model
torch.cuda.empty_cache()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_path = "student_model_entropy_based_soft_weighting"  # Path where the model was saved
base_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")

loaded_model = PeftModel.from_pretrained(base_model, "lora_adapter")

# Move the model to the device
loaded_model.to(device)

original_model = AutoModelForSeq2SeqLM.from_pretrained(student_model_name)
original_model.to(device)

1+1


In [None]:
def generate_text(student_model, text, max_new_tokens=30):
    inputs = student_tokenizer(text, return_tensors="pt").to(device)
    outputs = student_model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        output_scores=True,  
        return_dict_in_generate=True  
    )
    
    generated_text = student_tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
    print("Generated text:", generated_text)
    
    logits = outputs.scores[-2]  

    # Apply softmax to get probabilities
    probs = F.softmax(logits, dim=-1)
    
    # Get the top 5 tokens and their probabilities
    top5_probs, top5_tokens = torch.topk(probs, 5, dim=-1)
    
    # Decode and print the top 5 tokens with their probabilities
    print("Top 5 most likely tokens and their probabilities:")
    for i in range(5):
        token = student_tokenizer.decode(top5_tokens[0, i].item())
        probability = top5_probs[0, i].item()
        print(f"Token: '{token}', Probability: {probability:.4f}")


In [None]:
definition = "You are given a question on high school mathematics. You are also given 4 answer options (associated with \"A\", \"B\", \"C\", \"D\"), out of which only one is correct. You need to answer the question by selecting the correct option. You should only answer with the choice letter, not the whole answer."
input_text = "What is the greatest common factor of 252 and 96?\n(A)6 (B)24 (C)5 (D)12"
prompt = f"{definition}\n{input_text}\n"
print("\nflan model\n")
generate_text(flan_model, prompt)
print("\nxl model\n")
generate_text(flan_t5_xl_model, prompt)
print("\nt0 model\n")
generate_text(t0_model, prompt)
print("Before training\n")
generate_text(original_model, prompt)
print("after  training\n")
generate_text(loaded_model, prompt)


In [None]:
student_model_paths = {
    'cross_entropy_soft': "student_model_cross_entropy_soft_weighting",
    'cross_entropy_hard': "student_model_cross_entropy_hard_weighting",
    'entropy_based_soft': "student_model_entropy_based_soft_weighting",
    'entropy_based_hard': "student_model_entropy_based_hard_weighting",
    'mutual_information_soft': "student_model_mutual_information_soft_weighting",
    'mutual_information_hard': "student_model_mutual_information_hard_weighting",
    'gradient_based_soft': "student_model_gradient_based_soft_weighting",
    'gradient_based_hard': "student_model_gradient_based_hard_weighting",
    'attention_based_soft': "student_model_attention_based_weighting_soft",
    'equal_weighting': "student_model_equal_weighting"
}

# Add the base flan-t5 model to the list for evaluation
student_model_paths['flan_t5_base'] = "google/flan-t5-base"

# Define the path to your test data directory
test_data_dir = "./splits/default/test_tasks.txt"  # Adjust this path as needed
tasks_dir = "./tasks/"  # Ensure this points to the correct tasks directory

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")

# Prepare the data loader
def prepare_dataloader_for_task(test_data_dir, tasks_dir, task_num, tokenizer, max_instances=100):
    test_instances = prepare_data(test_data_dir, tasks_dir, task_num, task_num)
    test_instances = test_instances[:max_instances]  # Take only the first 100 instances
    if len(test_instances) < 600:  # Check if task has fewer than 600 instances
        test_dataset = InstructionDataset(test_instances, tokenizer)
        test_dataloader = DataLoader(
            dataset=test_dataset,
            batch_size=1,
            shuffle=False
        )
        return test_dataloader
    return None

# Evaluation function to test the student models
def evaluate_student_model_with_rouge(model, dataloader, tokenizer, device):
    results = []
    rouge_scores = []

    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

    for batch in tqdm(dataloader, desc="Evaluating Model"):
        input_ids = batch['input_ids'].to(device)
        references = batch['labels']

        with torch.no_grad():
            outputs = model.generate(input_ids=input_ids)
            decoded_preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            decoded_references = tokenizer.batch_decode(references, skip_special_tokens=True)
            results.extend(decoded_preds)

            # Calculate ROUGE scores for each prediction
            for pred, ref in zip(decoded_preds, decoded_references):
                score = scorer.score(ref, pred)
                rouge_scores.append(score)

    return results, rouge_scores

# Define the path to your output file
output_file_path = "average_rouge_scores.txt"
# Iterate over tasks and models, loading one model at a time
with open(output_file_path, "a") as f:  # Open in append mode to add results to the file
    for task_num in range(1,14):
        test_dataloader = prepare_dataloader_for_task(test_data_dir, tasks_dir, task_num, tokenizer, max_instances=110)
        if test_dataloader:
            for model_key, model_path in student_model_paths.items():
                if model_key == 'flan_t5_base':
                    # Load the base model without LoRA adapter
                    model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to(device)
                else:
                    # Load the base FLAN-T5 model first
                    base_model = AutoModelForSeq2SeqLM.from_pretrained(student_model_paths['flan_t5_base'])
                    # Load the LoRA adapter and combine it with the base model
                    model = PeftModel.from_pretrained(base_model, model_path).to(device)
                
                model.eval()  # Set the model to evaluation mode

                # Evaluate the model
                _, student_rouge_scores = evaluate_student_model_with_rouge(model, test_dataloader, tokenizer, device)

                # Calculate average ROUGE scores
                avg_score = {
                    'rougeL': sum(score['rougeL'].fmeasure for score in student_rouge_scores) / len(student_rouge_scores),
                }

                # Write the average ROUGE scores to the output file
                f.write(f"Task: {task_num}\n")
                f.write(f"  Model: {model_key}\n")
                f.write(f"    ROUGE-L: {avg_score['rougeL']:.4f}\n")
                f.write("\n")

                # Clear the model from memory
                del model
                torch.cuda.empty_cache()

print(f"Average ROUGE scores to {output_file_path}")
