# Model Showcase
`Img2Music` is a model that takes in an album cover image, and generates a music snippet based on the input image's features.
It does this by predicting the most likely genre of the input album cover image, then generating an audio snippet given the predicted genre and the image's features.


In this notebook, we will:
1. initialize our datasets
2. declare and initialize our model using `PyTorch`
3. see what the model outputs when we pass in:
  1. album cover images from the dataset
  2. custom images

** **
**Note** that we simply showcasing the model's architecture and capabilities in this notebook.
For the full training pipeline, see the included notebook `[img2music] Full Training Pipeline.ipynb`.

# Setup

In [None]:
# @title Imports & Pre-trained Weights downloads
# @markdown Note that the pre-trained weights are obtained from the training pipeline and saved in google drive.
# @markdown For more information on how the models were trained, see the `Full Training Pipeline.ipynb` file.

%pip install torch-summary

import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms

import os
import shutil
import random
import pathlib
import librosa
import kagglehub
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image
from io import BytesIO
from pathlib import Path
from google.colab import files
from torchvision import datasets
from torch.utils.data import DataLoader, random_split
from torchvision.transforms.v2 import GaussianNoise
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, Subset

from IPython.display import Audio

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


os.makedirs("/content/pretrained_weights/", exist_ok=True)
!gdown 1cxWq874GMY0-Fb-dLVOiVn8zNknHBiTc --output "/content/pretrained_weights/"
!gdown 1-9RjQlp8e4p9sejV7J_Sfin6fVhET7Ue --output "/content/pretrained_weights/"
!gdown 1AUZ_TZ8ZD_-LmmGXAWR8NKG4rjrHd-12 --output "/content/pretrained_weights/"
!gdown 1C1e2nuEQ05CQeWHrson4fFt6GCSlOk98 --output "/content/pretrained_weights/"
!gdown 1-haiRrcriPoVolAU7o5rylNXELVUZE0T --output "/content/pretrained_weights/"

In [None]:
# @title Declare Helper Functions
def get_class_subsets(dataset: datasets.ImageFolder, class_names: list[str])  -> dict[str, Subset]:
    """
    Given an ImageFolder dataset and a list of class folder names, return a dict mapping each class name to its Subset.
    Allows us to isolate images from each genre when running the model.
    """
    subsets = {}
    for cname in class_names:
        cidx = dataset.class_to_idx[cname]
        indices = [i for i, (_, lbl) in enumerate(dataset.samples)
                   if lbl == cidx]
        subsets[cname] = Subset(dataset, indices)
    return subsets

In [None]:
# @title Initialize Album Cover Images
# @markdown For more information, see the `Full Training Pipeline.ipynb` file.

# this code is lifted from the 'full training pipeline' notebook with slight modifications.

BATCH_SIZE = 50
NOISE_VARIANCE = 0.15


dataset_path = kagglehub.dataset_download("michaeljkerr/20k-album-covers-within-20-genres")
dataset_path = os.path.join(dataset_path, "GAID")

copy_path = "/content/data/albumcovers/"
if not os.path.exists(copy_path):
  shutil.copytree(dataset_path, copy_path)
dataset_path = copy_path

classes_to_keep = [ "Classical", "HipHop",  "Pop", "Jazz" ]
directories_to_delete = [f"{dataset_path}/{name}" for name in os.listdir(dataset_path) if name not in classes_to_keep]
for _, dir_to_delete in enumerate(directories_to_delete):
    shutil.rmtree(dir_to_delete)
classes = classes_to_keep


default_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # densenet takes 224 by 224 input
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


dataset = datasets.ImageFolder(root=dataset_path)
dataset_size = len(dataset)
training_ratio = 0.8
validation_ratio = 0.1
testing_ratio = 1 - training_ratio - validation_ratio

split_lengths = [int(dataset_size * training_ratio), int(dataset_size * validation_ratio), int(dataset_size * testing_ratio)]
split_lengths[2] += dataset_size - np.sum(split_lengths)  # correct the testing dataset size
training_dataset, validation_dataset, testing_dataset = random_split(dataset, lengths=split_lengths)
apply_alt_training_transforms = False

training_dataset.dataset = datasets.ImageFolder(root=dataset_path, transform=default_transform)
validation_dataset.dataset = datasets.ImageFolder(root=dataset_path, transform=default_transform)
testing_dataset.dataset = datasets.ImageFolder(root=dataset_path, transform=default_transform)

