In [None]:
import torch
import pandas as pd
import os 
import librosa
import librosa.display
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import timm
import wave
import skimage.io
import torchvision.models as models

from torchvision import transforms
from torch import nn
from tqdm import tqdm
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence 
from torch.nn import functional as F
from typing import List, Any, Tuple, Optional, TypeVar, Union, IO, Type
from transformers import AutoTokenizer
from timm.models.layers import to_2tuple

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
plt.ioff()
matplotlib.use('agg')

In [5]:
#Bahdanau Attention
class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        
        self.attention_dim = attention_dim
        
        self.W = nn.Linear(decoder_dim, attention_dim)
        self.U = nn.Linear(encoder_dim, attention_dim)
        self.A = nn.Linear(attention_dim, 1)
        

    def forward(self, features, hidden_state):
        u_hs = self.U(features)      # (batch_size, num_layers, attention_dim)
        w_ah = self.W(hidden_state)  # (batch_size, attention_dim)
        
        combined_states = torch.tanh(u_hs + w_ah.unsqueeze(1))  # (batch_size, num_layers, attemtion_dim)
        
        attention_scores = self.A(combined_states)         # (batch_size, num_layers, 1)
        attention_scores = attention_scores.squeeze(2)     # (batch_size, num_layers)
        
        
        alpha = F.softmax(attention_scores,dim=1)          # (batch_size, num_layers)
        
        attention_weights = features * alpha.unsqueeze(2)  # (batch_size, num_layers, features_dim)
        attention_weights = attention_weights.sum(dim=1)   # (batch_size, num_layers)
        
        return alpha, attention_weights

In [None]:
class EncoderCNN(nn.Module):
    def __init__(self):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet50(pretrained=True)
        for param in resnet.parameters():
            param.requires_grad_(True)
        
        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)
        

    def forward(self, images):
        features = self.resnet(images)                                     #  (batch_size, 2048, 7, 7)
        features = features.permute(0, 2, 3, 1)                            #  (batch_size, 7, 7, 2048)
        features = features.view(features.size(0), -1, features.size(-1))  #  (batch_size, 49, 2048)
        return features

In [6]:
class DecoderRNN(nn.Module):
    def __init__(self,embed_size, 
                 vocab_size,
                 attention_dim,
                 encoder_dim,
                 decoder_dim,
                 drop_prob=0.3,
                 device='cuda'):
        
        super().__init__()
        
        self.vocab_size = vocab_size
        self.attention_dim = attention_dim
        self.decoder_dim = decoder_dim
        
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.attention = Attention(encoder_dim, decoder_dim, attention_dim)
        
        self.init_h = nn.Linear(encoder_dim, decoder_dim)  
        self.init_c = nn.Linear(encoder_dim, decoder_dim)  
        self.lstm_cell = nn.LSTMCell(embed_size + encoder_dim, decoder_dim, bias=True)
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)
        
        self.fcn = nn.Linear(decoder_dim, vocab_size)
        self.drop = nn.Dropout(drop_prob)
        
        self.device = device

        
    def forward(self, features, captions):
        # Vectorize the caption
        embeds = self.embedding(captions)
        
        # Initialize LSTM state
        h, c = self.init_hidden_state(features)  # (batch_size, decoder_dim)
        
        # Get the seq length to iterate
        seq_length = len(captions[0]) -1 # Exclude the last one
        batch_size = captions.size(0)
        num_features = features.size(1)
        
        preds = torch.zeros(batch_size, seq_length, self.vocab_size).to(self.device)
        alphas = torch.zeros(batch_size, seq_length,num_features).to(self.device)
                
        for s in range(seq_length):
            alpha, context = self.attention(features, h)
            lstm_input = torch.cat((embeds[:, s], context), dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))
                    
            output = self.fcn(self.drop(h))
            
            preds[:,s] = output
            alphas[:,s] = alpha  
        
        return preds, alphas
    
    
    def generate_caption(self, features, int2str, str2int, max_len=20):
        # Given the image features generate the captions
        batch_size = features.size(0)
        h, c = self.init_hidden_state(features)  # (batch_size, decoder_dim)
        
        alphas = []
        
        # Starting input
        word = torch.tensor(str2int['[BOS]']).view(1,-1).to(self.device)
        embeds = self.embedding(word)

        captions = []
        
        for _ in range(max_len):
            alpha, context = self.attention(features, h)
            
            # Store the alpha score
            alphas.append(alpha.cpu().detach().numpy())
            
            lstm_input = torch.cat((embeds[:, 0], context), dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))
            output = self.fcn(self.drop(h))
            output = output.view(batch_size,-1)
        
            # Select the word with most val
            predicted_word_idx = output.argmax(dim=1)
            
            # Save the generated word
            captions.append(predicted_word_idx.item())
            
            # End if "[EOS]" detected
            if int2str[predicted_word_idx.item()] == "[EOS]":
                break
            
            # Send generated word as the next caption
            embeds = self.embedding(predicted_word_idx.unsqueeze(0))
        
        # Covert the vocab idx to words and return sentence
        return [int2str[caption] for caption in captions], alphas
    
    
    def init_hidden_state(self, encoder_out):
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)  # (batch_size, decoder_dim)
        c = self.init_c(mean_encoder_out)
        return h, c

