In [1]:
import math
import torch
import torchaudio
from torch import nn
from torch.utils.data import DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt
import random
import torchaudio.transforms as TT
import librosa
import os
import logging
import tqdm as tqdm
import pandas as pd

In [2]:
dataset = []

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Assuming you have labeled_data.csv in the same directory as your script
#csv_file_path = 'labeled_data.csv'
data_directory = '/nas/home/jalbarracin/datasets/hrir_st'

dataset = []

for filename in os.listdir(data_directory):
    # Extract subject ID and measurement point from the file name
    parts = os.path.splitext(filename)[0].split('_')
    subject_id = int(parts[0][2:])  # Extract subject ID, assuming it starts with "pp"
    measurement_point = int(parts[-1])  # Extract measurement point

    # Load audio data
    file_path = os.path.join(data_directory, filename)
    wave, sr = torchaudio.load(file_path, normalize=True)

    # Append the audio data, subject ID, and measurement point to the dataset
    dataset.append({'audio': wave, 'subject_id': subject_id, 'measurement_point': measurement_point})


Using device: cuda


In [3]:
print(f"Dataset length: {len(dataset)}")
element = dataset[0]
audio_ex = element['audio']
print(f"Sample channels and length: {audio_ex.shape}")
for item in dataset:
    audio_data = item['audio']
    subject_id = item['subject_id']
    measurement_point = item['measurement_point']
    #print(subject_id)

Dataset length: 3096
Sample channels and length: torch.Size([2, 256])


In [4]:
def collate_fn(batch):
    """
    Collate function for the DataLoader. Assumes each element in batch is a dictionary with 'audio' and 'measurement_point' keys.
    """
    audio_batch = [item['audio'] for item in batch]
    measurement_point_batch = [item['measurement_point'] for item in batch]
    return {'audio': torch.stack(audio_batch), 'measurement_point': torch.tensor(measurement_point_batch)}

In [5]:
BATCH_SIZE = 1  # len(dataset) // 4
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
    collate_fn=collate_fn,
)

In [30]:
class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv1d(2 * in_ch, out_ch, kernel_size=3, padding=1)
            self.transform = nn.ConvTranspose1d(out_ch, out_ch, kernel_size=4, stride=2, padding=1)
        else:
            self.conv1 = nn.Conv1d(in_ch, out_ch, kernel_size=3, padding=1)
            self.transform = nn.Conv1d(out_ch, out_ch, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv1d(out_ch, out_ch, kernel_size=3, padding=1)
        self.bnorm1 = nn.BatchNorm1d(out_ch)
        self.bnorm2 = nn.BatchNorm1d(out_ch)
        self.relu = nn.ReLU()

    def forward(self, x, t):
        # Remove the last singleton dimension
        #print("x before", x.shape)
        x = x.squeeze(2)
        #print("x shape",x.shape)
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))  # Adjust here
        #print("After conv1 shape:", h.shape)
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        # Extend last dimension
        time_emb = time_emb.unsqueeze(-1)
        # Add time channel
        h = h + time_emb
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        #print("After conv2 shape:", h.shape)
        # Down or Upsample
        return self.transform(h)

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        # TODO: Double check the ordering here
        return embeddings

class SimpleUnet_cond(nn.Module):
    def __init__(self, num_classes=None):
        super().__init__()
        audio_channels = 2  # Adjust for stereo audio
        down_channels = (64, 128, 256, 512, 1024)
        up_channels = (1024, 512, 256, 128, 64)
        out_dim = 2  # Adjust for stereo audio
        time_emb_dim = 32

        # Time embedding
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU()
        )

        # Initial projection
        self.conv0 = nn.Conv1d(audio_channels, down_channels[0], kernel_size=3, padding=1)  # Adjust here

        # Downsample
        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i + 1], time_emb_dim) \
                                    for i in range(len(down_channels) - 1)])
        # Upsample
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i + 1], time_emb_dim, up=True) \
                                  for i in range(len(up_channels) - 1)])

        # Output layer
        self.output = nn.Conv1d(up_channels[-1], out_dim, kernel_size=1)
        
        if num_classes is not None:
            self.label_emb = nn.Embedding(num_classes,time_emb_dim)

    def forward(self, x, timestep, y):
        # Embed time
        t = self.time_mlp(timestep)
        print("Time embedding shape:", t.shape)
        if y is not None:
            t +=self.label_emb(y)
        # Initial conv
        x = x.squeeze(0)
        print("before conv0 shape:", x.shape)
        x = self.conv0(x) 
        print("After conv0 shape:", x.shape)
        # Unet
        residual_inputs = []
        for i, down in enumerate(self.downs):
            #print("hola")
            x = down(x, t)
            #print(f"After downsampling block {i} shape:", x.shape)
            residual_inputs.append(x)
        for i, up in enumerate(self.ups):
            residual_x = residual_inputs.pop()
            # Add residual x as additional channels
            x = torch.cat((x, residual_x), dim=1)
            x = up(x, t)
            #print(f"After upsampling block {i} shape:", x.shape)
        # Assuming the final output layer is 1D
        return self.output(x)


# Example usage
model = SimpleUnet_cond(num_classes=36)
print("Num params: ", sum(p.numel() for p in model.parameters()))
model

Num params:  18570466


