In [1]:
import torch
import pandas as pd
import os 
import librosa
import librosa.display
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import math
import einops

from torchvision import transforms
from torch import nn
from tqdm import tqdm
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.nn import functional as F
from typing import List, Any, Union, Tuple
from transformers import AutoTokenizer

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
plt.ioff()
matplotlib.use('agg')

In [2]:
class PositionalEncoding(nn.Module):
    """
    Computes positional encoding as given in paper 'Attention is all you need'.

    """
    def __init__(self, dim_model : int, dropout_p : float, max_len : int) -> torch.Tensor:
        super().__init__()
        # Modified version from: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
        # max_len determines how far the position can have an effect on a token (window)
        
        # Dropout
        self.dropout = nn.Dropout(dropout_p)
        
        # Encoding - From formula
        pos_encoding = torch.zeros(max_len, dim_model)
        positions_list = torch.arange(0, max_len, dtype=torch.float).view(-1, 1) # 0, 1, 2, 3, 4, 5
        division_term = torch.exp(torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model) # 1000^(2i/dim_model)
        
        # PE(pos, 2i) = sin(pos/1000^(2i/dim_model))
        pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)
        
        # PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model))
        pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)
        
        # Saving buffer (same as parameter without gradients needed)
        pos_encoding = pos_encoding.unsqueeze(0)
        
        self.register_buffer("pos_encoding",pos_encoding)

    def forward(self, token_embedding: torch.tensor) -> torch.tensor:
        # Residual connection + pos encoding
        return self.dropout(token_embedding + self.pos_encoding[:, :token_embedding.size(1), :])

In [3]:
class InputEmbeddingRecPatches(nn.Module):
    """
    Computes embeddings for each patch in the image.
    Can use rectangular patches.  
    """
    def __init__(self, 
                 patch_size : int,   
                 latent_size : int,
                 img_shape: Tuple[int],
                 device : str) -> torch.Tensor:

        super(InputEmbeddingRecPatches, self).__init__()

        c, h, w = img_shape
        self.patch_size = (patch_size, patch_size) if isinstance(patch_size, int) else patch_size
        assert h % patch_size[0] == 0 and w % patch_size[1] == 0,\
            'Image dim. must be divisible by the patch size'
        
        self.device = device
        
        self.latent_size = latent_size
        self.input_size = self.patch_size[0] * self.patch_size[1] * c  # Embedding size of each patch
        self.linearProjection = nn.Linear(self.input_size, self.latent_size)

    def forward(self, input_data : torch.Tensor, batch_size : int) -> torch.Tensor:
        input_data = input_data.to(self.device)
        pos_embedding = nn.Parameter(torch.randn(batch_size, 1, self.latent_size)).to(self.device)  # (B, 1, C)
        h1, w1 = self.patch_size[0], self.patch_size[1]
        # Patchify
        patches = einops.rearrange(
            input_data, 'b c (h h1) (w w1) -> b (h w) (h1 w1 c)', h1=h1, w1=w1
        )

        linearProjection = self.linearProjection(patches).to(self.device)
        b, n, _ = linearProjection.shape

        # Extend pos. embeddings to match number of patches
        pos_embed = einops.repeat(pos_embedding, 'b 1 d -> b m d', m=n)
        linearProjection += pos_embed
        return linearProjection

