In [16]:
import torch
import pandas as pd
import os 
import librosa
import librosa.display
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import math
import einops

from torchvision import transforms
from torch import nn
from tqdm import tqdm
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.nn import functional as F
from typing import List, Any

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
plt.ioff()
matplotlib.use('agg')


In [17]:
class PositionalEncoding(nn.Module):
    """
    Computes positional encoding as given in paper 'Attention is all you need'.

    """
    def __init__(self, dim_model : int, dropout_p : float, max_len : int) -> torch.Tensor:
        super().__init__()
        # Modified version from: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
        # max_len determines how far the position can have an effect on a token (window)
        
        # Dropout
        self.dropout = nn.Dropout(dropout_p)
        
        # Encoding - From formula
        pos_encoding = torch.zeros(max_len, dim_model)
        positions_list = torch.arange(0, max_len, dtype=torch.float).view(-1, 1) # 0, 1, 2, 3, 4, 5
        division_term = torch.exp(torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model) # 1000^(2i/dim_model)
        
        # PE(pos, 2i) = sin(pos/1000^(2i/dim_model))
        pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)
        
        # PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model))
        pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)
        
        # Saving buffer (same as parameter without gradients needed)
        pos_encoding = pos_encoding.unsqueeze(0)
        
        self.register_buffer("pos_encoding",pos_encoding)

    def forward(self, token_embedding: torch.tensor) -> torch.tensor:
        # Residual connection + pos encoding
        return self.dropout(token_embedding + self.pos_encoding[:, :token_embedding.size(1), :])
        

In [18]:
class InputEmbedding(nn.Module):
    """
    Computes embeddings for each patch in the image. 
    
    """
    def __init__(self, patch_size : int, n_channels : int, device : str, latent_size : int) -> torch.Tensor:
        super(InputEmbedding, self).__init__()
        self.patch_size = patch_size
        self.latent_size = latent_size
        self.n_channels = n_channels
        self.device = device
        # Embedding size of each spatch
        self.input_size = self.patch_size * self.patch_size * self.n_channels

        self.linearProjection = nn.Linear(self.input_size, self.latent_size)


    def forward(self, input_data : torch.Tensor, batch_size : int) -> torch.Tensor:
        input_data = input_data.to(self.device)
        pos_embedding = nn.Parameter(torch.randn(batch_size, 1, self.latent_size)).to(self.device)  # (B, 1, C)

        # Patchify
        patches = einops.rearrange(
            input_data, 'b c (h h1) (w w1) -> b (h w) (h1 w1 c)', h1=self.patch_size, w1=self.patch_size
        )

        linearProjection = self.linearProjection(patches).to(self.device)
        b, n, _ = linearProjection.shape

        # Extend pos. embeddings to match number of patches
        pos_embed = einops.repeat(pos_embedding, 'b 1 d -> b m d', m=n)
        linearProjection += pos_embed
        return linearProjection

