In [2]:
from transformers import WhisperProcessor, WhisperForConditionalGeneration, WhisperTokenizer, AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from pathlib import Path
import os
import torch
from IPython.display import Audio
from tqdm.notebook import tqdm

DSDIR = Path(os.environ["DSDIR"])
WHISPER_PATH = "openai/whisper-small.en"
LLM_PATH = DSDIR / "HuggingFace_Models/microsoft/phi-2"
DATA_PATH = DSDIR / "meld_fidle"

os.environ['TOKENIZERS_PARALLELISM'] = 'false'

def freeze_model(model):
    for param in model.parameters():
        param.requires_grad = False

    return model

whisper_model = WhisperForConditionalGeneration.from_pretrained(WHISPER_PATH)
whisper_model.eval()
whisper_model = freeze_model(whisper_model)
whisper_processor = WhisperProcessor.from_pretrained(WHISPER_PATH)

# Initialize the model and its tokenizer
llm_model = AutoModelForCausalLM.from_pretrained(
    LLM_PATH,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,  # Allow using code that was not written by HuggingFace
    attn_implementation="flash_attention_2"  # Optimize the model with Flash Attention
).to("cuda")
llm_model.eval()
llm_model = freeze_model(llm_model)
llm_tokenizer = AutoTokenizer.from_pretrained(LLM_PATH)
llm_tokenizer.pad_token_id = 50257  # Special token of phi

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
# Observe the data and check the whisper model

In [4]:
mode = "test"
WAV_DATA = DATA_PATH / f"{mode}_wav"
PT_DATA = DATA_PATH / f"{mode}_pt"
CSV_DF = f"{mode}_sent_emo.csv"

In [5]:
whisper_model

