In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoProcessor, Llama4ForConditionalGeneration, default_data_collator, get_linear_schedule_with_warmup
from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, PrefixTuningConfig, TaskType
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
import os
import pandas as pd

# load csv from data/MELD
train_df = pd.read_csv("data/MELD/train_sent_emo.csv")
valid_df = pd.read_csv("data/MELD/dev_sent_emo.csv")
test_df = pd.read_csv("data/MELD/test_sent_emo.csv")


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_df[:1]

Unnamed: 0,Sr No.,Utterance,Speaker,Emotion,Sentiment,Dialogue_ID,Utterance_ID,Season,Episode,StartTime,EndTime
0,1,also I was the point person on my company’s tr...,Chandler,neutral,neutral,0,0,8,21,"00:16:16,059","00:16:21,731"


In [None]:
# Apply prefix tuning to solve the emotion recognition task
# First, do the language modeling task with prefix tuning
# The prefix are the emotions in the dataset, and the model will learn to predict the next word based on the prefix
# There's also an additional prefix for answering the emotion recognition task.


model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"

processor = AutoProcessor.from_pretrained(model_id)
model = Llama4ForConditionalGeneration.from_pretrained(
    model_id,
    attn_implementation="flex_attention",
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

def get_data_loader(df, tokenizer, max_length=128, batch_size=4):
    """
    Create a DataLoader for the dataset.
    """
    def encode(examples):
        return tokenizer(
            examples['Utterance'],
            truncation=True,
            padding='max_length',
            max_length=max_length,
            return_tensors='pt'
        )

    # Encode the dataset
    encoded_dataset = df.apply(encode, axis=1).tolist()
    
    # Create DataLoader
    data_loader = DataLoader(
        encoded_dataset,
        batch_size=batch_size,
        collate_fn=default_data_collator
    )
    
    return data_loader