SimpleUnet_cond(
  (time_mlp): Sequential(
    (0): SinusoidalPositionEmbeddings()
    (1): Linear(in_features=32, out_features=32, bias=True)
    (2): ReLU()
  )
  (conv0): Conv1d(2, 64, kernel_size=(3,), stride=(1,), padding=(1,))
  (downs): ModuleList(
    (0): Block(
      (time_mlp): Linear(in_features=32, out_features=128, bias=True)
      (conv1): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))
      (transform): Conv1d(128, 128, kernel_size=(4,), stride=(2,), padding=(1,))
      (conv2): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))
      (bnorm1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bnorm2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
    )
    (1): Block(
      (time_mlp): Linear(in_features=32, out_features=256, bias=True)
      (conv1): Conv1d(128, 256, kernel_size=(3,), stride=(1,), padding=(1,))
      (transform): Conv1d(256, 256, kern

In [7]:
class EMA:
    def __init__(self, beta):
        super().__init__()
        self.beta = beta
        self.step = 0

    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

    def step_ema(self, ema_model, model, step_start_ema=2000):
        if self.step < step_start_ema:
            self.reset_parameters(ema_model, model)
            self.step += 1
            return
        self.update_model_average(ema_model, model)
        self.step += 1

    def reset_parameters(self, ema_model, model):
        ema_model.load_state_dict(model.state_dict())

In [8]:
class Diffusion:
    def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, audio_length=256, device="cuda"):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end

        self.beta = self.prepare_noise_schedule().to(device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

        self.audio_length = audio_length
        self.device = device

    def prepare_noise_schedule(self):
        return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)

    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        Ɛ = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ

    def sample_timesteps(self, n):
        return torch.randint(low=1, high=self.noise_steps, size=(n,))

    def sample(self, model, n, labels, cfg_scale=3):
        logging.info(f"Sampling {n} new audio....")
        model.eval()
        with torch.no_grad():
            x = torch.randn((n, 2, 256)).to(self.device)
            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
                t = (torch.ones(n) * i).long().to(self.device)
                predicted_noise = model(x, t, labels)
                if cfg_scale > 0:
                    uncond_predicted_noise = model(x, t, None)
                    predicted_noise = torch.lerp(uncond_predicted_noise, predicted_noise, cfg_scale)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
        model.train()
        #x = (x.clamp(-1, 1) + 1) / 2
        #x = (x * 255).type(torch.uint8)
        return x

In [32]:
from torch.optim import Adam
import copy
import torch.nn.functional as F
torch.autograd.set_detect_anomaly(True)
def train(epochs,dataloader):
    #setup_logging(args.run_name)
    #device = args.device
    #dataloader = get_data(args)
    model = SimpleUnet_cond(num_classes=36).to(device)
    optimizer = Adam(model.parameters(), lr=0.001)
    #mse = nn.MSELoss()
    diffusion = Diffusion(device=device)
    #logger = SummaryWriter(os.path.join("runs", args.run_name))
    l = len(dataloader)
    ema = EMA(0.995)
    ema_model = copy.deepcopy(model).eval().requires_grad_(False)

    for epoch in range(epochs):
        logging.info(f"Starting epoch {epoch}:")
        #pbar = tqdm(dataloader)
        for step, batch in enumerate(dataloader):
            audio = batch['audio'].to(device)
            labels = batch['measurement_point'].to(device)
            t = diffusion.sample_timesteps(audio.shape[0]).to(device)
            x_t, noise = diffusion.noise_images(audio, t)
            if np.random.random() < 0.1:
                labels = None
            predicted_noise = model(x_t, t, labels)
            loss = F.l1_loss(noise, predicted_noise)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            ema.step_ema(ema_model, model)

            #pbar.set_postfix(L1_Loss=loss.item())
            #logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i)

        if epoch % 10 == 0:
            labels = torch.arange(10).long().to(device)
            sampled_images = diffusion.sample(model, n=len(labels), labels=labels)
            ema_sampled_images = diffusion.sample(ema_model, n=len(labels), labels=labels)
            #plot_images(sampled_images)
            #save_images(sampled_images, os.path.join("results", args.run_name, f"{epoch}.jpg"))
            #save_images(ema_sampled_images, os.path.join("results", args.run_name, f"{epoch}_ema.jpg"))
            #torch.save(model.state_dict(), os.path.join("models", args.run_name, f"ckpt.pt"))
            #torch.save(ema_model.state_dict(), os.path.join("models", args.run_name, f"ema_ckpt.pt"))
            #torch.save(optimizer.state_dict(), os.path.join("models", args.run_name, f"optim.pt"))

In [33]:
train(epochs=100, dataloader=dataloader)

Time embedding shape: torch.Size([1, 32])
before conv0 shape: torch.Size([1, 2, 256])
After conv0 shape: torch.Size([1, 64, 256])


  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/nas/home/jalbarracin/miniconda3/lib/python3.11/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/nas/home/jalbarracin/miniconda3/lib/python3.11/site-packages/traitlets/config/application.py", line 1053, in launch_instance
    app.start()
  File "/nas/home/jalbarracin/miniconda3/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 737, in start
    self.io_loop.start()
  File "/nas/home/jalbarracin/miniconda3/lib/python3.11/site-packages/tornado/platform/asyncio.py", line 195, in start
    self.asyncio_loop.run_forever()
  File "/nas/home/jalbarracin/miniconda3/lib/python3.11/asyncio/base_events.py", line 607, in run_forever
    self._run_once()
  File "/nas/home/jalbarracin/miniconda3/lib/python3.11/asyncio/base_events.py", line 1922, in _run_once
    handle._run()
  File "/nas/home/jalbarracin/miniconda3/lib/python3.11/

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 32]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!