In [25]:
class Img2SeqModel(nn.Module):
    def __init__(self,
                 device : str,
                 patch_size : Union[int, tuple[int]],
                 dim_model : int,
                 encoder_layers : int,
                 decoder_layers : int,
                 heads : int,
                 mlp_dim : int,
                 dropout_p: float,
                 vocab_size : int,
                 max_seq_len : int,
                 img_shape : Tuple[int]) -> None:
        
        super().__init__()
        self.dim_model = dim_model

        # Init target seq. positional encoding instance
        self.pos_enc = PositionalEncoding(
            dim_model=dim_model,
            dropout_p=dropout_p,
            max_len=max_seq_len
        )

        # Init iamge encoding instance
        self.src_embedding = InputEmbeddingRecPatches(
            patch_size=patch_size,
            latent_size=dim_model,
            img_shape=img_shape,
            device=device
        )

        # Init token embeddings table
        self.embedding = nn.Embedding(vocab_size, dim_model)

        # Instance of encoder layer
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=dim_model,
                                                        nhead=heads,
                                                        dim_feedforward=mlp_dim,
                                                        dropout=dropout_p,
                                                        activation='gelu',
                                                        batch_first=True)
        
        # Whole encoder made of encoder layers
        self.encoder = nn.TransformerEncoder(encoder_layer=self.encoder_layer,
                                             num_layers=encoder_layers)
        
        # Instance of decoder layer
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=dim_model,
                                                        nhead=heads,
                                                        dim_feedforward=mlp_dim,
                                                        dropout=dropout_p,
                                                        batch_first=True)
        
        # Whole decoder made of decoder layers
        self.decoder = nn.TransformerDecoder(decoder_layer=self.decoder_layer,
                                             num_layers=decoder_layers)
        
        # Linear proj. to vocab. size
        self.out = nn.Linear(dim_model, vocab_size)


    def forward(self, src : torch.tensor, 
                tgt : torch.tensor, 
                tgt_mask : torch.tensor=None, 
                tgt_pad_mask : torch.tensor=None) -> torch.Tensor:
        
        src = self.src_embedding(src, batch_size=src.size(0))  # Embed img.
        memory = self.encoder(src)  # Encoder output
        
        # Embed transcriptions
        tgt = self.embedding(tgt) * math.sqrt(self.dim_model) 
        tgt = self.pos_enc(tgt)

        # Out. features of decoder
        decoder_out = self.decoder(tgt,
                                   memory,
                                   tgt_mask=tgt_mask,
                                   tgt_key_padding_mask=tgt_pad_mask)
        
        # Linear projection
        out = self.out(decoder_out)
        return out
    

    def get_tgt_mask(self, size : int) -> torch.tensor:
        """
        Creates mask to prevent decoder sequences looking
        at future characters.
        """

        mask = torch.tril(torch.ones(size, size) == 1) # Lower triangular matrix
        mask = mask.float()
        mask = mask.masked_fill(mask == 0, float('-inf')) # Convert zeros to -inf
        mask = mask.masked_fill(mask == 1, float(0.0)) # Convert ones to 0

        return mask
    
    
    def create_pad_mask(self, matrix: torch.tensor, pad_token: int) -> torch.tensor:
        """
        Creates mask that masks padding tokens in decoder. 
        """
        # If matrix = [1,2,3,0,0,0] where pad_token=0, the result mask is
        # [False, False, False, True, True, True]
        return (matrix == pad_token)

