In [9]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoProcessor, Llama4ForConditionalGeneration, default_data_collator, get_linear_schedule_with_warmup
from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, PrefixTuningConfig, TaskType
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
import os
import pandas as pd


# AbstractPrompt Classes

In [None]:
from peft import PeftModelForSeq2SeqLM, PeftConfig, PromptEncoderReparameterizationType, PromptEncoderConfig
from peft.tuners.prefix_tuning import PrefixEncoder
from peft.tuners.p_tuning import PromptEncoder
from peft.utils import _get_batch_size, PeftType, TaskType, TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING
from transformers import PreTrainedModel, DynamicCache, EncoderDecoderCache, map_cache_to_layer_device_map
from typing import Optional
import torch
import numpy as np
import warnings
from dataclasses import dataclass, field

@dataclass
class AbstractPromptEncoderConfig(PromptEncoderConfig):
    """
    This is the configuration class to store the configuration of a [`PromptEncoder`].

    Args:
        encoder_reparameterization_type (Union[[`PromptEncoderReparameterizationType`], `str`]):
            The type of reparameterization to use.
        encoder_hidden_size (`int`): The hidden size of the prompt encoder.
        encoder_num_layers (`int`): The number of layers of the prompt encoder.
        encoder_dropout (`float`): The dropout probability of the prompt encoder.
    """

    num_subjects: int = field(
        default=8,
        metadata={"help": "The number of subjects of the prompt encoder"},
    )
    padding_idx: int = field(
        default=None,
        metadata={"help": "The padding index of the prompt encoder"},
    )
    def __post_init__(self):
        super().__post_init__()
        self.peft_type = PeftType.P_TUNING #TODO: switch to APTuning


class AbstractPromptEncoder(PromptEncoder):
    def __init__(self, config):
        super().__init__(config)
        self.num_subjects = config.num_subjects
        self.total_virtual_tokens = config.num_virtual_tokens * config.num_subjects * config.num_transformer_submodules
        if config.padding_idx is not None:
            self.padding_idx = config.padding_idx
        else:
            self.padding_idx = self.total_virtual_tokens
            self.total_virtual_tokens += 1

        # embedding
        self.embedding = torch.nn.Embedding(self.total_virtual_tokens, self.token_dim, padding_idx=self.padding_idx)
        if not config.inference_mode:
            if self.encoder_type == PromptEncoderReparameterizationType.LSTM:
                lstm_dropout = config.encoder_dropout
                num_layers = config.encoder_num_layers
                # LSTM
                self.lstm_head = torch.nn.LSTM(
                    input_size=self.input_size,
                    hidden_size=self.hidden_size,
                    num_layers=num_layers,
                    dropout=lstm_dropout,
                    bidirectional=True,
                    batch_first=True,
                )

                self.mlp_head = torch.nn.Sequential(
                    torch.nn.Linear(self.hidden_size * 2, self.hidden_size * 2),
                    torch.nn.ReLU(),
                    torch.nn.Linear(self.hidden_size * 2, self.output_size),
                )

            elif self.encoder_type == PromptEncoderReparameterizationType.MLP:
                encoder_num_layers_default = PromptEncoderConfig.encoder_num_layers
                if config.encoder_num_layers != encoder_num_layers_default:
                    warnings.warn(
                        f"for {self.encoder_type.value}, the argument `encoder_num_layers` is ignored. "
                        f"Exactly {encoder_num_layers_default} MLP layers are used."
                    )
                layers = [
                    torch.nn.Linear(self.input_size, self.hidden_size),
                    torch.nn.ReLU(),
                    torch.nn.Linear(self.hidden_size, self.hidden_size),
                    torch.nn.ReLU(),
                    torch.nn.Linear(self.hidden_size, self.output_size),
                ]
                self.mlp_head = torch.nn.Sequential(*layers)

            else:
                raise ValueError("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM.")