In [7]:
class Img2SeqModel(nn.Module):
    def __init__(self,
                 embed_dim,
                 encoder_dim,
                 decoder_dim,
                 attention_dim,
                 vocab_size) -> None:
        
        super().__init__()

        # ------------ INITIALIZING ENCODER ------------

        self.encoder = EncoderCNN()

        # ------------ INITIALIZING DECODER ------------
        self.decoder = DecoderRNN(embed_size=embed_dim,
                                  vocab_size=vocab_size,
                                  attention_dim=attention_dim,
                                  encoder_dim=encoder_dim,
                                  decoder_dim=decoder_dim)


    def forward(self, src : torch.tensor, 
                tgt : torch.tensor) -> torch.Tensor:
        
        # Encode
        features = self.encoder(src)
        # Decode
        outputs = self.decoder(features, tgt)

        return outputs


    def get_shape(self, fstride, tstride, height, width):
        test_input = torch.randn(1, 1, height, width)
        test_proj = nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride))
        test_out = test_proj(test_input)
        f_dim = test_out.shape[2]
        t_dim = test_out.shape[3]
        return f_dim, t_dim

In [None]:
class CustomSpectroDataset(Dataset):
    def __init__(self,
                 csv_data : str,
                 recordings_folder : str='wav_recordings', 
                 spectr_folder : str='log_rectangular_small',
                 produce_spectr : bool=False, 
                 transform : Any=None,
                 ):
        
        self.transform = transform
        self.df = pd.read_csv(csv_data)
        self.df.dropna(how='any', inplace=True)
        self.spectr_folder = spectr_folder
        self.jpg_paths = os.listdir(self.spectr_folder)
        self.wav_filenames = self.df['channel_1'].tolist()
        self.transcriptions = self.df['transcription'].tolist()
        assert len(self.wav_filenames) == len(self.transcriptions)

        # If we want to generate images from WAV-files (not needed if images are already present)
        if produce_spectr:
            for name in tqdm(self.wav_filenames):
                if os.path.exists(os.path.join(recordings_folder, name)):
                    self._logar(os.path.join(recordings_folder, name), spectr_folder)
            
        self.data = []

        for i, wav in enumerate(tqdm(self.wav_filenames)):
            for jpg in self.jpg_paths:
                if jpg.split('.')[0] == wav.split('.')[0]:
                    self.data.append({'img_name': jpg, 'seq': self.transcriptions[i], 'wav': wav})
        
        # SUPER SIMPLE TOKENIZER: SPLIT ON SPACEBAR
        self.int2str = {0: '[PAD]', 1: '[BOS]', 2: '[EOS]'}
        
        all_tokens = []

        for sample in self.data:
            transcription = sample['seq']
            tokens = transcription.split()
            all_tokens.extend(tokens)

        vocab = list(set(all_tokens))

        for token in vocab:
            self.int2str[len(self.int2str)] = token
        
        self.str2int = {st:i for i, st in self.int2str.items()}

    def __getitem__(self, index : int) -> dict[str, torch.tensor]:
        img_name = self.data[index]['img_name']
        seq = self.data[index]['seq']
        wav = self.data[index]['wav']

        with Image.open(os.path.join(self.spectr_folder, img_name)) as img_file:
            if self.transform:
                img_tensor = self.transform(img_file)

        # seq_to_int = [self.tokenizer.vocab['[BOS]']]
        seq_to_int = [self.str2int['[BOS]']]
        seq_to_int.extend([self.str2int[token] for token in seq.split()])
        seq_to_int.append(self.str2int['[EOS]'])

        return {'img': img_tensor, 
                'seq': torch.tensor(seq_to_int),
                'wav': wav}

    def get_vocab_size(self) -> List[str]:
        #return len(self.tokenizer.vocab)
        return len(self.int2str)
    
    
    def get_all_transcriptions(self) -> List[dict[torch.tensor, str]]:
        transcriptions = [sample['seq'] for sample in self.data]
        return transcriptions


    def __len__(self) -> int:
        return len(self.data)
      

    def _logar(self, wav_file_path : str, jpg_folder_output : str) -> None:
        y, sr = librosa.load(wav_file_path)
        
        D = librosa.stft(y)
        S = librosa.amplitude_to_db(librosa.magphase(D)[0], ref=np.max)  # different sizes here
        file_name = os.path.splitext(os.path.basename(wav_file_path))[0]

        output_folder = jpg_folder_output

        os.makedirs(output_folder, exist_ok=True)

        output_path = os.path.join(output_folder, f"{file_name}.jpg")

        plt.figure(figsize=(7, 3))

        librosa.display.specshow(S, y_axis='log', x_axis='time')

        plt.axis('off')
        plt.savefig(output_path, bbox_inches = 'tight', pad_inches = 0)
        plt.close()