In [5]:
class CustomSpectroDataset(Dataset):
    def __init__(self, 
                 csv_data : str,
                 recordings_folder : str='wav_recordings', 
                 spectr_folder : str='output_logarithmic',
                 spectr_format : str='log',
                 produce_spectr : bool=False, 
                 transform : Any=None):
        
        self.transform = transform
        self.df = pd.read_csv(csv_data)
        self.transcriptions, self.wav_filenames = self.df['transcription'].tolist(), self.df['channel_1'].tolist()
        assert len(self.transcriptions) == len(self.wav_filenames)
        
        # If we want to generate images from WAV-files (not needed if images are already present)
        if produce_spectr:
            if spectr_format == 'log':
                for name in tqdm(self.wav_filenames):
                    if os.path.exists(os.path.join(recordings_folder, name)):
                        self._logar(os.path.join(recordings_folder, name), spectr_folder)
            
            elif spectr_format == 'bw':
                for name in tqdm(self.wav_filenames):
                    if os.path.exists(os.path.join(recordings_folder, name)):
                        self._black_n_white(os.path.join(recordings_folder, name), spectr_folder)

            elif spectr_format == 'df':
                for name in tqdm(self.wav_filenames):
                    if os.path.exists(os.path.join(recordings_folder, name)):
                        self._default(os.path.join(recordings_folder, name), spectr_folder)

        self.data = []

        # Load images and combine with their respective transcriptions 
        for i in range(len(self.transcriptions)):
            transcription = self.transcriptions[i]
            fname, _ = self.wav_filenames[i].split('.')
            # If check can probably be removed once all WAV-files are processed
            if os.path.exists(os.path.join(spectr_folder, fname + '.jpg')):
                with Image.open(os.path.join(spectr_folder, fname + '.jpg')) as img_file:
                    if self.transform:
                        img_tensor = self.transform(img_file)
                self.data.append({'img': img_tensor, 'transcription': transcription})
       
        # Custom tokenizer taken from NorBERT (can try something else as well)
        old_tokenizer = AutoTokenizer.from_pretrained('ltg/norbert2')
        training_corpus = self.get_all_transcriptions()
        self.tokenizer = old_tokenizer.train_new_from_iterator(training_corpus, vocab_size=1200)

        new_special_tokens = {'additional_special_tokens' : ['[BOS]', '[EOS]'],}
        self.tokenizer.add_special_tokens(new_special_tokens)


    def longest_seq(self) -> int:
        # + 2 to count for beginning and end
        return max([len(self.tokenizer.tokenize(sample['transcription'])) for sample in self.data]) + 2
        # return max([len(sample['transcription']) for sample in self.data]) + 2
    

    def get_vocab_size(self) -> List[str]:
        return len(self.tokenizer.vocab)
    
    
    def get_all_transcriptions(self) -> List[dict[torch.tensor, str]]:
        transcriptions = [sample['transcription'] for sample in self.data]
        return transcriptions


    def __len__(self) -> int:
        return len(self.data)
    

    def __getitem__(self, index : int) -> dict[str, torch.tensor]:
        img, seq = self.data[index]['img'], self.data[index]['transcription']
        
        seq_to_int = [self.tokenizer.vocab['[BOS]']]
        seq_to_int.extend(self.tokenizer.encode(seq, add_special_tokens=False))
        seq_to_int.append(self.tokenizer.vocab['[EOS]'])

        return {'img': img, 
                'seq': torch.tensor(seq_to_int)}
    

    def _get_training_corpus(self):
        data = self.get_all_transcriptions()
        for start_idx in range(0, len(data), 1000):
            samples = data[start_idx : start_idx + 1000]
            yield samples
        

    def _logar(self, wav_file_path : str, jpg_folder_output : str) -> None:
        y, sr = librosa.load(wav_file_path)
        
        D = librosa.stft(y)
        S = librosa.amplitude_to_db(librosa.magphase(D)[0], ref=np.max)  # different sizes here
        file_name = os.path.splitext(os.path.basename(wav_file_path))[0]

        output_folder = jpg_folder_output

        os.makedirs(output_folder, exist_ok=True)

        output_path = os.path.join(output_folder, f"{file_name}.jpg")

        plt.figure(figsize=(7, 3))

        librosa.display.specshow(S, y_axis='log', x_axis='time')

        plt.axis('off')
        plt.savefig(output_path, bbox_inches = 'tight', pad_inches = 0)
        plt.close()


    def _black_n_white(self, file_path):
        y, sr = librosa.load(file_path)

        D = librosa.stft(y)
        S = librosa.amplitude_to_db(librosa.magphase(D)[0], ref=np.max)

        file_name = os.path.splitext(os.path.basename(file_path))[0]

        output_folder = self.spectr_folder


        os.makedirs(output_folder, exist_ok=True)

        output_path = os.path.join(output_folder, f"{file_name}.jpg")

        plt.figure(figsize=(10, 5))
        librosa.display.specshow(S, cmap='gray', y_axis='hz', x_axis='time')
        plt.axis('off')
        plt.savefig(output_path, bbox_inches = 'tight', pad_inches = 0)


    def _default(self, file_path):
        y, sr = librosa.load(file_path)

        D = librosa.stft(y)
        file_name = os.path.splitext(os.path.basename(file_path))[0]

        output_folder = self.spectr_folder


        os.makedirs(output_folder, exist_ok=True)

        output_path = os.path.join(output_folder, f"{file_name}.jpg")

        plt.figure(figsize=(10, 5))
        librosa.display.specshow(librosa.amplitude_to_db(D, ref=np.max), y_axis='hz', x_axis='time')
        plt.axis('off')
        plt.savefig(output_path, bbox_inches = 'tight', pad_inches = 0)

In [50]:
class CollateFunctor:
    """
    Simple collator to pad decoder sequences
    """

    def __init__(self, pad_idx : int) -> None:
        self.pad_idx = pad_idx
        
    def __call__(self, samples) -> dict[str, torch.tensor]:
        img_tensors = []
        sequences = []

        for sample in sorted(samples, key=lambda x: len(x['seq']), reverse=True):
            img_tensors.append(sample['img'])
            sequences.append(sample['seq'])

        # Padding sequences
        padded_seq_tensors = pad_sequence(sequences, 
                                          batch_first=True,
                                          padding_value=self.pad_idx)
        
        img_tensors = torch.stack(tuple(img_tensors), dim=0)
        return {'img': img_tensors,
                'seq': padded_seq_tensors}

