In [18]:
%reload_ext autoreload
%matplotlib inline

import os

import torch
from PIL import Image
from torch import nn

from models.style_transfer import SoundStyleTransferModel

In [19]:
model = SoundStyleTransferModel()

An error occurred while trying to fetch riffusion/riffusion-model-v1: riffusion/riffusion-model-v1 does not appear to have a file named diffusion_pytorch_model.safetensors.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
An error occurred while trying to fetch riffusion/riffusion-model-v1: riffusion/riffusion-model-v1 does not appear to have a file named diffusion_pytorch_model.safetensors.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.


## Train TVE

In [20]:
def get_audio_dataloader(batch_size=1, dataset="church_bell"):
    images = []
    path = f"./audios/timbre/{dataset}"
    for filename in os.listdir(path):
        if filename.endswith(".png"):
            image = Image.open(os.path.join(path, filename))
            image = SoundStyleTransferModel.preprocess_image(image)
            image = image.squeeze(0)
            images.append(image)

    class MyDataset(torch.utils.data.Dataset):
        def __init__(self, dataset, label):
            self.dataset = dataset
            self.label = label

        def __len__(self):
            return len(self.dataset)

        def __getitem__(self, idx):
            return self.dataset[idx], self.label

    dataloader = torch.utils.data.DataLoader(MyDataset(images, dataset), batch_size=batch_size, shuffle=True)
    return dataloader

In [21]:
learning_rate = 0.001
batch_size = 1

for param in model.text_transform.text_encoder.parameters():
    param.requires_grad = False

for param in model.unet.parameters():
    param.requires_grad = False

optimizer = torch.optim.Adam(model.text_transform.tve.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

dataloader = get_audio_dataloader(batch_size=batch_size, dataset="church_bell")

num_epochs = 500
for epoch in range(num_epochs):
    epoch_loss = 0
    for step, (images, labels) in enumerate(dataloader):
        images, labels = images.to(device=model.device), ["*"] * len(labels)

        optimizer.zero_grad()

        with torch.no_grad():
            init_latents = model.encode_images(images)

        noise = torch.randn_like(init_latents)
        timesteps = torch.randint(
            0,
            model.scheduler.config.num_train_timesteps,
            (batch_size,),
            dtype=torch.int64,
            device=model.device
        )
        noisy_latents = model.scheduler.add_noise(init_latents, noise, timesteps)

        with torch.no_grad():
            text_embeddings = [model.text_transform.embed_text(label) for label in labels]
            text_embeddings = torch.stack(text_embeddings).squeeze(dim=1).to(device=model.device)

        label_embeddings = model.text_transform.tve(timesteps, text_embeddings)

        with torch.amp.autocast("cuda"):
            pred_noise = model(noisy_latents, label_embeddings, timesteps)

            loss = criterion(pred_noise, noise)
            epoch_loss += loss.item()

            loss.backward()
            optimizer.step()

    print(f"Epoch {epoch + 1}/{num_epochs} Loss: {loss.item() / len(dataloader)}")

torch.save(model.text_transform.tve.state_dict(), "./data/tve.pth")

Epoch 1/500 Loss: 0.07465159893035889
Epoch 2/500 Loss: 0.005687912305196126
Epoch 3/500 Loss: 0.005481130753954251
Epoch 4/500 Loss: 0.004845332354307175
Epoch 5/500 Loss: 0.03421626736720403
Epoch 6/500 Loss: 0.008676807706554731
Epoch 7/500 Loss: 0.0007230174572517475
Epoch 8/500 Loss: 0.01728026568889618
Epoch 9/500 Loss: 0.019629814972480137
Epoch 10/500 Loss: 0.04181812206904093
Epoch 11/500 Loss: 0.0374076763788859
Epoch 12/500 Loss: 0.02551819384098053
Epoch 13/500 Loss: 0.01216357077161471
Epoch 14/500 Loss: 0.009954245140155157
Epoch 15/500 Loss: 0.06694128612677257
Epoch 16/500 Loss: 0.027562879025936127
Epoch 17/500 Loss: 0.00039504666347056627
Epoch 18/500 Loss: 0.0315287709236145
Epoch 19/500 Loss: 0.022748532394568127
Epoch 20/500 Loss: 0.004026090415815513
Epoch 21/500 Loss: 0.02466689298550288
Epoch 22/500 Loss: 0.012723686794439951
Epoch 23/500 Loss: 0.04133350153764089
Epoch 24/500 Loss: 0.006824962794780731
Epoch 25/500 Loss: 0.017137716213862102
Epoch 26/500 Loss: 

## Sample sound

In [22]:
model.text_transform.tve.load_state_dict(torch.load("./data/tve.pth", weights_only=True))
image = Image.open("sample.png")
print(image.size)

prompt_start = "*"
prompt_end = "jazz with piano"

image = model.transfer_style(image, prompt_start, prompt_end, use_tve=True)

image.save("out_sample.png")
print(image.size)

(502, 512)
torch.Size([1, 4, 64, 60])


100%|██████████| 38/38 [00:15<00:00,  2.39it/s]


(480, 512)
