In [1]:
from processing import COMP5421Config
config = COMP5421Config()

  from .autonotebook import tqdm as notebook_tqdm
Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [2]:
import os
import torch
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from diffusers import DDPMScheduler, UNet2DModel
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import huggingface_hub
from tqdm import tqdm, trange
from dataclasses import asdict
from dotenv import load_dotenv

load_dotenv()

huggingface_hub.login(os.getenv("HF_TOKEN"))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNet2DModel(
    sample_size=config.img_dims,
    in_channels=1,
    out_channels=1,
    layers_per_block=2,
    block_out_channels=(32, 64, 64),
    down_block_types=("DownBlock2D", "AttnDownBlock2D", "DownBlock2D"),
    up_block_types=("UpBlock2D", "AttnUpBlock2D", "UpBlock2D")
).to(device)

# Load the dataset
#If you need this dataset lmk - Darin
dataset = load_dataset(config.dataset_src)

def collate_fn(batch):
    mels = [torch.tensor(item['mel']).unsqueeze(0) for item in batch]  # Adding channel dimension
    mels = torch.stack(mels).to(device)  # Shape will be [batch_size, 1, 128, 432]
    return mels

train_test = dataset['train'].train_test_split(test_size=config.val_size)
train_loader = DataLoader(train_test['train'], batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(train_test['test'], batch_size=config.batch_size, shuffle=False, collate_fn=collate_fn)

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [None]:
import wandb

noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
loss_func = torch.nn.MSELoss()

api_key = os.getenv("WANDB_API_KEY")
wandb.login(key=api_key)

# Comment out this line if you dont need logging
wandb.init(
    project=config.training_name,
    config=asdict(config)
)

def validate(model, loader, noise_scheduler, loss_func):
    val_loss = 0.0
    val_count = 0
    for val_batch in tqdm(val_loader, desc="Validating...", total=config.val_step):
        val_count += 1
        noise = torch.randn_like(batch)
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (config.batch_size,), device=device, dtype=torch.int64)
        noisy_batch = noise_scheduler.add_noise(batch, noise, timesteps)
        noise_pred = model(noisy_batch, timesteps)[0]
        loss = loss_func(noise_pred, noise)
        val_loss += loss
        if val_count >= config.val_step:
            break
    return val_loss / val_count

# Training Loop
step_count = 0
for epoch in range(config.num_epochs):
    model.train()
    epoch_loss = 0.0
    for batch in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{config.num_epochs}'):
        step_count += 1
        optimizer.zero_grad()

        # Add noise
        noise = torch.randn_like(batch)
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (config.batch_size,), device=device, dtype=torch.int64)
        noisy_batch = noise_scheduler.add_noise(batch, noise, timesteps)

        # Forward pass
        noise_pred = model(noisy_batch, timesteps)[0]

        # Loss
        loss = loss_func(noise_pred, noise)
        loss.backward()
        optimizer.step()

        if step_count > 0 and step_count % config.val_step == 0:
            with torch.no_grad():
                val_loss = validate(model, val_loader, noise_scheduler, loss_func)
            wandb.log({"val_batch_loss": val_loss.item()}, step=step_count)

        epoch_loss += loss.item()
        wandb.log({"batch_loss": loss.item()}, step=step_count)

    average_epoch_loss = epoch_loss / len(train_loader)
    print(f'Epoch {epoch + 1} completed, Average Loss: {average_epoch_loss}')
    wandb.log({"epoch_loss": average_epoch_loss}, step=step_count)

wandb.finish()
print("Training completed.")

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\Darin\.netrc
[34m[1mwandb[0m: Currently logged in as: [33myfdchau[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 1/10:   0%|          | 3/11165 [01:14<74:07:07, 23.90s/it]