In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchaudio
from transformers import BertTokenizer
from google.colab.patches import cv2_imshow

import cv2
import os
import numpy as np
import pandas as pd

class MELDDataset(Dataset):
    def __init__(self, csv, path, transform=None, max_video_len=30, max_audio_len=16000, max_text_len=128):
        self.df = pd.read_csv(csv)
        self.label = {self.df['Emotion'].unique()[i]: i for i in range(len(self.df['Emotion'].unique()))}
        self.path = path
        self.max_video_len = max_video_len
        self.max_audio_len = max_audio_len
        self.max_text_len = max_text_len
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.video_transform = transform

    def __len__(self):
        return len(self.df)

    def load_video(self, video_path):
        cap = cv2.VideoCapture(video_path)
        frames = []
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            # frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = self.video_transform(frame)
            frames.append(frame)
            if len(frames) >= self.max_video_len:
                break
        cap.release()
        if len(frames) < self.max_video_len:
            return frames[-1]
        else:
            return frames[self.max_video_len - 1]

    def load_audio(self, audio_path):
        try:
            waveform, sample_rate = torchaudio.load(audio_path)
            if waveform.size(0) > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)
            if waveform.size(1) > self.max_audio_len:
                waveform = waveform[:, :self.max_audio_len]
            else:
                padding = self.max_audio_len - waveform.size(1)
                waveform = torch.nn.functional.pad(waveform, (0, padding))
            return waveform
        except:
            return torch.zeros(1, self.max_audio_len)

    def tokenize_text(self, text):
        encoding = self.tokenizer(text, truncation=True, padding='max_length', max_length=self.max_text_len, return_tensors='pt')
        return encoding['input_ids'].squeeze(), encoding['attention_mask'].squeeze()

    def __getitem__(self, idx):
        filename = 'dia' + str(self.df.iloc[idx]['Dialogue_ID']) + '_utt' + str(self.df.iloc[idx]['Utterance_ID']) + '.mp4'
        text = self.df.iloc[idx]['Utterance'].replace('\x92', "'")

        video = self.load_video(self.path + filename)
        audio = self.load_audio(self.path + filename)
        text, attention_mask = self.tokenize_text(text)
        label = self.label[self.df.iloc[idx]['Emotion']]
        return video, audio, text, attention_mask, label

def collate_fn(batch):
    videos, audios, texts, attention_masks, labels = zip(*batch)
    videos = torch.stack(videos)
    audios = torch.stack(audios)
    texts = torch.stack(texts)
    attention_masks = torch.stack(attention_masks)
    labels = torch.tensor(labels)
    return texts, attention_masks, audios, videos, labels

def MELD(datatype, transform=None, batch_size=4, collate=collate_fn):
    """DataLoader. \\
    Expected File structure is: \\
    ├── train\\
    ├── valid\\
    ├── test  \\
    ├── train.csv\\
    ├── valid.csv\\
    └── test.csv\\
    Change if you want. \\
    If transform is None, it just resizes data and returns Tensor.\\
    Video (Batch, Frame, Channel, Height, Width) \\
    Audio (Batch, Channel, Sample) \\
    Text  (Batch, tokenized Length)\\
    Label (Batch)
    """
    # Data to load
    if datatype == 'train':
        csv_file = '/content/drive/MyDrive/MELD/train_sent_emo.csv'
        data_folder = '/content/drive/MyDrive/MELD/train_splits/'
    elif datatype == 'valid':
        csv_file = '/content/drive/MyDrive/MELD/MELD.Raw/dev_sent_emo.csv'
        data_folder = '/content/drive/MyDrive/MELD/dev_splits_complete/'
    elif datatype == 'test':
        csv_file = '/content/drive/MyDrive/MELD/MELD.Raw/test_sent_emo.csv'
        data_folder = '/content/drive/MyDrive/MELD/output_repeated_splits_test/'
    # transform
    if transform is None:
        transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((480, 640)),
            transforms.ToTensor(),
        ])
    # Load data
    dataset = MELDDataset(csv_file, data_folder, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate)
    return dataloader

if __name__ == '__main__':
    dataloader = MELD('test', batch_size=16)

    for batch in dataloader:
        batch_input_tokens, batch_attention_masks, batch_audio, batch_video, batch_labels = batch
        # Display one image from the batch for verification
        img = batch_video[0].permute(1, 2, 0).numpy()*255
        cv2_imshow(img)
        break
!git clone https://github.com/jaeyeonkim99/EnCLAP.git
%cd EnCLAP/
!wget https://huggingface.co/lukewys/laion_clap/resolve/main/630k-audioset-best.pt?download=true
dataloader = MELD('test', batch_size=16)

for batch in dataloader:
    batch_input_tokens, batch_attention_masks, batch_audio, batch_video, batch_labels = batch
    !python inference.py  --ckpt /content/drive/MyDrive/Audio captioning --clap_ckpt /content/drive/MyDrive/EnCLAP/630k-audioset-best.pt --input batch_audio
