<a href="https://colab.research.google.com/github/leenatantawy/AVA/blob/main/VisualLabs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import os
import torch
import torchaudio
import numpy as np
import torch.nn.functional as F
from torch import Tensor, nn
from typing import Dict, Iterable, Optional

In [3]:
SAMPLE_RATE = 16000
N_FFT = 1024
N_MELS = 128
HOP_LENGTH = int(0.01 * SAMPLE_RATE)
DURATION = 10
N_SAMPLES = int(DURATION * SAMPLE_RATE)
N_FRAMES = N_SAMPLES // HOP_LENGTH + 1

In [4]:
def sinusoids(length, channels, max_timescale=10000):
    """Returns sinusoids for positional embedding"""
    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
    scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)

In [5]:
class MelEncoder(nn.Module):
    """
    time-frequency represntation
    """
    def __init__(self,
                sample_rate= 16000,
                f_min=0,
                f_max=8000,
                n_fft=1024,
                win_length=1024,
                hop_length = int(0.01 * 16000),
                n_mels = 128,
                power = None,
                pad= 0,
                normalized= False,
                center= True,
                pad_mode= "reflect"
                ):
        super(MelEncoder, self).__init__()
        self.window = torch.hann_window(win_length)
        self.spec_fn = torchaudio.transforms.Spectrogram(
            n_fft = n_fft,
            win_length = win_length,
            hop_length = hop_length,
            power = power
        )
        self.mel_scale = torchaudio.transforms.MelScale(
            n_mels,
            sample_rate,
            f_min,
            f_max,
            n_fft // 2 + 1)

        self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()

    def forward(self, wav):
        spec = self.spec_fn(wav)
        power_spec = spec.real.abs().pow(2)
        mel_spec = self.mel_scale(power_spec)
        mel_spec = self.amplitude_to_db(mel_spec) # Log10(max(reference value and amin))
        return mel_spec

In [6]:
class AudioEncoder(nn.Module):
    def __init__(
        self, n_mels: int, n_ctx: int, audio_dim: int, text_dim: int, num_of_stride_conv: int,
    ):
        super().__init__()
        self.mel_encoder = MelEncoder(n_mels=n_mels)
        self.conv1 = nn.Conv1d(n_mels, audio_dim, kernel_size=3, padding=1)
        self.conv_stack = nn.ModuleList([])
        for _ in range(num_of_stride_conv):
            self.conv_stack.append(
                nn.Conv1d(audio_dim, audio_dim, kernel_size=3, stride=2, padding=1)
            )
        # self.proj = nn.Linear(audio_dim, text_dim, bias=False)
        self.register_buffer("positional_embedding", sinusoids(n_ctx, text_dim))

In [7]:
def forward(self, x: Tensor):
        """
        x : torch.Tensor, shape = (batch_size, waveform)
            single channel wavform
        """
        x = self.mel_encoder(x) # (batch_size, n_mels, n_ctx)
        x = F.gelu(self.conv1(x))
        for conv in self.conv_stack:
            x = F.gelu(conv(x))
        x = x.permute(0, 2, 1)
        x = (x + self.positional_embedding).to(x.dtype)
        return x

In [9]:
!pip install transformers

Collecting transformers
  Downloading transformers-4.34.0-py3-none-any.whl (7.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.7/7.7 MB[0m [31m41.9 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.16.4 (from transformers)
  Downloading huggingface_hub-0.18.0-py3-none-any.whl (301 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.0/302.0 kB[0m [31m28.0 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers<0.15,>=0.14 (from transformers)
  Downloading tokenizers-0.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m73.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers)
  Downloading safetensors-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m70.3 MB/s[0m eta [36m0:00:00[0m
Col

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig

In [11]:
class BartCaptionModel(nn.Module):
    def __init__(self, n_mels=128, num_of_conv=6, sr=16000, duration=10, max_length=128, label_smoothing=0.1, bart_type="facebook/bart-base", audio_dim=768):
        super(BartCaptionModel, self).__init__()
        # non-finetunning case
        bart_config = BartConfig.from_pretrained(bart_type)
        self.tokenizer = BartTokenizer.from_pretrained(bart_type)
        self.bart = BartForConditionalGeneration(bart_config)

        self.n_sample = sr * duration
        self.hop_length = int(0.01 * sr) # hard coding hop_size
        self.n_frames = int(self.n_sample // self.hop_length)
        self.num_of_stride_conv = num_of_conv - 1
        self.n_ctx = int(self.n_frames // 2**self.num_of_stride_conv) + 1
        self.audio_encoder = AudioEncoder(
            n_mels = n_mels, # hard coding n_mel
            n_ctx = self.n_ctx,
            audio_dim = audio_dim,
            text_dim = self.bart.config.hidden_size,
            num_of_stride_conv = self.num_of_stride_conv
        )

        self.max_length = max_length
        self.loss_fct = nn.CrossEntropyLoss(label_smoothing= label_smoothing, ignore_index=-100)

In [12]:
@property
def device(self):
    return list(self.parameters())[0].device


In [13]:
def shift_tokens_right(self, input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
        """
        Shift input ids one token to the right.ls
        """
        shifted_input_ids = input_ids.new_zeros(input_ids.shape)
        shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
        shifted_input_ids[:, 0] = decoder_start_token_id

        if pad_token_id is None:
            raise ValueError("self.model.config.pad_token_id has to be defined.")
        # replace possible -100 values in labels by `pad_token_id`
        shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
        return shifted_input_ids

In [14]:
def forward_encoder(self, audio):
        audio_embs = self.audio_encoder(audio)
        encoder_outputs = self.bart.model.encoder(
            input_ids=None,
            inputs_embeds=audio_embs,
            return_dict=True
        )["last_hidden_state"]
        return encoder_outputs, audio_embs

In [15]:
 def forward_decoder(self, text, encoder_outputs):
        text = self.tokenizer(text,
                              padding='longest',
                              truncation=True,
                              max_length=self.max_length,
                              return_tensors="pt")
        input_ids = text["input_ids"].to(self.device)
        attention_mask = text["attention_mask"].to(self.device)

        decoder_targets = input_ids.masked_fill(
            input_ids == self.tokenizer.pad_token_id, -100
        )

        decoder_input_ids = self.shift_tokens_right(
            decoder_targets, self.bart.config.pad_token_id, self.bart.config.decoder_start_token_id
        )

        decoder_outputs = self.bart(
            input_ids=None,
            attention_mask=None,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=attention_mask,
            inputs_embeds=None,
            labels=None,
            encoder_outputs=(encoder_outputs,),
            return_dict=True
        )
        lm_logits = decoder_outputs["logits"]
        loss = self.loss_fct(lm_logits.view(-1, self.tokenizer.vocab_size), decoder_targets.view(-1))
        return loss

In [16]:
def forward(self, audio, text):
        encoder_outputs, _ = self.forward_encoder(audio)
        loss = self.forward_decoder(text, encoder_outputs)
        return loss

In [17]:
def generate(self,
                 samples,
                 use_nucleus_sampling=False,
                 num_beams=5,
                 max_length=128,
                 min_length=2,
                 top_p=0.9,
                 repetition_penalty=1.0,
                 ):

        # self.bart.force_bos_token_to_be_generated = True
        audio_embs = self.audio_encoder(samples)
        encoder_outputs = self.bart.model.encoder(
            input_ids=None,
            attention_mask=None,
            head_mask=None,
            inputs_embeds=audio_embs,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=True)
        input_ids = torch.zeros((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device)
        input_ids[:, 0] = self.bart.config.decoder_start_token_id
        decoder_attention_mask = torch.ones((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device)
        if use_nucleus_sampling:
            outputs = self.bart.generate(
                input_ids=None,
                attention_mask=None,
                decoder_input_ids=input_ids,
                decoder_attention_mask=decoder_attention_mask,
                encoder_outputs=encoder_outputs,
                max_length=max_length,
                min_length=min_length,
                do_sample=True,
                top_p=top_p,
                num_return_sequences=1,
                repetition_penalty=1.1)
        else:
            outputs = self.bart.generate(input_ids=None,
                                            attention_mask=None,
                                            decoder_input_ids=input_ids,
                                            decoder_attention_mask=decoder_attention_mask,
                                            encoder_outputs=encoder_outputs,
                                            head_mask=None,
                                            decoder_head_mask=None,
                                            inputs_embeds=None,
                                            decoder_inputs_embeds=None,
                                            use_cache=None,
                                            output_attentions=None,
                                            output_hidden_states=None,
                                            max_length=max_length,
                                            min_length=min_length,
                                            num_beams=num_beams,
                                            repetition_penalty=repetition_penalty)

        captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
        return captions