training_dataloader = DataLoader(training_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, worker_init_fn=seed_worker, persistent_workers=False)
validation_dataloader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, worker_init_fn=seed_worker, persistent_workers=False)
testing_dataloader = DataLoader(testing_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=1, worker_init_fn=seed_worker, persistent_workers=False)

In [None]:
# @title Initialize Music Audio Snippets
# @markdown For more information, see the `Full Training Pipeline.ipynb` file.

# this code is lifted from the 'full training pipeline' notebook with slight modifications.

SR = 22_050

def load_audio_files(files, sr=SR):
    audio_data = []
    for file in files:
        try:
            data, _ = librosa.load(file, sr=sr)
            audio_data.append(data)
        except Exception as e:
            print(f"Skipping file {file} due to error: {e}")
    return audio_data

def pad_audio_data(audio_data, target_length):
    padded_audio_data = []
    for clip in audio_data:
        if len(clip) < target_length:  # pad if < target_length
            pad_length = target_length - len(clip)
            padded_clip = np.pad(clip, (0, pad_length), mode='constant')
        else:  # truncate if > target length
            padded_clip = clip[:target_length]
        padded_audio_data.append(padded_clip)
    return np.array(padded_audio_data)

def get_split_dataloaders(padded_audio_data, test_size=0.2, random_state=42):
    train_data, test_data = train_test_split(padded_audio_data, test_size=0.2, random_state=42)

    # Normalize Audio Data
    max_value = np.max([np.abs(train_data).max(), np.abs(test_data).max()])
    train_data_normalized = train_data / max_value
    test_data_normalized = test_data / max_value

    train_tensor = torch.tensor(train_data_normalized, dtype=torch.float32).unsqueeze(1)
    test_tensor = torch.tensor(test_data_normalized, dtype=torch.float32).unsqueeze(1)

    train_dataset = TensorDataset(train_tensor)
    test_dataset = TensorDataset(test_tensor)

    train_loader = DataLoader(dataset=train_dataset, batch_size=8, shuffle=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=8, shuffle=False)
    return train_loader, test_loader



path = kagglehub.dataset_download("andradaolteanu/gtzan-dataset-music-genre-classification")
path = f"{path}/Data/genres_original/"
print("Path to dataset files:", path)

files_per_genre = [list((pathlib.Path(path) / genre.lower()).glob('*.wav')) for genre in classes]
audio_per_genre = [load_audio_files(files) for files in files_per_genre]
flattened = [item for sublist in audio_per_genre for item in sublist]
target_length = min(map(len, flattened)) // 2
audio_per_genre_padded = [pad_audio_data(audio_data, target_length) for audio_data in audio_per_genre]

loaders_per_genre = [get_split_dataloaders(padded_audio) for padded_audio in audio_per_genre_padded]
genre_loader_pairs = list(zip(classes, loaders_per_genre))

In [None]:
# @title Declare Module for Classes