WhisperForConditionalGeneration(
  (model): WhisperModel(
    (encoder): WhisperEncoder(
      (conv1): Conv1d(80, 768, kernel_size=(3,), stride=(1,), padding=(1,))
      (conv2): Conv1d(768, 768, kernel_size=(3,), stride=(2,), padding=(1,))
      (embed_positions): Embedding(1500, 768)
      (layers): ModuleList(
        (0-11): 12 x WhisperEncoderLayer(
          (self_attn): WhisperSdpaAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=False)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
        

In [6]:
df = pd.read_csv(CSV_DF)
df

Unnamed: 0,Sr No.,Utterance,Speaker,Emotion,Sentiment,Dialogue_ID,Utterance_ID,Season,Episode,StartTime,EndTime,question,answer
0,1,Why do all youre coffee mugs have numbers on ...,Mark,surprise,positive,0,0,3,19,"00:14:38,127","00:14:40,378",Who is the speaker in this dialogue?,"Monica, who is saying ""I'm gonna miss you!"" an..."
1,2,Oh. Thats so Monica can keep track. That way ...,Rachel,anger,negative,0,1,3,19,"00:14:40,629","00:14:47,385",What is the speaker's emotion?,The speaker's emotion is Joy.
2,3,Y'know what?,Rachel,neutral,neutral,0,2,3,19,"00:14:56,353","00:14:57,520",Who is the speaker?,The speaker's emotion is Joy.
3,19,"Come on, Lydia, you can do it.",Joey,neutral,neutral,1,0,1,23,"0:10:44,769","0:10:46,146",The speaker is trying to,The speaker's emotion is Joy.
4,20,Push!,Joey,joy,positive,1,1,1,23,"0:10:46,146","0:10:46,833",Who is the speaker in this dialogue?,"Monica, who is saying ""I'm gonna miss you!"" an..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...
2605,2760,"Yeah, I mean, come on Ross, no one will even n...",Rachel,neutral,neutral,279,11,6,4,"00:14:35,457","00:14:40,211",Who is the speaker?,The speaker's emotion is Joy.
2606,2761,They’re not listening too me?,Ross,surprise,negative,279,12,6,4,"00:14:42,256","00:14:43,840",Who is the speaker?,The speaker's emotion is Joy.
2607,2762,Of course they’re listening to you! Everybody ...,Rachel,neutral,neutral,279,13,6,4,"00:14:44,008","00:14:48,511",Who is the speaker in this dialogue?,"Monica, who is saying ""I'm gonna miss you!"" an..."
2608,2763,Monica you really think I should try this phas...,Ross,neutral,neutral,279,14,6,4,"00:14:48,138","00:14:52,390",Who is the speaker in this dialogue?,"Monica, who is saying ""I'm gonna miss you!"" an..."


In [7]:
idx = 53
df.loc[idx]

Sr No.                                                         69
Utterance       Yeah, its two guys in a ring, and the rules a...
Speaker                                                  Chandler
Emotion                                                   neutral
Sentiment                                                 neutral
Dialogue_ID                                                     7
Utterance_ID                                                    3
Season                                                          3
Episode                                                        24
StartTime                                            00:03:33,129
EndTime                                              00:03:37,341
question                     Who is the speaker of the monologue?
answer                              The speaker's emotion is Joy.
Name: 53, dtype: object

In [8]:
inp = torch.load(PT_DATA / f"dia{df.loc[idx]['Dialogue_ID']}_utt{df.loc[idx]['Utterance_ID']}.pt")
Audio(os.path.join(WAV_DATA, f"dia{df.loc[idx]['Dialogue_ID']}_utt{df.loc[idx]['Utterance_ID']}.wav"), embed=True)

In [9]:
input_features = whisper_processor(inp, return_tensors="pt", sampling_rate=16000).input_features
generated_ids = whisper_model.generate(input_features)
transcription = whisper_processor.batch_decode(generated_ids, skip_special_tokens=True)
transcription

[" Yeah, it's two guys in the ring and the rules are... There are no rules!"]

In [10]:
# Design the model

In [11]:
whisper_encoder = whisper_model.model.encoder.to("cuda")
out_enc = whisper_encoder(input_features.to("cuda"))
out_enc["last_hidden_state"].shape

torch.Size([1, 1500, 768])

In [12]:
class Projector(torch.nn.Module):
    def __init__(
        self, encoder_hidden_dim=768, llm_hidden_dim=2560, nhead=8, nb_feat_tokens=5
    ):
        super(Projector, self).__init__()
        
        self.encoder_layer = torch.nn.TransformerEncoderLayer(
            d_model=encoder_hidden_dim, nhead=nhead, activation="gelu", batch_first=True
        )
        self.projection_layer = torch.nn.Linear(encoder_hidden_dim, llm_hidden_dim)
        
        self.nb_feat_tokens = nb_feat_tokens
        self.encoder_hidden_dim = encoder_hidden_dim
        self.llm_hidden_dim = llm_hidden_dim

    def forward(self, x):
        x = self.encoder_layer(x)
        x = x[:, :self.nb_feat_tokens, :]  # We select a fix number of vectors that will represent the audio features
        x = self.projection_layer(x.reshape(-1, self.encoder_hidden_dim))
        x = x.reshape(-1, self.nb_feat_tokens, self.llm_hidden_dim)
        return x

In [13]:
projector = Projector().to("cuda").to(torch.bfloat16)

In [14]:
feat_extracted = out_enc["last_hidden_state"].to(torch.bfloat16)
feat_project = projector(feat_extracted)
feat_project.shape

torch.Size([1, 5, 2560])

In [15]:
def add_padding(list_ids: list[torch.Tensor]) -> torch.Tensor:
    """Add padding to a list of tensors and return a padded tensor (batch)"""
    padded_tensor = torch.nn.utils.rnn.pad_sequence(
        [sample.flip(dims=(0,)) for sample in list_ids],
        batch_first=True,
        padding_value=llm_tokenizer.pad_token_id,
    ).flip(dims=(1,))
    return padded_tensor


def create_mask(padded_tensor: torch.Tensor) -> torch.Tensor:
    """Create a mask for HuggingFace models"""
    decoder_mask = torch.logical_not(
        (padded_tensor == torch.full_like(padded_tensor, llm_tokenizer.pad_token_id))
    ).to(dtype=torch.int)
    return decoder_mask

In [16]:
class FriendsDataset(Dataset):
    def __init__(self, df, tokenizer, processor):
        self.df = df
        self.tokenizer = tokenizer
        self.processor = processor

    def __len__(self) -> int:
        """Return the number of element of the dataset"""
        return len(self.df)

    def __getitem__(self, idx) -> (str, torch.Tensor, torch.Tensor):
        """Return the input for the model and the label for the loss"""
        df_elem = self.df.loc[idx]

        tokens_ids_inp = self.tokenizer(f"{df_elem['question']} {df_elem['answer']}", add_special_tokens=False)['input_ids']
        tokens_ids_target = tokens_ids_inp[1:] + [self.tokenizer.eos_token_id]

        audio_torch = torch.load(PT_DATA / f"dia{df.loc[idx]['Dialogue_ID']}_utt{df.loc[idx]['Utterance_ID']}.pt")
        audio_torch = self.processor(audio_torch, return_tensors="pt", sampling_rate=16000).input_features

        return torch.tensor(tokens_ids_inp, dtype=torch.int64), torch.tensor(tokens_ids_target, dtype=torch.int64), audio_torch

In [17]:
def collate_fn(batch):
    tokens_ids_inp_list = [element[0] for element in batch]
    tokens_ids_target_list = [element[1] for element in batch]
    batch_audio_torch = torch.cat([element[2] for element in batch])

    batch_inp = add_padding(tokens_ids_inp_list)
    batch_target = add_padding(tokens_ids_target_list)
    
    decoder_mask = create_mask(batch_inp)
    
    return batch_inp, batch_target, decoder_mask, batch_audio_torch


In [18]:
dataset = FriendsDataset(df=df, tokenizer=llm_tokenizer, processor=whisper_processor)
dataloader = DataLoader(
    dataset,
    batch_size=4,
    num_workers=4,
    prefetch_factor=2,
    shuffle=True,
    collate_fn=collate_fn
)

In [21]:
class LlavaFriends(torch.nn.Module):
    def __init__(
        self, whisper_encoder, llm_model, projector
    ):
        super(LlavaFriends, self).__init__()
        self.whisper_encoder = whisper_encoder
        self.llm_model = llm_model
        self.projector = projector

    def add_audio_feat(self, batch_embed_inp, batch_audio_torch, decoder_mask):
        nb_audio_tokens = batch_audio_torch.shape[1]
        idx_max = torch.argmax(decoder_mask, dim=1)
        
        decoder_mask = torch.stack([
            torch.cat([vec[:idx_m], torch.tensor([1]*nb_audio_tokens).to("cuda"), vec[idx_m:]])
            for vec, idx_m in zip(decoder_mask, idx_max)
        ])
        batch_embed_inp = torch.stack([
            torch.cat([vec[:idx_m], audio_vec, vec[idx_m:]])
            for vec, idx_m, audio_vec in zip(batch_embed_inp, idx_max, batch_audio_torch)
        ])
    
        return decoder_mask, batch_embed_inp

    def forward(
        self, batch_inp, decoder_mask, batch_audio_torch
    ):
        batch_embed_inp = self.llm_model.model.embed_tokens(batch_inp)
        batch_audio_torch = self.whisper_encoder(batch_audio_torch)["last_hidden_state"]

        batch_audio_torch = self.projector(batch_audio_torch.to(torch.bfloat16))
    
        decoder_mask, batch_embed_inp = self.add_audio_feat(batch_embed_inp, batch_audio_torch, decoder_mask)
    
        out = self.llm_model(inputs_embeds=batch_embed_inp, attention_mask=decoder_mask)
        nb_audio_tokens = batch_audio_torch.shape[1]
        return out.logits[:, nb_audio_tokens:, :]

llava_model = LlavaFriends(whisper_encoder, llm_model, projector)

In [22]:
def prepare_for_loss(logits, labels):
    """Unfold the Tensors to compute the CrossEntropyLoss correctly"""
    batch_size, seq_length, vocab_size = logits.shape
    logits = logits.reshape(batch_size * seq_length, vocab_size)
    labels = labels.reshape(batch_size * seq_length)
    return logits, labels

In [23]:
# Initialize Optimizer and Criterion
# We choose the CrossEntropyLoss and Adam because they're the most used
criterion = torch.nn.CrossEntropyLoss(ignore_index=llm_tokenizer.pad_token_id)
optimizer = torch.optim.Adam(llava_model.parameters(), lr=1e-4)

In [24]:
loop = tqdm(dataloader)
for batch_inp, batch_target, decoder_mask, batch_audio_torch in loop:
    batch_inp = batch_inp.to("cuda")
    batch_target = batch_target.to("cuda")
    decoder_mask = decoder_mask.to("cuda")
    batch_audio_torch = batch_audio_torch.to("cuda")
    # print(batch_inp.shape)
    # print(batch_target.shape)
    # print(decoder_mask.shape)
    # print(batch_audio_torch.shape)

    logits = llava_model(batch_inp, decoder_mask, batch_audio_torch)
    logits, labels = prepare_for_loss(logits, batch_target)
    loss = criterion(logits, labels)

    loss.backward()
    optimizer.step()

    # print next to progress bar
    loop.set_postfix(loss=loss.item())

  0%|          | 0/653 [00:00<?, ?it/s]