In [19]:
class Img2SeqModel(nn.Module):
    def __init__(self,
                 device : str,
                 patch_size : int,
                 dim_model : int,
                 encoder_layers : int,
                 decoder_layers : int,
                 heads : int,
                 mlp_dim : int,
                 dropout_p: float,
                 vocab_size : int,
                 n_channels : int,
                 max_seq_len : int) -> None:
        
        super().__init__()
        self.dim_model = dim_model

        # Init target seq. positional encoding instance
        self.pos_enc = PositionalEncoding(
            dim_model=dim_model,
            dropout_p=dropout_p,
            max_len=max_seq_len
        )

        # Init iamge encoding instance
        self.src_embedding = InputEmbedding(patch_size=patch_size,
                                            n_channels=n_channels,
                                            device=device,
                                            latent_size=dim_model)

        # Init charachter embeddings table
        self.embedding = nn.Embedding(vocab_size, dim_model)

        # Instance of encoder layer
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=dim_model,
                                                        nhead=heads,
                                                        dim_feedforward=mlp_dim,
                                                        dropout=dropout_p,
                                                        activation='gelu',
                                                        batch_first=True)
        
        # Whole encoder made of encoder layers
        self.encoder = nn.TransformerEncoder(encoder_layer=self.encoder_layer,
                                             num_layers=encoder_layers)
        
        # Instance of decoder layer
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=dim_model,
                                                        nhead=heads,
                                                        dim_feedforward=mlp_dim,
                                                        dropout=dropout_p,
                                                        batch_first=True)
        
        # Whole decoder made of decoder layers
        self.decoder = nn.TransformerDecoder(decoder_layer=self.decoder_layer,
                                             num_layers=decoder_layers)
        
        # Linear proj. to vocab. size
        self.out = nn.Linear(dim_model, vocab_size)


    def forward(self, src : torch.tensor, 
                tgt : torch.tensor, 
                tgt_mask : torch.tensor=None, 
                tgt_pad_mask : torch.tensor=None) -> torch.Tensor:
        
        src = self.src_embedding(src, batch_size=src.size(0))  # Embed img.
        memory = self.encoder(src)  # Encoder output
        
        # Embed transcriptions
        tgt = self.embedding(tgt) * math.sqrt(self.dim_model) 
        tgt = self.pos_enc(tgt)

        # Out. features of decoder
        decoder_out = self.decoder(tgt,
                                   memory,
                                   tgt_mask=tgt_mask,
                                   tgt_key_padding_mask=tgt_pad_mask)
        
        # Linear projection
        out = self.out(decoder_out)
        return out
    

    def get_tgt_mask(self, size : int) -> torch.tensor:
        """
        Creates mask to prevent decoder sequences looking
        at future characters.
        """

        mask = torch.tril(torch.ones(size, size) == 1) # Lower triangular matrix
        mask = mask.float()
        mask = mask.masked_fill(mask == 0, float('-inf')) # Convert zeros to -inf
        mask = mask.masked_fill(mask == 1, float(0.0)) # Convert ones to 0

        return mask
    
    
    def create_pad_mask(self, matrix: torch.tensor, pad_token: int) -> torch.tensor:
        """
        Creates mask that masks padding tokens in decoder. 
        """
        # If matrix = [1,2,3,0,0,0] where pad_token=0, the result mask is
        # [False, False, False, True, True, True]
        return (matrix == pad_token)