class AutoEncoder(nn.Module):
    """
    Convolutional AutoEncoder for 1D signals.
    In the context of this project, it acts as a Variational AutoEncoder (VAE) for audio waveforms.

    Original Author:
    https://yuehan-z.medium.com/introduction-to-vaes-in-ai-music-generation-d8e0cfc2245b

    Args:
        in_channels (int): Number of channels in the input tensor.
        out_channels (int): Number of channels in the output tensor.
        down_channels (list[int]): Output channels for each encoder block.
        up_channels (list[int]): Output channels for each decoder block.
        down_rate (list[int]): Stride values for each encoder convolution.
        up_rate (list[float]): Upsampling factors for each decoder stage.
        cross_attention_dim (int): Number of channels in the bottleneck layers.
    """
    def __init__(
        self,
        in_channels,
        out_channels,
        down_channels,
        up_channels,
        down_rate,
        up_rate,
        cross_attention_dim
    ):
        """
        Initialize encoder, bottleneck, decoder, and output layers.
        """
        super(AutoEncoder, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.down_channels = down_channels
        self.up_channels = up_channels
        self.down_rate = down_rate
        self.up_rate = up_rate
        self.cross_attention_dim = cross_attention_dim

        # build encoder
        self.encoder_layers = nn.ModuleList()
        for i, out_ch in enumerate(down_channels):
            in_ch = in_channels if i == 0 else down_channels[i-1]
            stride = down_rate[i]
            layer = nn.Sequential(
                nn.Conv1d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1),
                nn.BatchNorm1d(out_ch),
                nn.ReLU(inplace=True),
                nn.Conv1d(out_ch, out_ch, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm1d(out_ch),
                nn.ReLU(inplace=True)
            )
            self.encoder_layers.append(layer)

        # build bottleneck
        self.bottleneck_layers = nn.Sequential(
            nn.Conv1d(down_channels[-1], cross_attention_dim, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(cross_attention_dim),
            nn.ReLU(inplace=True),
            nn.Conv1d(cross_attention_dim, cross_attention_dim, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(cross_attention_dim),
            nn.ReLU(inplace=True)
        )

        # build decoder
        self.decoder_layers = nn.ModuleList()
        for i, out_ch in enumerate(up_channels):
            in_ch = cross_attention_dim if i == 0 else up_channels[i-1]
            layer = nn.Sequential(
                nn.ConvTranspose1d(in_ch, out_ch, kernel_size=3, padding=1),
                nn.BatchNorm1d(out_ch),
                nn.ReLU(inplace=True),
                nn.ConvTranspose1d(out_ch, out_ch, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm1d(out_ch),
                nn.ReLU(inplace=True)
            )
            self.decoder_layers.append(layer)

        # final output
        self.output_layer = nn.Sequential(
            nn.ConvTranspose1d(up_channels[-1], out_channels, kernel_size=3, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        """
        Forward pass through the autoencoder.

        Args:
            x (Tensor): Input tensor of shape (B, in_channels, L).

        Returns:
            Tensor: Reconstructed tensor of shape (B, out_channels, 330000).
        """
        x = self.encode(x)
        x = self.bottleneck(x)
        x = self.decode(x)
        x_out = self.output_layer(x)
        # trim or pad to fixed length
        x_out = x_out[:, :, :330_000]
        return x_out

    def encode(self, x):
        """
        Apply encoder layers to downsample input.

        Args:
            x (Tensor): Input tensor.

        Returns:
            Tensor: Latent representation.
        """
        for layer in self.encoder_layers:
            x = layer(x)
        return x

    def bottleneck(self, x):
        """
        Process latent tensor through bottleneck convolutional blocks.

        Args:
            x (Tensor): Encoder output tensor.

        Returns:
            Tensor: Features at bottleneck.
        """
        return self.bottleneck_layers(x)

    def decode(self, x):
        """
        Apply decoder layers with upsampling to reconstruct features.

        Args:
            x (Tensor): Bottleneck feature tensor.

        Returns:
            Tensor: Upsampled feature tensor before output.
        """
        for i, layer in enumerate(self.decoder_layers):
            target_length = int(x.shape[2] * self.up_rate[i])
            x = nn.functional.interpolate(x, size=target_length, mode='nearest')
            x = layer(x)
        return x



class Img2Music(nn.Module):
    """
    Image-to-Music generator combining a ConvNeXt classifier with per-genre VAEs.

    Args:
        weights_dir (str):             Directory path containing classifier and VAE pre-trainined weights.
        genre_audios (list[list]):     Lists of raw audio samples per genre.
        classes (list[str]):           Names of the music genres for VAEs.
        img_features_weight (float):   Influence of image features on generated audio.
    """
    def __init__(
        self,
        weights_dir: str,
        genre_audios: list,
        classes: list,
        img_features_weight: float = 1e-3,
    ):
        """
        Initialize ConvNeXt classifier and per-genre VAEs.

        - Sets up a ConvNeXt_Base pretrained on ImageNet, replaces the final linear
          layer to match genre count, and loads its weights from weights_dir.
        - Instantiates a VAE for each genre and loads its weights.
        """
        super(Img2Music, self).__init__()
        self.genre_audios = genre_audios
        self.img_features_weight = img_features_weight

        # Load and configure ConvNeXt classifier
        weights = torchvision.models.convnext.ConvNeXt_Base_Weights.IMAGENET1K_V1
        convnext = torchvision.models.convnext_base(weights=weights)
        convnext.classifier[2] = nn.Linear(in_features=1024, out_features=len(classes), bias=True)
        convnext.load_state_dict(torch.load(f"{weights_dir}classifier.pth"))
        self.classifier = convnext

        # Load per-genre AutoEncoders
        model_cfg = {
            'in_channels': 1,
            'out_channels': 1,
            'down_channels': [16, 32, 64, 128],
            'up_channels': [384, 192, 96, 48],
            'down_rate': [4, 4, 3, 2],
            'up_rate': [2, 3, 4, 4],
            'cross_attention_dim': 1024,
        }
        self.vaes = []
        for cls_name in classes:
            vae = AutoEncoder(**model_cfg)
            vae.load_state_dict(torch.load(f"{weights_dir}{cls_name.lower()}.pth"))
            vae.to(DEVICE)
            vae.eval()
            self.vaes.append(vae)

    def forward(self, x: torch.Tensor) -> tuple[int, torch.Tensor]:
        """
        Generate a music sample from an input image.

        Args:
            x (Tensor): Input image tensor of shape (B, C, H, W).

        Returns:
            genre (int): Predicted genre index.
            audio (Tensor): Generated audio waveform tensor.
        """
        # predict genre and select corresponding VAE
        genre = self.classifier(x).argmax(dim=1).item()
        vae = self.vaes[genre]

        # pick a random seed audio and reshape
        sample = random.choice(self.genre_audios[genre])
        sample = torch.tensor(sample).unsqueeze(0).unsqueeze(0).to(DEVICE)
        sample_embed = vae.bottleneck(vae.encode(sample))

        # add image feature vector to sampled embedding vector
        x_gray = x.mean(dim=1, keepdim=True)
        feature_vec = torch.nn.functional.interpolate(x_gray, size=(1024, 3438), mode='bilinear', align_corners=False)
        feature_vec = feature_vec.squeeze(1).to(DEVICE)
        feature_vec = feature_vec * self.img_features_weight

        # generate new audio via VAE
        decoded = vae.output_layer(vae.decode(sample_embed.to(DEVICE) + feature_vec.to(DEVICE)))
        audio = decoded.squeeze(0).squeeze(0)
        return genre, audio


model = Img2Music('/content/pretrained_weights/', audio_per_genre_padded, classes)
model = model.to(DEVICE)
model.eval();

alt_dataset = datasets.ImageFolder(root=dataset_path, transform=default_transform)
classes = ["Classical", "HipHop", "Pop", "Jazz"]
album_covers = get_class_subsets(alt_dataset, classes)

# Running the model

In [None]:
# @title Sample from dataset
# @markdown Here, we pass an album cover image from the dataset.
# @markdown Configure the below values to select which album cover image from which genre to use.
# @markdown ****

# setup
genre = "Jazz" # @param ["Classical", "HipHop", "Pop", "Jazz"]
genre_idx = classes.index(genre)

idx = 0 # @param {type:"integer"}
idx = idx % 1000
idx = max(0, min(idx, 1000-1))

album_cover, _ = album_covers[genre][idx]
album_cover = album_cover.unsqueeze(0)


# display the input album cover image
img = album_cover.squeeze(0).permute(1, 2, 0).numpy()
img = (img - img.min()) / (img.max() - img.min())

plt.imshow(img)
plt.axis('off')
plt.title(f'Input Album Cover Image ({classes[genre_idx]})')


# pass it through the model
pred_genre, audio = model(album_cover.to(DEVICE))
audio_np = audio.detach().cpu().numpy()
print(f"Predicted genre of input image: \"{classes[pred_genre]}\"")
display(Audio(audio_np, rate=SR))

In [None]:
#@title Custom Sample
# @markdown Here, we can upload our own image and pass it through the model.
# @markdown The classifier will predict what the most likely genre of the image is, then generate audio corresponding to the specific genre.

# @markdown Simply run this code block to start.

# setup
uploaded = files.upload()  # displays file picker
fname, data = next(iter(uploaded.items()))
img = Image.open(BytesIO(data)).convert("RGB")
tensor = default_transform(img)
tensor = tensor.unsqueeze(0)


# display the input image
img = tensor.squeeze(0).permute(1, 2, 0).numpy()
img = (img - img.min()) / (img.max() - img.min())

plt.imshow(img)
plt.axis('off')
plt.title(f'Input image:')


# pass it through the model
pred_genre, audio = model(tensor.to(DEVICE))
audio_np = audio.detach().cpu().numpy()
print(f"\n\nPredicted genre of input image: \"{classes[pred_genre]}\"")
display(Audio(audio_np, rate=SR))