!pip install laion_clap
!pip install encodec
# audio
from typing import Any, Dict

import argparse
import numpy as np
import torch
import torchaudio
from encodec import EncodecModel
from encodec.utils import convert_audio
from laion_clap import CLAP_Module
from transformers import AutoTokenizer

from modeling.enclap_bart import EnClapBartConfig, EnClapBartForConditionalGeneration


class EnClap:
    def __init__(
        self,
        ckpt_path: str,
        clap_audio_model: str = "HTSAT-tiny",
        clap_enable_fusion: bool = True,
        clap_ckpt_path: str = None,
        device: str = "cuda",
    ):
        config = EnClapBartConfig.from_pretrained(ckpt_path)
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
        self.model = (
            EnClapBartForConditionalGeneration.from_pretrained(ckpt_path)
            .to(self.device)
            .eval()
        )

        self.encodec = EncodecModel.encodec_model_24khz().to(self.device)
        self.encodec.set_target_bandwidth(12.0)
        self.clap_model = CLAP_Module(enable_fusion=clap_enable_fusion, amodel=clap_audio_model, device=self.device)
        self.clap_model.load_ckpt(clap_ckpt_path)

        self.generation_config = {
            "_from_model_config": True,
            "bos_token_id": 0,
            "decoder_start_token_id": 2,
            "early_stopping": True,
            "eos_token_id": 2,
            "forced_bos_token_id": 0,
            "forced_eos_token_id": 2,
            "no_repeat_ngram_size": 3,
            "num_beams": 4,
            "pad_token_id": 1,
            "max_length": 50,
        }
        self.max_seq_len = config.max_position_embeddings - 3

    @torch.no_grad()
    def infer_from_audio_file(
        self, audio_file: str, generation_config: Dict[str, Any] = None
    ) -> str:
        if generation_config is None:
            generation_config = self.generation_config
        audio, res = torchaudio.load(audio_file)
        return self.infer_from_audio(audio[0], res)

    @torch.no_grad()
    def infer_from_audio(
        self, audio: torch.Tensor, res: int, generation_config: Dict[str, Any] = None
    ) -> str:
        if generation_config is None:
            generation_config = self.generation_config
        if audio.dtype == torch.short:
            audio = audio / 2**15
        if audio.dtype == torch.int:
            audio = audio / 2**31
        encodec_audio = (
            convert_audio(
                audio.unsqueeze(0), res, self.encodec.sample_rate, self.encodec.channels
            )
            .unsqueeze(0)
            .to(self.device)
        )
        encodec_frames = self.encodec.encode(encodec_audio)
        encodec_frames = torch.cat(
            [codebook for codebook, _ in encodec_frames], dim=-1
        ).mT

        clap_audio = torchaudio.transforms.Resample(res, 48000)(audio).unsqueeze(0)
        clap_embedding = self.clap_model.get_audio_embedding_from_data(clap_audio, use_tensor=True)

        return self._infer(encodec_frames, clap_embedding, generation_config)

    @torch.no_grad()
    def _infer(
        self,
        encodec_frames: torch.LongTensor,
        clap_embedding: torch.Tensor,
        generation_config: Dict[str, Any] = None,
    ) -> str:
        input_ids = torch.cat(
            [
                torch.ones(
                    (encodec_frames.shape[0], 2, encodec_frames.shape[-1]),
                    dtype=torch.long,
                ).to(self.device)
                * self.tokenizer.bos_token_id,
                encodec_frames[:, : self.max_seq_len],
                torch.ones(
                    (encodec_frames.shape[0], 1, encodec_frames.shape[-1]),
                    dtype=torch.long,
                ).to(self.device)
                * self.tokenizer.eos_token_id,
            ],
            dim=1,
        )
        encodec_mask = torch.LongTensor(
            [[0, 0] + [1] * (input_ids.shape[1] - 3) + [0]]
        ).to(self.device)

        enclap_bart_inputs = {
            "input_ids": input_ids,
            "encodec_mask": encodec_mask,
            "clap_embedding": clap_embedding,
        }

        results = self.model.generate(**enclap_bart_inputs, **generation_config)
        caption = self.tokenizer.batch_decode(results, skip_special_tokens=True)

        return caption

    @torch.no_grad()
    def infer_from_encodec(
        self,
        encodec_path,
        clap_path,
        generation_config: Dict[str, Any] = None,
    ):
        if generation_config is None:
            generation_config = self.generation_config
        encodec_frames = torch.from_numpy(np.load(encodec_path)).unsqueeze(0).cuda()
        clap_embedding = torch.from_numpy(np.load(clap_path)).unsqueeze(0).cuda()

        return self._infer(encodec_frames, clap_embedding, generation_config)
enclap = EnClap(
        ckpt_path='/content/drive/MyDrive/Audio_captioning',
        clap_ckpt_path='/content/drive/MyDrive/EnCLAP/630k-audioset-fusion-best.pt',
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
    )

predict = enclap.infer_from_audio_file('/content/drive/MyDrive/final_videos_testdia101_utt2.wav')