In [9]:
class CollateFunctor:
    """
    Simple collator to pad decoder sequences
    """

    def __init__(self, pad_idx : int) -> None:
        self.pad_idx = pad_idx
        self.max = 1024
        
    def __call__(self, samples) -> dict[str, torch.tensor]:
        img_tensors = []
        sequences = []
        wavs = []

        for sample in sorted(samples, key=lambda x: len(x['seq']), reverse=True):
            img_tensors.append(sample['img'])
            sequences.append(sample['seq'])
            wavs.append(sample['wav'])

        # Padding sequences
        padded_seq_tensors = pad_sequence(sequences, 
                                          batch_first=True,
                                          padding_value=self.pad_idx)
        
        # Padding images
        for i, tensor in enumerate(img_tensors):
            img_tensors[i] = F.pad(tensor, (0, self.max - tensor.shape[-1]), 'constant', 0)
        
        img_tensors = torch.stack(tuple(img_tensors), dim=0)
        return {'img': img_tensors,
                'seq': padded_seq_tensors,
                'wav': wavs}

In [None]:
# Dataset variables
data_path = 'dataset.csv' 
produce_spectr = False
transform = transforms.Compose([
    transforms.ToTensor()
])

dataset = CustomSpectroDataset(csv_data=data_path,
                               produce_spectr=produce_spectr,
                               transform=transform,
                               spectr_folder='log_rectangular_small')

In [18]:
# Train/test splits
seed = 77
gen = torch.Generator().manual_seed(seed)
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [0.9, 0.1], generator=gen)


train_dataloader = DataLoader(train_dataset, 
                              batch_size=16, 
                              shuffle=True, 
                              collate_fn=CollateFunctor(pad_idx=dataset.str2int['[PAD]']))

test_dataloader = DataLoader(test_dataset, 
                             batch_size=16, 
                             shuffle=True,
                             collate_fn=CollateFunctor(pad_idx=dataset.str2int['[PAD]']))

In [19]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Hyperparameters
dim_model = 768
dropout_p = 0.1
vocab_size = dataset.get_vocab_size()
img_shape_small = (1, 128, 1024)

In [20]:
model = Img2SeqModel(embed_dim=300,
                     encoder_dim=2048,
                     decoder_dim=512,
                     attention_dim=256,
                     vocab_size=vocab_size)

model = model.to(device)

In [21]:
epochs = 5
enc_lr = 1e-4
dec_lr = 3e-4
enc_optim = torch.optim.AdamW(model.encoder.parameters(), lr=enc_lr)
dec_optim = torch.optim.AdamW(model.decoder.parameters(), lr=dec_lr)
criterion = nn.CrossEntropyLoss(ignore_index=dataset.str2int['[PAD]'])

In [22]:
def train(model, iterator, enc_optim, dec_optim, device):
    model.train()
    total_loss = 0
    
    for batch in tqdm(iterator):
        img = batch['img'].to(device)
        target_seq = batch['seq'].to(device)

        outputs, _ = model(img, target_seq)
        targets = target_seq[:, 1:]
        
        loss = criterion(outputs.view(-1, vocab_size), targets.reshape(-1))
        loss.backward()

        enc_optim.step()
        dec_optim.step()

        total_loss += loss.detach().item()
  
    return total_loss / len(iterator)

In [23]:
evaluations = 10

@torch.no_grad()
def evaluate(model, iterator, device):
    model.eval()
    batch = next(iter(iterator))
    
    # Evaluatuing on 10 spectrograms
    for i in range(evaluations):
        img = batch['img'][i].unsqueeze(0).to(device)
        target_seq = batch['seq'][i]

        features = model.encoder(img)
        caps, _ = model.decoder.generate_caption(features, int2str=dataset.int2str, str2int=dataset.str2int, max_len=16)
        caption = ' '.join(caps)
        target = [dataset.int2str[i.item()] for i in target_seq.detach().cpu()]

        print(f'Generated caption: {caption}')
        print(f'Target caption: {" ".join(t for t in target)}', end='\n-----------------------\n\n')

In [None]:
for epoch in range(epochs):
    epoch_loss = train(model, train_dataloader, enc_optim, dec_optim, device)
    print(f'Epoch: {epoch + 1}, Loss: {epoch_loss}', end='\n\n')
    evaluate(model, test_dataloader, device)