In [26]:
class CustomSpectroDataset(Dataset):
    def __init__(self, 
                 csv_data : str,
                 recordings_folder : str='wav_recordings', 
                 spectr_folder : str='output_logarithmic',
                 spectr_format : str='log',
                 produce_spectr : bool=False, 
                 transform : Any=None):
        
        self.transform = transform
        self.df = pd.read_csv(csv_data)
        self.transcriptions, self.wav_filenames = self.df['transcription'].tolist(), self.df['channel_1'].tolist()
        self.spectr_folder = spectr_folder
        assert len(self.transcriptions) == len(self.wav_filenames)
        
        # If we want to generate images from WAV-files (not needed if images are already present)
        if produce_spectr:
            if spectr_format == 'log':
                for name in tqdm(self.wav_filenames):
                    if os.path.exists(os.path.join(recordings_folder, name)):
                        self._logar(os.path.join(recordings_folder, name), spectr_folder)
            
            elif spectr_format == 'bw':
                for name in tqdm(self.wav_filenames):
                    if os.path.exists(os.path.join(recordings_folder, name)):
                        self._black_n_white(os.path.join(recordings_folder, name), spectr_folder)

            elif spectr_format == 'df':
                for name in tqdm(self.wav_filenames):
                    if os.path.exists(os.path.join(recordings_folder, name)):
                        self._default(os.path.join(recordings_folder, name), spectr_folder)


        # Acquiring vocab and adding beginning-, end-of-sequence and padding tokens
        self.all_text = "".join(sample for sample in self.transcriptions)
        self.vocab = list(set(self.all_text))
        self.vocab.append('<BOS>')
        self.vocab.append('<EOS>')
        self.vocab.append('<PAD>')

        # Mapping vocab to ints and other way around 
        self.char_to_int = {char:i for i, char in enumerate(self.vocab)}
        self.int_to_char = {i:char for char, i in self.char_to_int.items()}

    def longest_seq(self) -> int:
        # + 2 to count for beginning and end
        return max([len(sample['transcription']) for sample in self.data]) + 2
    

    def get_vocab(self) -> List[str]:
        return self.vocab
    

    def __len__(self) -> int:
        return len(self.data)


    def __getitem__(self, index : int) -> dict[str, torch.tensor]:
        img_name, _ = self.wav_filenames[index].split('.')
        transcription = self.transcriptions[index]
        
        if os.path.exists(os.path.join(self.spectr_folder, img_name + '.jpg')):
            img = Image.open(os.path.join(self.spectr_folder, fname + '.jpg'))
            if self.transform:
                img = self.transform(img)

        # Turning str seq into ints, adding start- and end tokens
        transformed_seq = [self.char_to_int['<BOS>']]
        for ch in seq:
            transformed_seq.append(self.char_to_int[ch]) 
        transformed_seq.append(self.char_to_int['<EOS>'])

        return {'img': img, 
                'seq': torch.tensor(transformed_seq)}
        

    def _logar(self, wav_file_path : str, jpg_folder_output : str) -> None:
        y, sr = librosa.load(wav_file_path)

        D = librosa.stft(y)
        S = librosa.amplitude_to_db(librosa.magphase(D)[0], ref=np.max)

        file_name = os.path.splitext(os.path.basename(wav_file_path))[0]

        output_folder = jpg_folder_output

        os.makedirs(output_folder, exist_ok=True)

        output_path = os.path.join(output_folder, f"{file_name}.jpg")

        plt.figure(figsize=(8, 8))
        librosa.display.specshow(S, y_axis='log', x_axis='time')

        plt.axis('off')
        plt.savefig(output_path, bbox_inches = 'tight', pad_inches = 0)
        plt.close()


    def _black_n_white(self, wav_file_path : str, jpg_folder_output : str) -> None:
        y, sr = librosa.load(wav_file_path)

        D = librosa.stft(y)
        S = librosa.amplitude_to_db(librosa.magphase(D)[0], ref=np.max)

        file_name = os.path.splitext(os.path.basename(wav_file_path))[0]

        output_folder = jpg_folder_output

        os.makedirs(output_folder, exist_ok=True)

        output_path = os.path.join(output_folder, f"{file_name}.jpg")

        plt.figure(figsize=(8, 8))
        librosa.display.specshow(S, cmap='gray', y_axis='hz', x_axis='time')
        plt.axis('off')
        plt.savefig(output_path, bbox_inches = 'tight', pad_inches = 0)
        plt.close()


    def _default(self, wav_file_path : str, jpg_folder_output : str) -> None:
        y, sr = librosa.load(wav_file_path)

        D = librosa.stft(y)
        file_name = os.path.splitext(os.path.basename(wav_file_path))[0]

        output_folder = jpg_folder_output

        os.makedirs(output_folder, exist_ok=True)

        output_path = os.path.join(output_folder, f"{file_name}.jpg")

        plt.figure(figsize=(8, 8))
        librosa.display.specshow(librosa.amplitude_to_db(D, ref=np.max), y_axis='hz', x_axis='time')
        plt.axis('off')
        plt.savefig(output_path, bbox_inches = 'tight', pad_inches = 0)
        plt.close()

In [27]:
class CollateFunctor:
    """
    Simple collator to pad decoder sequences
    """

    def __init__(self, pad_idx : int) -> None:
        self.pad_idx = pad_idx
        
    def __call__(self, samples) -> dict[str, torch.tensor]:
        img_tensors = []
        sequences = []

        for sample in sorted(samples, key=lambda x: len(x['seq']), reverse=True):
            img_tensors.append(sample['img'])
            sequences.append(sample['seq'])
        
        padded_seq_tensors = pad_sequence(sequences, 
                                          batch_first=True,
                                          padding_value=self.pad_idx)
        
        img_tensors = torch.stack(tuple(img_tensors), dim=0)
        return {'img': img_tensors,
                'seq': padded_seq_tensors}

In [28]:
# Dataset variables 
data_path = 'dataset.csv' 
produce_spectr = False
transform = transforms.Compose([
    transforms.Resize((620, 620)),
    transforms.ToTensor(),
])

dataset = CustomSpectroDataset(csv_data=data_path,
                               produce_spectr=produce_spectr,
                               transform=transform)

In [29]:
# Train/test splits
seed = 7
gen = torch.Generator().manual_seed(seed)
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [0.9, 0.1], generator=gen)

train_dataloader = DataLoader(train_dataset, 
                              batch_size=32, 
                              shuffle=True,
                              collate_fn=CollateFunctor(pad_idx=dataset.char_to_int['<PAD>']))