In [8]:
# Dataset variables 
data_path = 'dataset.csv' 
produce_spectr = False
transform = transforms.Compose([
    transforms.Resize((231, 540)),  # RESIZE SHOULD BE INCLUDED ONLY IF PATCH SHAPE DOESN'T MATCH IMAGE SHAPE
    transforms.ToTensor(),
])

dataset = CustomSpectroDataset(csv_data=data_path,
                               produce_spectr=produce_spectr,
                               transform=transform,
                               spectr_folder='log_rectangular')

In [51]:
# Train/test splits
seed = 7
gen = torch.Generator().manual_seed(seed)
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [0.9, 0.1], generator=gen)

train_dataloader = DataLoader(train_dataset, 
                              batch_size=32, 
                              shuffle=True,
                              collate_fn=CollateFunctor(pad_idx=dataset.tokenizer.pad_token_id))

test_dataloader = DataLoader(test_dataset, 
                             batch_size=32, 
                             shuffle=True,
                             collate_fn=CollateFunctor(dataset.tokenizer.pad_token_id))

In [58]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Model hyperparameters
patch_size = (231, 54) 
dim_model = 240
encoder_layers = 4
decoder_layers = 4
heads = 6
mlp_dim = 1024
dropout_p = 0.1

vocab_size = dataset.get_vocab_size()
max_seq_len = dataset.longest_seq()
n_channels = 3 
img_shape = (3, 231, 540)

In [59]:
model = Img2SeqModel(device=device,
                     patch_size=patch_size,
                     dim_model=dim_model,
                     encoder_layers=encoder_layers,
                     decoder_layers=decoder_layers,
                     heads=heads,
                     mlp_dim=mlp_dim,
                     dropout_p=dropout_p,
                     vocab_size=vocab_size,
                     max_seq_len=max_seq_len,
                     img_shape=img_shape )

model = model.to(device)

In [60]:
# Training hyperparameters
epochs = 3
init_lr = 1e-5
optimizer = torch.optim.AdamW(params=model.parameters(), lr=init_lr)

In [61]:
@torch.no_grad()
def generate(model, img, device):
    iterations = 10
    """
    Generates text based on sepctrogram, BOS-token + potentially some other token of choice.
    """
    model.eval()
    seq_inp = torch.tensor([dataset.tokenizer.vocab['[BOS]']], dtype=torch.long, device=device).unsqueeze(0)
    
    # Iterations is arbitrary - can be changed to any number
    for _ in range(iterations):
        logits = model(img, seq_inp)
        logits = logits[:, -1, :]
        probs = F.softmax(logits, dim=-1).detach().to('cpu')
        idx_next = np.argmax(probs, axis=-1).unsqueeze(0).to(device)
        seq_inp = torch.cat(((seq_inp, idx_next)), dim=1)
    
    print(dataset.tokenizer.decode(seq_inp.squeeze(0).detach().to('cpu')))

In [62]:
def train(model, iterator, optimizer, device):
    """
    Train loop
    """
    model.train()
    total_loss = 0
    eval_batches = [50, 100, 150, 200, 250]
    for batch_num, batch in enumerate(tqdm(iterator), 1):
        img = batch['img'].to(device)
        target_seq = batch['seq'].to(device)

        sequence_len = target_seq.size(1)
        tgt_attention_mask = model.get_tgt_mask(sequence_len).to(device)
        tgt_pad_mask = model.create_pad_mask(target_seq, dataset.tokenizer.pad_token_id)

        logits = model(img, target_seq, tgt_mask=tgt_attention_mask, tgt_pad_mask=tgt_pad_mask)
        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        target_seq = target_seq.view(B*T)
        loss = F.cross_entropy(logits, target_seq, ignore_index=dataset.tokenizer.pad_token_id)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.detach().item()

        # Test on a single spectrogram
        if batch_num in eval_batches:
            single_img = img[0].unsqueeze(0)
            generate(model, single_img, device)
    
    return total_loss / len(iterator)

In [None]:
for epoch in range(epochs):
    print(f'Epoch: {epoch + 1}', end='\n\n')
    epoch_loss = train(model, train_dataloader, optimizer, device)
    print()
    print(f'Loss: {epoch_loss}')    