In [4]:
import torch
import pipeline
import model_loader
import numpy as np
#from tqdm import tqdm
#from ddpm import DDPMSampler
import torchaudio
import torch.nn.functional as F
#from diffusion import Diffusion
#from encoder import VAE_Encoder
#from decoder import VAE_Decoder
import pandas as pd
#from transformers import CLIPTokenizer
import os

from dataclasses import dataclass

#torch.set_default_device("mps")

AUDIO_PATH = "/Users/nathanielsmith/Desktop/ai/audio_diffusion/audio/"
LABELS_FILE_PATH = AUDIO_PATH + "labels.csv"

SAMPLE_RATE = 44100
INPUT_AUDIO_LENGTH_SECONDS = 1

#DEVICE = torch.device("mps:0")

dataset = torch.randn(0, 2, 65536)
desired_dimensions = [2, 65536]

data = pd.read_csv(LABELS_FILE_PATH, names=["id", "label", "audio_path"])

print(torch.__version__)

for path in data["audio_path"]:
    waveform, sample_rate = torchaudio.load(AUDIO_PATH + path)
    print(waveform.size())


2.0.1
torch.Size([2, 4629])


In [None]:
model_file = "../data/v1-5-pruned-emaonly.ckpt"
#models = model_loader.preload_models_from_standard_weights(model_file, DEVICE)

def get_time_embedding(timestep):
    # Shape: (160,)
    freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160) 
    # Shape: (1, 160)
    x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
    # Shape: (1, 160 * 2)
    return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)

@dataclass
class TrainingConfig:
    image_size = 128  # the generated image resolution
    train_batch_size = 16
    eval_batch_size = 16  # how many images to sample during evaluation
    num_epochs = 1 #50
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    save_image_epochs = 10
    save_model_epochs = 30
    mixed_precision = "fp16"  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = "ddpm-butterflies-128"  # the model name locally and on the HF Hub

    push_to_hub = False  # whether to upload the saved model to the HF Hub
    hub_model_id = "<your-username>/<my-awesome-model>"  # the name of the repository to create on the HF Hub
    hub_private_repo = False
    overwrite_output_dir = True  # overwrite the old model when re-running the notebook
    seed = 0


config = TrainingConfig()

tokenizer = CLIPTokenizer("../data/vocab.json", merges_file="../data/merges.txt")
#clip = models["clip"]

# load dataset

#dataset = torch.randn(0, 2, 44100)
#desired_dimensions = [2, 44100]

dataset = torch.randn(0, 2, 65536)
desired_dimensions = [2, 65536]

data = pd.read_csv(LABELS_FILE_PATH, names=["id", "label", "audio_path"])

for path in data["audio_path"]:
    waveform, sample_rate = torchaudio.load(AUDIO_PATH + path)

    if waveform.size(1) < INPUT_AUDIO_LENGTH_SECONDS * SAMPLE_RATE:
        padding_width = desired_dimensions[1] - waveform.shape[1]
        waveform = F.pad(waveform, (0, padding_width), mode='constant', value=0)
    elif waveform.size(1) > INPUT_AUDIO_LENGTH_SECONDS * SAMPLE_RATE:
        waveform = waveform[:, :desired_dimensions[1]]
    

    waveform = waveform.unsqueeze(0)

    dataset = torch.cat([dataset, waveform], dim=0)


# preprocess data

train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)


encoder = VAE_Encoder()
diffusion = Diffusion()
decoder = VAE_Decoder()

encoder.to(DEVICE)
diffusion.to(DEVICE)
decoder.to(DEVICE)


loss_function = torch.nn.MSELoss()  # Choose an appropriate loss function
optimizer = torch.optim.AdamW(list(encoder.parameters()) + list(diffusion.parameters()) + list(decoder.parameters()), lr=config.learning_rate)

generator = torch.Generator(device=DEVICE)


noise_scheduler = DDPMSampler(generator)
noise_scheduler.set_inference_timesteps(50)

# Training loop
for epoch in range(config.num_epochs):
    progress_bar = tqdm(total=len(train_dataloader))
    progress_bar.set_description(f"Epoch {epoch}")

    for batch_data in train_dataloader:
        # Forward pass
        input_data = batch_data  # Adjust as needed, make sure these are just the audios/images
        #labels = input_data
        input_data.to(DEVICE)
        #context = torch.randn(1, 77, batch_data.size(2))
        context = torch.randn(1, 77, 768, device=DEVICE)


        #tokens = tokenizer.batch_encode_plus(
        #    [label], padding="max_length", max_length=77
        #).input_ids
        #tokens = torch.tensor(tokens, dtype=torch.long, device=DEVICE)
        #context = clip(tokens)

        # Generate noise
        encoder_noise = torch.randn([1, 4, 8192], device=DEVICE).to(DEVICE)

        bs = input_data.shape[0]

        timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=DEVICE,dtype=torch.int64)
        latents = encoder(input_data, encoder_noise)
        print("a: ", latents.shape)
        latents = noise_scheduler.add_noise(latents, noise_scheduler.timesteps[0])
        print("b: ", latents.shape)        
        timesteps = tqdm(noise_scheduler.timesteps)
        for i, timestep in enumerate(timesteps):
            time_embedding = get_time_embedding(timestep).to(DEVICE)

            model_input = latents
            print("a")
            model_output = diffusion(model_input, context, time_embedding)
            print("b")
            latents = noise_scheduler.step(timestep, latents, model_output)
            print("c")
        print("c: ", latents.shape)        


        output = decoder(latents)

        # Compute loss
        loss = loss_function(output, input_data)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Print epoch loss or other metrics
    print(f'Epoch [{epoch+1}/{config.num_epochs}], Loss: {loss.item()}')

torch.save({
    'encoder_state_dict': encoder.state_dict(),
    'diffusion_state_dict': diffusion.state_dict(),
    'decoder_state_dict': decoder.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, '../data/data.pth')