test_dataloader = DataLoader(test_dataset, 
                             batch_size=32, 
                             shuffle=True,
                             collate_fn=CollateFunctor(pad_idx=dataset.char_to_int['<PAD>']))

AttributeError: 'CustomSpectroDataset' object has no attribute 'data'

In [9]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Hyperparameters
patch_size = 20
dim_model = 528
encoder_layers = 6
decoder_layers = 6
heads = 6
mlp_dim = 2048
dropout_p = 0.1

vocab_size = len(dataset.get_vocab())
max_seq_len = dataset.longest_seq()
n_channels = 3 

In [10]:
model = Img2SeqModel(device=device,
                     patch_size=patch_size,
                     dim_model=dim_model,
                     encoder_layers=encoder_layers,
                     decoder_layers=decoder_layers,
                     heads=heads,
                     mlp_dim=mlp_dim,
                     dropout_p=dropout_p,
                     vocab_size=vocab_size,
                     n_channels=n_channels,
                     max_seq_len=max_seq_len)

model = model.to(device)

In [11]:
epochs = 2
init_lr = 1e-4
optimizer = torch.optim.AdamW(params=model.parameters(), lr=init_lr)

In [12]:
def train(model, iterator, optimizer, device):
    model.train()
    total_loss = 0
    for batch in tqdm(iterator):
        img = batch['img'].to(device)
        target_seq = batch['seq'].to(device)

        sequence_len = target_seq.size(1)
        tgt_attention_mask = model.get_tgt_mask(sequence_len).to(device)
        tgt_pad_mask = model.create_pad_mask(target_seq, dataset.char_to_int['<PAD>'])

        logits = model(img, target_seq, tgt_mask=tgt_attention_mask, tgt_pad_mask=tgt_pad_mask)
        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        target_seq = target_seq.view(B*T)
        loss = F.cross_entropy(logits, target_seq, ignore_index=dataset.char_to_int['<PAD>'])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.detach().item()
    
    return total_loss / len(iterator)

In [13]:
for epoch in range(epochs):
    print(f'Epoch: {epoch + 1}', end='\n\n')
    epoch_loss = train(model, train_dataloader, optimizer, device)
    print()
    print(f'Loss: {epoch_loss}')    

Epoch: 1



  0%|          | 0/263 [00:00<?, ?it/s]




ValueError: Operation on closed image

In [82]:
@torch.no_grad()
def generate(model, img, device):
    model.eval()
    seq_inp = torch.tensor([1], dtype=torch.long, device=device).unsqueeze(0)
    for _ in range(15):
        logits = model(img, seq_inp)
        logits = logits[:, -1, :]
        probs = F.softmax(logits, dim=-1).detach().to('cpu')
        idx_next = np.argmax(probs, axis=-1).unsqueeze(0).to(device)
        seq_inp = torch.cat(((seq_inp, idx_next)), dim=1)
        print(seq_inp)
batch = next(iter(test_dataloader))
img = batch['img'][0].unsqueeze(0)
generate(model, img, device)

tensor([[1, 4]], device='cuda:0')
tensor([[1, 4, 4]], device='cuda:0')
tensor([[1, 4, 4, 4]], device='cuda:0')
tensor([[1, 4, 4, 4, 4]], device='cuda:0')
tensor([[1, 4, 4, 4, 4, 4]], device='cuda:0')
tensor([[1, 4, 4, 4, 4, 4, 4]], device='cuda:0')
tensor([[1, 4, 4, 4, 4, 4, 4, 4]], device='cuda:0')
tensor([[1, 4, 4, 4, 4, 4, 4, 4, 4]], device='cuda:0')
tensor([[1, 4, 4, 4, 4, 4, 4, 4, 4, 4]], device='cuda:0')
tensor([[1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]], device='cuda:0')
tensor([[1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]], device='cuda:0')
tensor([[1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]], device='cuda:0')
tensor([[1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]], device='cuda:0')
tensor([[1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]], device='cuda:0')
tensor([[1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]], device='cuda:0')


To try:
- Resolve data problems to increase training samples
- Try lowering lr
- Try lowering number of parameters
- Try different sampling stratetgies from the decoder
- Potentially pre-trained decoder?