class AbstractPeftModelForSeq2SeqLM(PeftModelForSeq2SeqLM):
    def __init__(self, model, peft_config):
        super().__init__(model, peft_config)
        self.model = model
        self.peft_config = peft_config

    def _setup_prompt_encoder(self, adapter_name: str):
        config = self.peft_config[adapter_name]
        if not hasattr(self, "prompt_encoder"):
            self.prompt_encoder = torch.nn.ModuleDict({})
            self.prompt_tokens = {}
        transformer_backbone = None
        for name, module in self.base_model.named_children():
            for param in module.parameters():
                param.requires_grad = False
            if isinstance(module, PreTrainedModel):
                # Make sure to freeze Tranformers model
                if transformer_backbone is None:
                    transformer_backbone = module
                    self.transformer_backbone_name = name
        if transformer_backbone is None:
            transformer_backbone = self.base_model

        if config.num_transformer_submodules is None:
            config.num_transformer_submodules = 2 if config.task_type == TaskType.SEQ_2_SEQ_LM else 1

        # determine the word embeddings
        word_embeddings = None
        try:
            # First try to find the word embeddings based on the module name, this should work for models like Bert,
            # Roberta, Deberta, etc.
            word_embeddings = self.base_model.get_submodule("embeddings.word_embeddings")
        except AttributeError:
            pass

        if word_embeddings is None:
            # Word embeddings could not be determined. Next try to guess them by checking which parameter has the size
            # of the vocab.
            for named_param, value in list(transformer_backbone.named_parameters()):
                # for ZeRO-3, the tensor is sharded across accelerators and deepspeed modifies it to a tensor with shape
                # [0] the actual unsharded shape is stored in "ds_shape" attribute special handling is needed in case
                # the model is initialized in deepspeed.zero.Init() context or HfDeepSpeedConfig has been called before
                # For reference refer to issue: https://github.com/huggingface/peft/issues/996
                deepspeed_distributed_tensor_shape = getattr(value, "ds_shape", None)

                if value.shape[0] == self.base_model.config.vocab_size or (
                    deepspeed_distributed_tensor_shape is not None
                    and deepspeed_distributed_tensor_shape[0] == self.base_model.config.vocab_size
                ):
                    word_embeddings = transformer_backbone.get_submodule(named_param.replace(".weight", ""))
                    break

        self.word_embeddings = word_embeddings

        prompt_encoder = AbstractPromptEncoder(config)

        prompt_encoder = prompt_encoder.to(self.device)
        self.prompt_encoder.update(torch.nn.ModuleDict({adapter_name: prompt_encoder}))
        
        self.prompt_tokens[adapter_name] = torch.arange(
            config.num_virtual_tokens * config.num_subjects * config.num_transformer_submodules
        ).long()

    def get_prompt(self, batch_size: int, object_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Returns the virtual prompts to use for Peft. Only applicable when using a prompt learning method.

        The object_ids should be a tensor of shape 
        """
        peft_config = self.active_peft_config
        prompt_encoder = self.prompt_encoder[self.active_adapter]
        # arrange the prompt tokens according to the object id
        if object_ids is not None:
            object_ids = object_ids.view(batch_size, -1)
            prompt_tokens = (object_ids * peft_config.num_virtual_tokens)[:, :, None] + np.arange(peft_config.num_virtual_tokens)
            # set token id to padding id if it is out of range
            prompt_tokens[prompt_tokens > prompt_encoder.padding_idx] = prompt_encoder.padding_idx
        else:
            # prompt_tokens = (
            #     self.prompt_tokens[self.active_adapter][-peft_config.num_virtual_tokens:]
            #     .unsqueeze(0)
            #     .expand(batch_size, -1)
            #     .to(prompt_encoder.embedding.weight.device)
            # )
            raise ValueError("object_ids is None, please provide object_ids for AbstractPeftModelForSeq2SeqLM")
        

        if peft_config.inference_mode:
            # prompts = prompt_encoder.embedding.weight
            # TODO: add support for inference mode
            pass
        else:
            prompts = prompt_encoder(prompt_tokens) # (batch_size, num_virtual_tokens, hidden_size)
        return prompts
    
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        inputs_embeds=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        decoder_inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        object_ids=None,

        **kwargs,
    ):
        """
        args:
        object_ids: (batch_size, num_subjects) tensor of object ids. The object ids are used to select the prompt tokens
        """
        peft_config = self.active_peft_config
        batch_size = _get_batch_size(input_ids, inputs_embeds)
        if decoder_attention_mask is not None:
            # concat prompt attention mask
            prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(
                decoder_attention_mask.device
            )
            if peft_config.peft_type not in [PeftType.PROMPT_TUNING, PeftType.P_TUNING]:
                decoder_attention_mask = torch.cat((prefix_attention_mask, decoder_attention_mask), dim=1)

        if kwargs.get("position_ids", None) is not None:
            warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
            kwargs["position_ids"] = None
        if kwargs.get("token_type_ids", None) is not None:
            warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids")
            kwargs["token_type_ids"] = None
        kwargs.update(
            {
                "attention_mask": attention_mask,
                "decoder_attention_mask": decoder_attention_mask,
                "labels": labels,
                "output_attentions": output_attentions,
                "output_hidden_states": output_hidden_states,
                "return_dict": return_dict,
            }
        )

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)

        if attention_mask is not None:
            # concat prompt attention mask
            prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(
                attention_mask.device
            )
            kwargs["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1)

        prompts = self.get_prompt(batch_size=batch_size, object_ids=object_ids)
        prompts = prompts.to(inputs_embeds.dtype)

        # TODO: merge prompt and input embeddings
        if inputs_embeds.dim()==4:
            # inputs_embeds: (batch_size, conversation_len, sequence_len, hidden_size)
            # prompts: (batch_size, conversation_len*num_virtual_tokens, hidden_size)
            conversation_len = inputs_embeds.shape[1]
            prompts = prompts.view(batch_size, conversation_len, peft_config.num_virtual_tokens, peft_config.token_dim)
            inputs_embeds = torch.cat([prompts, inputs_embeds], dim=2)


        elif inputs_embeds.dim()==3:
            # inputs_embeds: (batch_size, sequence_len, hidden_size)
            # prompts: (batch_size, num_virtual_tokens, hidden_size)
            pass
            # prompts = prompts.unsqueeze(1).expand(-1, inputs_embeds.shape[1], -1, -1).reshape(batch_size, -1, inputs_embeds.shape[-1])

        return self.base_model(
            inputs_embeds=inputs_embeds,
            decoder_input_ids=decoder_input_ids,
            decoder_inputs_embeds=decoder_inputs_embeds,
            **kwargs,
        )

# create next token prediction dataset

## Done part

In [14]:
# load csv from data/MELD
raw_train_df = pd.read_csv("data/MELD/train_sent_emo.csv")
raw_valid_df = pd.read_csv("data/MELD/dev_sent_emo.csv")
raw_test_df = pd.read_csv("data/MELD/test_sent_emo.csv")

In [15]:
# length of the dataset
print(f"train_df: {len(raw_train_df)}")
# check the label types in Emotion
print(raw_train_df['Emotion'].unique())
raw_train_df

train_df: 9989
['neutral' 'surprise' 'fear' 'sadness' 'joy' 'disgust' 'anger']


Unnamed: 0,Sr No.,Utterance,Speaker,Emotion,Sentiment,Dialogue_ID,Utterance_ID,Season,Episode,StartTime,EndTime
0,1,also I was the point person on my company’s tr...,Chandler,neutral,neutral,0,0,8,21,"00:16:16,059","00:16:21,731"
1,2,You must’ve had your hands full.,The Interviewer,neutral,neutral,0,1,8,21,"00:16:21,940","00:16:23,442"
2,3,That I did. That I did.,Chandler,neutral,neutral,0,2,8,21,"00:16:23,442","00:16:26,389"
3,4,So let’s talk a little bit about your duties.,The Interviewer,neutral,neutral,0,3,8,21,"00:16:26,820","00:16:29,572"
4,5,My duties? All right.,Chandler,surprise,positive,0,4,8,21,"00:16:34,452","00:16:40,917"
...,...,...,...,...,...,...,...,...,...,...,...
9984,10474,You or me?,Chandler,neutral,neutral,1038,13,2,3,"00:00:48,173","00:00:50,799"
9985,10475,"I got it. Uh, Joey, women don't have Adam's ap...",Ross,neutral,neutral,1038,14,2,3,"00:00:51,009","00:00:53,594"
9986,10476,"You guys are messing with me, right?",Joey,surprise,positive,1038,15,2,3,"00:01:00,518","00:01:03,520"
9987,10477,Yeah.,All,neutral,neutral,1038,16,2,3,"00:01:05,398","00:01:07,274"


In [16]:
# check emotion labels tpyes
print(raw_train_df['Emotion'].unique())

['neutral' 'surprise' 'fear' 'sadness' 'joy' 'disgust' 'anger']


In [70]:
#take only Utterance and Emotion columns
train_df = raw_train_df[['Dialogue_ID', 'Utterance', 'Emotion']]
valid_df = raw_valid_df[['Dialogue_ID', 'Utterance', 'Emotion']]
test_df = raw_test_df[['Dialogue_ID', 'Utterance', 'Emotion']]


### Tokenize the data

In [72]:
# tokenize the 
# load the tokenizer and model

model_name_or_path = "t5-base"
tokenizer_name_or_path = "t5-base"

text_column = "Utterance"
label_column = "Emotion"
max_length = 128
lr = 1e-2
num_epochs = 5
batch_size = 8

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
emotion_name_to_id = {name:i for i, name in enumerate(raw_train_df['Emotion'].unique())}
query_id = len(emotion_name_to_id)

def tokenize_function(examples):
    # print(examples)
    utterance = examples[text_column]
    emotion = examples[label_column]
    utterance_id = tokenizer(utterance, max_length=max_length, padding=False, truncation=False, return_tensors="pt")["input_ids"]
    emotion_id = emotion_name_to_id[emotion]
    
    return {"Dialogue_ID":examples["Dialogue_ID"], "UtteranceID": utterance_id, "EmotionID": emotion_id}
    
dialogue_train_df = train_df.apply(
    tokenize_function,
    axis=1,
    result_type="expand",
)
dialogue_train_df = dialogue_train_df.groupby('Dialogue_ID').agg(list).reset_index()

# train_dataloader = DataLoader(
#     train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True
# )
# eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)

In [None]:
#

In [19]:
dialogue_train_df["UtteranceID"][0]

[tensor([[  92,   27,   47,    8,  500,  568,   30,   82,  349,   22,    7, 3508,
            45,    8,  480,  434, 4525,   12,    3, 8727, 5783,  358,    5,    1]]),
 tensor([[ 148,  398,   22,  162,  141,   39, 1780,  423,    5,    1]]),
 tensor([[466,  27, 410,   5, 466,  27, 410,   5,   1]]),
 tensor([[ 264,  752,   22,    7, 1350,    3,    9,  385,  720,   81,   39, 9353,
             5,    1]]),
 tensor([[ 499, 9353,   58,  432,  269,    5,    1]]),
 tensor([[ 852,   25,   22,  195,   36, 6904,    3,    9,  829, 4889,    6,   78,
            25,   22,  195,   43,    3,    9,  418,   13, 9353,    5,    1]]),
 tensor([[ 27, 217,   5,   1]]),
 tensor([[  299,   132,    22,   195,    36,  2361,   604,   151,   365,    25,
             78,    25,    54, 11986,     3,     9,   824,   866,    30,   135,
              5,     1]]),
 tensor([[1804,   12,  214,    5,    1]]),
 tensor([[ 101,   54,  281,  139, 2736,    1]]),
 tensor([[465, 278,  22,  17,  27,  36, 122,  13,  25,  55,   1]]),

In [20]:
dialogue_train_df["EmotionID"]

0              [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 1, 0]
1                                   [1, 3, 1, 2, 0, 0, 0]
2                 [0, 0, 4, 3, 1, 0, 5, 3, 0, 0, 4, 0, 4]
3                             [1, 1, 1, 0, 0, 0, 1, 3, 0]
4           [1, 4, 1, 0, 0, 0, 0, 0, 4, 4, 4, 3, 0, 0, 0]
                              ...                        
1033                                            [0, 4, 1]
1034                                [6, 0, 6, 0, 0, 6, 4]
1035    [4, 3, 1, 3, 1, 0, 1, 0, 3, 0, 6, 0, 0, 0, 0, ...
1036                                            [0, 0, 0]
1037    [0, 4, 0, 0, 1, 5, 0, 5, 1, 0, 0, 5, 5, 0, 0, ...
Name: EmotionID, Length: 1038, dtype: object

### ProgressiveDialogueDataset Class

In [49]:
from torch.utils.data import Dataset, DataLoader
import bisect


class ProgressiveDialogueDataset(Dataset):
    def __init__(self, dialogue_data, utterance_name="UtteranceID", emotion_name="EmotionID", max_length=None, utterance_padding_token_id=0, query_id=7, num_virtual_tokens=20):
        """
        Args:
            dialogue_data: A list of lists, where each inner list contains sentences
                           from one dialogue.
        """
        self.utterances = dialogue_data[utterance_name]
        self.emotions = dialogue_data[emotion_name]

        if max_length is None:
            max_length=0
            for utts in self.utterances:
                length = 0
                for utt in utts:
                    length += utt.shape[1]+num_virtual_tokens
                if length > max_length:
                    max_length = length
        self.max_length = max_length

        self.query_id = query_id
        self.utterance_padding_token_id = utterance_padding_token_id
        self.num_virtual_tokens = num_virtual_tokens


        self.cumulative_sentence_counts = []
        self.row_lengths = [len(row) for row in self.utterances]
        self.total_sentences = 0
        for length in self.row_lengths:
            self.total_sentences += length
            self.cumulative_sentence_counts.append(self.total_sentences)

    def __len__(self):
        return self.total_sentences

    def __getitem__(self, idx):

        # Find which row this index belongs to
        row_index = bisect.bisect_right(self.cumulative_sentence_counts, idx)

        # Find the index of the sentence within that row
        sentence_index = idx - self.cumulative_sentence_counts[row_index]

        # Get the progressive sequence of sentences, save space for virtual tokens and concatenate every sentence
        utterances = self.utterances[row_index][:sentence_index + 1]
        utterance_lengths = [utt.shape[1] for utt in utterances]

        # calculate the desired length of the utterance
        begin_index = len(utterance_lengths)
        cumulative_length = 0
        for i in range(len(utterance_lengths)-1,0-1, -1):
            # add virtual token space to the beginning of the sentence
            utterances[i] = torch.cat([ torch.full((1, self.num_virtual_tokens), self.utterance_padding_token_id, dtype=torch.long), 
                                        utterances[i]], dim=1)
            cumulative_length += utterance_lengths[i]+self.num_virtual_tokens
            if cumulative_length > self.max_length:
                break
            begin_index = i
        # concatenate the sentences
        if begin_index == len(utterance_lengths):
            # if the first sentence is too long, we need to truncate it
            utterance = utterances[-1][:self.max_length]
            emotions = [self.query_id]
            attention_mask = torch.ones((1, self.max_length), dtype=torch.long)

        else:
            utterance = torch.cat(utterances[begin_index:], dim=1)
            emotions = self.emotions[row_index][begin_index:sentence_index]
            emotions = emotions + [self.query_id]
            attention_mask = torch.ones((1, utterance.shape[1]), dtype=torch.long)
            attention_mask = torch.cat([attention_mask, torch.zeros((1, self.max_length - utterance.shape[1]), dtype=torch.long)], dim=1)
            utterance = torch.cat([utterance, torch.full((1, self.max_length - utterance.shape[1]), self.utterance_padding_token_id, dtype=torch.long)], dim=1)

        labels = self.emotions[row_index][sentence_index]

        return {
            "UtteranceID": utterance,
            "EmotionID": emotions,
            "AttentionMask": attention_mask,
            "Label": labels,
        }




## working part

In [None]:
def 

# T5-Large

In [None]:
dataset = train_df

In [3]:
device = "cuda"
# model_name_or_path = "t5-large"
# tokenizer_name_or_path = "t5-large"

model_name_or_path = "t5-base"
tokenizer_name_or_path = "t5-base"

text_column = "sentence"
label_column = "text_label"
max_length = 128
lr = 1e-2
num_epochs = 5
batch_size = 8


In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

def preprocess_function(examples):
    inputs = examples[text_column]
    targets = examples[label_column]
    model_inputs = tokenizer(inputs, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt")
    labels = tokenizer(targets, max_length=2, padding="max_length", truncation=True, return_tensors="pt")
    labels = labels["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100
    model_inputs["labels"] = labels
    return model_inputs
    
processed_datasets = dataset.map(
    preprocess_function,
    batched=True,
    num_proc=1,
    remove_columns=dataset["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)

train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["validation"]

train_dataloader = DataLoader(
    train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True
)
eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)

## train

In [1]:
import torch

{i:torch.arange(3).long() for i in range(3)}

{0: tensor([0, 1, 2]), 1: tensor([0, 1, 2]), 2: tensor([0, 1, 2])}

In [4]:
peft_config = PrefixTuningConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, num_virtual_tokens=20)

model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


pytorch_model.bin:   0%|          | 0.00/892M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

trainable params: 368,640 || all params: 223,272,192 || trainable%: 0.1651


In [6]:
model


PeftModelForSeq2SeqLM(
  (base_model): T5ForConditionalGeneration(
    (shared): Embedding(32128, 768)
    (encoder): T5Stack(
      (embed_tokens): Embedding(32128, 768)
      (block): ModuleList(
        (0): T5Block(
          (layer): ModuleList(
            (0): T5LayerSelfAttention(
              (SelfAttention): T5Attention(
                (q): Linear(in_features=768, out_features=768, bias=False)
                (k): Linear(in_features=768, out_features=768, bias=False)
                (v): Linear(in_features=768, out_features=768, bias=False)
                (o): Linear(in_features=768, out_features=768, bias=False)
                (relative_attention_bias): Embedding(32, 12)
              )
              (layer_norm): T5LayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (1): T5LayerFF(
              (DenseReluDense): T5DenseActDense(
                (wi): Linear(in_features=768, out_features=3072, bias=False)
                (wo): Li

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=(len(train_dataloader) * num_epochs),
)

In [None]:
model = model.to(device)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for step, batch in enumerate(tqdm(train_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        total_loss += loss.detach().float()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

    model.eval()
    eval_loss = 0
    eval_preds = []
    for step, batch in enumerate(tqdm(eval_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
        loss = outputs.loss
        eval_loss += loss.detach().float()
        eval_preds.extend(
            tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)
        )

    eval_epoch_loss = eval_loss / len(eval_dataloader)
    eval_ppl = torch.exp(eval_epoch_loss)
    train_epoch_loss = total_loss / len(train_dataloader)
    train_ppl = torch.exp(train_epoch_loss)
    print(f"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}")

In [None]:
correct = 0
total = 0
for pred, true in zip(eval_preds, dataset["validation"]["text_label"]):
    if pred.strip() == true.strip():
        correct += 1
    total += 1
accuracy = correct / total * 100
print(f"{accuracy=} % on the evaluation dataset")
print(f"{eval_preds[:10]=}")
print(f"{dataset['validation']['text_label'][:10]=}")

# Llama-4

In [None]:
# Apply prefix tuning to solve the emotion recognition task
# First, do the language modeling task with prefix tuning
# The prefix are the emotions in the dataset, and the model will learn to predict the next word based on the prefix
# There's also an additional prefix for answering the emotion recognition task.


model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"

processor = AutoProcessor.from_pretrained(model_id)
model = Llama4ForConditionalGeneration.from_pretrained(
    model_id,
    attn_implementation="flex_attention",
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
messages = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": train_df["Emotion"][0]},
            {"type": "text", "text": train_df["Utterance"][0]},
        ]
    },
]
inputs = processor.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=True,
    return_dict=True,
    return_tensors="pt",
)
# def get_data_loader(df, processor, max_length=128, batch_size=4):
#     """
#     Create a DataLoader for the dataset.
#     """
#     def encode(examples):
#         return processor(
#             examples["Emotion"],
#             examples["Utterance"],
#             truncation=True,
#             max_length=max_length,
#             padding="max_length",
#         )

#     # Encode the dataset
#     encoded_dataset = df.apply(encode, axis=1).tolist()
    
#     # Create DataLoader
#     data_loader = DataLoader(
#         encoded_dataset,
#         batch_size=batch_size,
#         collate_fn=default_data_collator
#     )
    
#     return data_loader

