In [2]:
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
import torch.nn.utils.rnn as rnn_utils

# import tim

import matplotlib.pyplot as plt
# import pandas as pd
import numpy as np

In [2]:
input_folder = '../spectrograms/2004'
output_folder = '../midi-processed-values/2004'

In [3]:
def collate_fn(batch):
  images, npy_arrays = zip(*batch)

  # Stack images (they should all be the same size)
  images = torch.stack(images)

  # Pad NPY arrays to the same length
  lengths = [len(npy) for npy in npy_arrays]
  padded_npy_arrays = rnn_utils.pad_sequence(npy_arrays, batch_first=True)

  return images, padded_npy_arrays, lengths

In [4]:
class SpectrogramDataset(Dataset):
    def __init__(self, image_folder, npy_folder):
        self.image_folder = image_folder
        self.npy_folder = npy_folder
        self.image_files = [f for f in os.listdir(image_folder) if f.endswith('.png')]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_name = self.image_files[idx]
        image_path = os.path.join(self.image_folder, image_name)
        npy_path = os.path.join(self.npy_folder, image_name.replace('.png', '.midi.npy'))

        image = Image.open(image_path).convert('RGB')
        image = np.array(image)
        image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1)  # Convert to CxHxW

        npy_array = np.load(npy_path)
        npy_array = torch.tensor(npy_array, dtype=torch.float32)

        return image, npy_array

# Example usage
image_folder = '../spectrograms/2004'
npy_folder = '../midi-processed-values/2004'
dataset = SpectrogramDataset(image_folder, npy_folder)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, collate_fn=collate_fn, num_workers=0)

In [5]:
class CNNFeatureExtractor(nn.Module):
    def __init__(self):
        super(CNNFeatureExtractor, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(128 * 62 * 175, 512)  # Adjust based on input image size
        self.fc2 = nn.Linear(512, 256)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 128 * 62 * 175)  # Flatten
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [6]:
class TransformerModel(nn.Module):
    def __init__(self, input_dim, output_dim, num_heads, num_layers):
        super(TransformerModel, self).__init__()
        self.transformer = nn.Transformer(d_model=input_dim, nhead=num_heads, num_encoder_layers=num_layers, num_decoder_layers=num_layers)
        self.fc = nn.Linear(input_dim, output_dim)
    
    def forward(self, src, tgt, tgt_mask=None):
        output = self.transformer(src, tgt, tgt_mask=tgt_mask)
        output = self.fc(output)
        return output

In [7]:
len(dataset.image_files)

132

In [8]:
for images, npy_arrays, lengths in dataloader:
        print("IMAGES SHAPE: ", images.shape)
        print("NPY SHAPE: ", npy_arrays.shape)
        print("NPY AT 0 SHAPE", npy_arrays[0].shape)
        break

IMAGES SHAPE:  torch.Size([64, 3, 500, 1400])
NPY SHAPE:  torch.Size([64, 64239, 17])
NPY AT 0 SHAPE torch.Size([64239, 17])


: 

In [None]:
# Combine CNN and Transformer
class CombinedModel(nn.Module):
    def __init__(self, cnn, transformer, d_model):
        super(CombinedModel, self).__init__()
        self.cnn = cnn
        self.transformer = transformer
        self.fc_tgt = nn.Linear(17, d_model)

    def forward(self, image, tgt, lengths):
        print("THESE ARE THE SEQUENCES", tgt.shape)
        print("IMAGE TENSOR SHAPE PRE CNN", image.shape)
        features = self.cnn(image)
        print("IMAGE TENSOR SHAPE POST CNN", features.shape)
        src = features.unsqueeze(1).repeat(1, tgt.size(1), 1).permute(1, 0, 2)  # Add batch dimension for transformer
        print("IMAGE TENSOR SHAPE POST UNSQUEEZE", src.shape)
        tgt = self.fc_tgt(tgt).permute(1, 0, 2)

        print("SOURCE SHAPE", src.shape)
        print("TARGET SHAPE", tgt.shape)

        # Generate mask for tgt sequences
        print("GENERATING TGT MASK...")
        tgt_mask = self.generate_square_subsequent_mask(tgt.size(0)).to(tgt.device)
        print("Done with TGT MASK.")

        # Pack the padded sequences
        # packed_tgt = rnn_utils.pack_padded_sequence(tgt, lengths, batch_first=True, enforce_sorted=False)
        # print("THESE ARE THE PACKED SEQUENCES", packed_tgt.data.shape)

        # print("RIGHT ABOUT TO ENTER THE TRANSFORMER")
        # print(packed_tgt.data.shape)
        output = self.transformer(src, tgt, tgt_mask=tgt_mask)
        return output
    
    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

cnn = CNNFeatureExtractor()
transformer = TransformerModel(input_dim=256, output_dim=17, num_heads=4, num_layers=2)
model = CombinedModel(cnn, transformer, d_model=256)

# Training loop
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(10):  # Number of epochs
    for images, npy_arrays, lengths in dataloader:
        optimizer.zero_grad()
        outputs = model(images, npy_arrays, lengths)
        loss = criterion(outputs, npy_arrays)
        loss.backward()
        optimizer.step()

    print(f'Epoch {epoch+1}, Loss: {loss.item()}')



THESE ARE THE SEQUENCES torch.Size([64, 73394, 17])
IMAGE TENSOR SHAPE PRE CNN torch.Size([64, 3, 500, 1400])
IMAGE TENSOR SHAPE POST CNN torch.Size([64, 256])
IMAGE TENSOR SHAPE POST UNSQUEEZE torch.Size([73394, 64, 256])
SOURCE SHAPE torch.Size([73394, 64, 256])
TARGET SHAPE torch.Size([73394, 64, 256])
GENERATING TGT MASK...
Done with TGT MASK.


False