## Environment

In [None]:
import os
import sys
import tempfile
import torch
import librosa
import numpy as np
import torchaudio
import soundfile as sf
from tqdm import tqdm, trange

# DDP IMPORTS
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

# !pip install tinytag          #TODO: change tinytag to something with cross-platform support
# from tinytag import TinyTag

# HANNA AND ANNA ENV
# drive.mount('/content/drive')
# device = "cuda" if torch.cuda.is_available() else "cpu"
# fma_small_path = 'drive/Shareddrives/Computer Audition Project/fma_small/'
# from google.colab import drive


# IZZY ENV
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") ## specify the GPU id's, GPU id's start from 0.
fma_small_path = r"C:\Users\ihargrav\Desktop\fma_small"

#SET UP CUDA ENVIRONMENT
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
print(f"device count: {torch.cuda.device_count()}")

device count: 2


## DataLoaders

In [None]:
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split

class FMADataset(Dataset):
  def __init__(self, directory, transform=None):
    self.directory = directory
    self.transform = transform
    self.file_names = [f for f in os.listdir(directory) if f.endswith('.mp3') and not f.startswith(".")]
    print(f"File names: {self.file_names}")

  def __len__(self):
    print(f"dataloader len: {len(self.file_names)}")
    return len(self.file_names)

  def __getitem__(self, idx):
    # print("entered get item")
    file_path = os.path.join(self.directory, self.file_names[idx])
    audio, sr = librosa.load(file_path, sr=None)
    # print("loaded file")

    # PREPROCESSING
    librosa.util.normalize(audio) # normalize

    # Resample
    if sr != 16000:
      audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
    # audio = np.pad(audio, (0, max(0, 16000*30 - len(audio))), mode='constant') # cut to 30s and pad with 0s
    # if audio.size > 16000*30:
    #   audio = audio[:(16000*30)]

    audio_tensor = torch.tensor(audio, dtype=torch.float32) # cast to tensor
    # audio_tensor = audio_tensor.unsqueeze(0)

    # Companding transforms
    audio_tensor = torch.div(audio_tensor, torch.max(torch.abs(audio_tensor))) # normalize
    transform = torchaudio.transforms.MuLawEncoding(quantization_channels=256)
    audio_tensor = transform(audio_tensor) # compand
    audio_tensor = torch.nn.functional.one_hot(audio_tensor, num_classes=256).to(torch.float32)

    return audio_tensor


#Instantiate dataset
dataset = FMADataset(fma_small_path)

#Split for train, valid, and test
train_size = int(len(dataset) * 0.8)
valid_size = (len(dataset) - train_size) // 2
test_size = len(dataset) - train_size - valid_size

train_dataset, valid_dataset, test_dataset = random_split(dataset, [train_size, valid_size, test_size])

#Instantiate dataloaders. Dimensions are {batch_size, length, channels}
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
print(f"train size: {len(train_dataloader)}")
valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)

# debug
# debug_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=False)


File names: ['000002.mp3', '000005.mp3', '000010.mp3', '000140.mp3', '000141.mp3', '000148.mp3', '000182.mp3', '000190.mp3', '000193.mp3', '000194.mp3', '000197.mp3', '000200.mp3', '000203.mp3', '000204.mp3', '000207.mp3', '000210.mp3', '000211.mp3', '000212.mp3', '000213.mp3', '000255.mp3', '000256.mp3', '000368.mp3', '000424.mp3', '000459.mp3', '000534.mp3', '000540.mp3', '000546.mp3', '000574.mp3', '000602.mp3', '000615.mp3', '000620.mp3', '000621.mp3', '000625.mp3', '000666.mp3', '000667.mp3', '000676.mp3', '000690.mp3', '000694.mp3', '000695.mp3', '000704.mp3', '000705.mp3', '000706.mp3', '000707.mp3', '000708.mp3', '000709.mp3', '000714.mp3', '000715.mp3', '000716.mp3', '000718.mp3', '000777.mp3', '000814.mp3', '000821.mp3', '000822.mp3', '000825.mp3', '000853.mp3', '000890.mp3', '000892.mp3', '000897.mp3', '000993.mp3', '000995.mp3', '000997.mp3', '000998.mp3', '001039.mp3', '001040.mp3', '001066.mp3', '001069.mp3', '001073.mp3', '001075.mp3', '001082.mp3', '001083.mp3', '001087

## WaveNet Building Blocks

In [None]:
class myCausalConv1d(torch.nn.Module):
  def __init__(self, in_channels, out_channels):
    '''
      in_channels   :   Number of features in the input signal
                        ex: a color image -> 3 in_channels (R,G,B)
                        ex: a black and white image -> 1 in_channel (black)

      out_channels  :   Number of channels produced by the convolution
    '''

    super(myCausalConv1d, self).__init__()

    # padding = (kernel_size - 1)*dilation + 1
    # (1 extra padding to ensure L > 0)
    # causal -> kernel_size = 1, dilation = 1
    # therefor, causal -> padding = 1
    self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=1, bias=False).to('cuda:0')

    torch.nn.init.xavier_normal_(self.conv.weight, gain=1)  # init weights

  def forward(self, x):
    output = self.conv(x.float())

    # model doesn't use the current sample when predicting the current sample
    return output[:,:,:-2]    # [N, C, L]


class myDilatedConv1d(torch.nn.Module):
  def __init__(self, channels, device, dilation=1):
    super(myDilatedConv1d, self).__init__()
    self.pad = dilation

    # padding = (kernel_size - 1)*dilation
    # kernel_size = 2 (always for wavenet)
    # therefor, padding = (2-1)*dilation = dilation
    self.conv = torch.nn.Conv1d(channels, channels, kernel_size=2, stride=1, dilation=dilation, padding=self.pad, bias=False).to(device)

    torch.nn.init.xavier_normal_(self.conv.weight, gain=1)  # init weights

  def forward(self, x):
    output = self.conv(x)

    # causal (asymmetric padding)
    return output[:,:,:-self.pad]   # [N, C, L]


class myResidualBlock(torch.nn.Module):
  def __init__(self, residual_channels, skip_channels, dilation, device):
    super(myResidualBlock, self).__init__()

    self.device = device

    self.dilated_conv = myDilatedConv1d(residual_channels, device, dilation=dilation).to(device)
    self.residual_conv = torch.nn.Conv1d(residual_channels, residual_channels, kernel_size=1).to(device)
    self.skip_conv = torch.nn.Conv1d(residual_channels, skip_channels, kernel_size=1).to(device)

    self.gate_tanh = torch.nn.Tanh()
    self.gate_sig = torch.nn.Sigmoid()

  def forward(self, x, skip_size):    # skip_size == the last output size (??)
    # Dilate
    dilated = self.dilated_conv(x)

    # Gating
    tanh_out = self.gate_tanh(dilated)
    sig_out = self.gate_sig(dilated)
    gated = tanh_out * sig_out

    # Residual
    output = self.residual_conv(gated)
    input_cut = x[:, :, -output.size(2):]   # ensure same dimensions
    output += input_cut

    # Skip
    skip = self.skip_conv(gated)
    skip = skip[:, :, -skip_size:]  # ensure same dimensions

    return output, skip   # [N, C, L]


class myResidualStack(torch.nn.Module):
  def __init__(self, layer_size, stack_size, residual_channels, skip_channels):
    super(myResidualStack, self).__init__()

    self.layer_size = layer_size    # 10 = layer[dilation=1, , 4, 8, 16, 32, 64, 128, 256, 512]
    self.stack_size = stack_size    # 5 = stack[layer1, layer2, layer3, layer4, layer5]

    self.residual_blocks = self.stack_blocks(residual_channels, skip_channels)


  def stack_blocks(self, residual_chan, skip_chan):
    residual_blocks = []
    dilations = self.make_dilations()

    devices = ['cuda:0', 'cuda:1']

    count = 0

    for d in dilations:
      # print(f"residual_chan: {residual_chan}, skip_chan: {skip_chan}, d: {d}")  # debug
      this_block = self.make_block(residual_chan, skip_chan, d, devices[count%2])
      residual_blocks.append(this_block)

      count += 1

    return residual_blocks

  def make_dilations(self):
    dilations = []  # 1, 2, 4, 8, 16, ...

    for s in range(self.stack_size):
      for l in range(self.layer_size):
        dilations.append(2 ** l)

    return dilations

  def make_block(self, residual_chan, skip_chan, dilation, device):
    block = myResidualBlock(residual_chan, skip_chan, dilation, device)
    return block


  def forward(self, x, skip_size):
    skip_connections = []
    output = x

    for block in self.residual_blocks:
      output, skip = block(output.to(block.device), skip_size)  #TODO: make sure this doesn't kill everything
      skip_connections.append(skip.to('cuda:0'))

    return torch.stack(skip_connections)  # [K, N, C, L]


class myOutConv(torch.nn.Module):
  def __init__(self, skip_channels, out_channels):
    super(myOutConv, self).__init__()

    # 1x1 convolutions
    self.conv1 = torch.nn.Conv1d(skip_channels, skip_channels, kernel_size=1)
    self.conv2 = torch.nn.Conv1d(skip_channels, out_channels, kernel_size=1)

    self.relu = torch.nn.ReLU()
    self.softmax = torch.nn.Softmax(dim=1)

    torch.nn.init.xavier_normal_(self.conv1.weight, gain=1)  # init weights
    self.conv1.bias.data.fill_(0)                            # init bias
    torch.nn.init.xavier_normal_(self.conv2.weight, gain=1)  # init weights
    self.conv2.bias.data.fill_(0)                            # init bias

  def forward(self, x):
    o = self.relu(x)
    o = self.conv1(o)

    o = self.relu(o)
    o = self.conv2(o)

    # DEBUG: TURNED OFF
    # output = self.softmax(o)
    output = o

    return output   # [N, C, L]


## WaveNet Model

In [None]:
class myWaveNet(torch.nn.Module):
  def __init__(self, layer_size=8, stack_size=4, in_channels=256, residual_channels=32, skip_channels=32):
    super(myWaveNet, self).__init__()

    self.stack_size = stack_size
    self.layer_size = layer_size
    self.residual_channels = residual_channels
    self.receptive_fields = np.sum([2 ** i for i in range(layer_size)] * self.stack_size)

    self.first_conv = myCausalConv1d(in_channels, residual_channels).to('cuda:0')
    self.residual_stack = myResidualStack(layer_size, stack_size, residual_channels, skip_channels)
    self.final_conv = myOutConv(skip_channels, in_channels).to('cuda:0')

  def forward(self, x):
    x = x.transpose(1, 2)   # [N, C, L]
    size = int(x.size(2))
    x = self.first_conv(x)  # [N, C, L]

    #TODO
    skip_connections = self.residual_stack(x, size)   # [K, N, C, L],  entire list on cuda:0
    output = torch.sum(skip_connections, dim=0)   # [N, C, L]

    output = self.final_conv(output)    # [N, C, L]
    output = output.transpose(1, 2)  # [N, L, C]

    return output


## Testing

In [None]:
print("Building Model...")
model = myWaveNet()


model_path = r"C:\Users\ihargrav\Desktop\checkpoints\savedModel_fmaSmall_best.pt"
checkpoint = torch.load(model_path)

model.load_state_dict(checkpoint['state_dict'])

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

print("Loading audio...")
audio_raw, sr = librosa.load(r"C:\Users\ihargrav\Desktop\fma_small\000002.mp3")  # TODO: get actual mp3 path
audio = librosa.util.normalize(audio_raw)

# Resample
if sr != 16000:
  audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
audio = np.pad(audio, (0, max(0, 16000*30 - len(audio))), mode='constant') # cut to 30s and pad with 0s
if audio.size > 16000*30:
  audio = audio[:(16000*30)]

audio_tensor = torch.tensor(audio, dtype=torch.float32) # cast to tensor
audio_tensor = audio_tensor.unsqueeze(0)

# Companding transforms
audio_tensor = torch.div(audio_tensor, torch.max(torch.abs(audio_tensor))) # normalize again

# compand
transform = torchaudio.transforms.MuLawEncoding(quantization_channels=256)
audio_tensor = transform(audio_tensor)

# one hot
audio_tensor = torch.nn.functional.one_hot(audio_tensor, num_classes=256).to(torch.float32)
# print(f"audio tensor: {audio_tensor.shape}")
audio_tensor = audio_tensor.to(device)

target = torch.clone(audio_tensor)
target = target.transpose(1, 2)

minValLoss = np.inf
loss_list = []
loss_txt_path = r"C:\Users\ihargrav\Desktop\checkpoints\loss.txt"
model.train()

output = model(audio_tensor)

decode = torchaudio.transforms.MuLawDecoding(quantization_channels=256)

# print(f"sequence: {sequence.shape}")
sequence_ = torch.argmax(output, axis=2)
# print(f"sequence_: {sequence_.shape}")
sequence_ = sequence_.squeeze()
print(f"sequence_: {sequence_.shape}")

audio = decode(sequence_)

audio = audio.to("cpu")

# print(audio.numpy())
print(audio)

sf.write(r"C:\Users\ihargrav\Desktop\audio_gen\test.wav", audio, 16000)

Building Model...


NameError: ignored

## Generation Model

In [None]:
class Generator(torch.nn.Module):
  def __init__(self, model, device, batch_size=1, channels=256):
    super(Generator, self).__init__()

    self.model = model
    self.residual_stack = model.get_submodule("residual_stack").residual_blocks
    self.channels = channels
    self.batch_size = batch_size
    self.stack_size = model.stack_size
    self.layer_size = model.layer_size

    self.device = device

    self.in_conv = self.model.get_submodule("first_conv").get_submodule("conv")
    self.out_conv = self.model.get_submodule("final_conv")


    self.queue_list = []    # {input, layer1, layer2, ...} == {[1], [2], [4], [8], ...}
    self.w_r = []           # weights to filter residual (sample from the queue)
    self.w_i = []           # weights to filter input to layer


    count = 0

    # For every layer...
    for b in range(self.stack_size):
      for l in range(self.layer_size):
        q_len = 2**l        # queue length == dilation size

        # Initialize queue for this layer
        self.queue_list.append(torch.zeros(q_len, batch_size, self.model.residual_channels).to(self.device)) # K-element list: [qL, N, C]

        # Recall dilation weights for this layer
        dil_conv = self.residual_stack[count].get_submodule("dilated_conv").get_submodule("conv")

        w_d = dil_conv.get_parameter("weight").to(self.device)  # [C, C, Kernel] == [32, 32, 2]

        # print(f"b: {b}, l: {l}, \nw_d: {w_d}")

        self.w_r.append(w_d[:,:,0])
        self.w_i.append(w_d[:,:,1])

        count += 1


  def push_back(self, queue, input):
    # if    :   q = [1, 2, 3, 4]  and  in = 9
    # then  :   [2, 3, 4, 9]

    # queue :   [qL, N, C]
    # input :   [N, C]

    input = input.unsqueeze(0)
    # print(queue)

    queue = queue[1:, :, :]
    queue = torch.cat((queue, input))

    # print(f"L: {len(queue)}")

    return queue

  def causal_lin(self, input, res, count, activation = None):
    # input  : [N, C, 1] *where C = 32*
    # res    : [N, C] *where C = 32*
    # b      : integer, number of blocks in model
    # l      : integer, number of layers in each block
    # output : [N, C, 1] *where C = 32*

    # Get weights
    wr = self.w_r[count]
    wi = self.w_i[count]

    # Apply weights
    input = input.squeeze(2)
    output = torch.matmul(input, wr) + torch.matmul(res, wi)

    # print(f"input: {input}\nres: {res}")

    # Non-linear activation
    if activation:
      output = activation(output)

    return output.unsqueeze(2)


  def predict(self, input):
    # input  :  [N, 256, 1]
    # pred   :  [N, 32, 1]
    # output :  [N, 256, 1]

    input = self.in_conv(input)             # [N, 256, 1] -> [N, 32, 3]
    input = input[:,:,1].unsqueeze(2)       # [N, 32, 3] -> [N, 32, 1]

    # print(f"\npredict input: {input}")

    count = 0

    # Single pass through the network
    for b in range(self.stack_size):
      for l in range(self.layer_size):
        residual_queue = self.queue_list[count]   # get this layer's queue
        # print(f"b: {b}, l: {l} \nqueue: {residual_queue}")

        # print(f"{count}, q: {residual_queue.shape}")
        pred = self.causal_lin(input, residual_queue[0], count, activation=None)   # calculate output of this layer

        # print(f"input: {input.squeeze(2)}")

        self.queue_list[count] = self.push_back(self.queue_list[count], input.squeeze(2))    # update this layer's queue

        input = pred  # propogate result up to next layer

        count += 1

    # print(f"\nprediction: {pred}")
    print(f"\nQUANT: {torch.argmax(pred, dim=1)}")

    output = pred
    output = torch.nn.functional.relu(pred)
    output = self.out_conv(output)    # [N, 32, 1] -> [N, 256, 1]
    output = torch.nn.functional.softmax(output, dim=1)

    return output


  def run(self, seq_len, input=None):
    # input   : [N, C] where C = 256
    # seq_len : integer, length of generated sequence

    predictions = []

    # Seed sample
    if(input is None):
      input = torch.zeros(self.batch_size, self.channels, 1)  # [N, 256, 1]
      input[:, 128, :] = 1

    input = input.to(self.device)

    # Save first sample
    print(f"\nBeginning sequence...")
    sample = self.predict(input)
    predictions.append(sample)

    # Generate sequence
    print(f"\nGenerating sequence...")
    with torch.no_grad():
      for s in range(seq_len - 1):
        print(f"\rSample {s+2} / {seq_len}", end='')

        arg_max = torch.argmax(sample, dim=1)
        input = torch.zeros(1, 256, 1).to('cuda:0') #FIXME :(((
        input[:, arg_max, :] = 1

        sample = self.predict(input)

        # print(f"sample shape: {sample.shape}")

        # print(f"\nsample: {torch.argmax(sample, dim=1)}")

        predictions.append(sample)

        # print(torch.cuda.memory_summary(self.device))

      print()

      # print(f"\nPREDICTIONS: {predictions[1].shape}")

      sequence = torch.cat(predictions, dim=0)

      return sequence.transpose(0,2)  # [N, C, L]

def temperature_sampling(logits, temperature=1.0):
    """
    Adjusts the logits based on the temperature and samples from the distribution for each time step.
    :param logits: Logits output from the model. Shape: [1, bit_depth, num_samples].
    :param temperature: Temperature parameter to control randomness. Higher temperature
                        increases randomness.
    :return: Sampled indices based on the adjusted distribution. Shape: [num_samples].
    """
    # Initialize an empty tensor to store the sampled indices
    sampled_indices = torch.zeros(logits.shape[2], dtype=torch.long)

    # Process each time step individually
    for i in range(logits.shape[2]):
        logits_sample = logits[:, :, i]  # Extract logits for the current sample

        if temperature != 1.0:
            # Adjust logits by the temperature
            logits_sample = logits_sample / temperature

        probabilities = torch.nn.functional.softmax(logits_sample, dim=1)
        # Create a categorical distribution and sample from it
        sample_dist = torch.distributions.Categorical(probs=probabilities)
        sampled_indices[i] = sample_dist.sample()

    return sampled_indices

## One Song Experiment (training)

In [None]:
print("Building Model...")
model = myWaveNet(stack_size=4, layer_size=8)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

print("Loading audio...")
audio_raw, sr = librosa.load(r"C:\Users\ihargrav\Desktop\piano.mp3")  # TODO: get actual mp3 path
audio = librosa.util.normalize(audio_raw)

# Resample
if sr != 16000:
  audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
audio = np.pad(audio, (0, max(0, 16000*30 - len(audio))), mode='constant') # cut to 30s and pad with 0s
if audio.size > 16000*30:
  audio = audio[:(16000*30)]

audio_tensor = torch.tensor(audio, dtype=torch.float32) # cast to tensor
audio_tensor = audio_tensor.unsqueeze(0)

# Companding transforms
audio_tensor = torch.div(audio_tensor, torch.max(torch.abs(audio_tensor))) # normalize again

# compand
transform = torchaudio.transforms.MuLawEncoding(quantization_channels=256)
audio_tensor = transform(audio_tensor)

# one hot
audio_tensor = torch.nn.functional.one_hot(audio_tensor, num_classes=256).to(torch.float32)
print(f"audio tensor: {audio_tensor.shape}")
audio_tensor = audio_tensor.to('cuda:0')


target = torch.clone(audio_tensor)
target = target.transpose(1, 2)

minValLoss = np.inf
loss_list = []
loss_txt_path = r"C:\Users\ihargrav\Desktop\checkpoints\loss.txt"
model.train()

# Training loop
for epoch in range(10000):

  print(f"epoch: {epoch}")

  model.zero_grad()

  output = model(audio_tensor)
  print("finished forward pass")
  output = output.transpose(1,2)

  loss = criterion(output, target)

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  # scheduler.step()

  print(f"loss: {loss}")

  checkpoint = {
      'state_dict': model.state_dict(),
      'minValLoss': minValLoss
  }
  torch.save(checkpoint, r'C:\Users\ihargrav\Desktop\checkpoints\full_model\ioConv_last.pt')

  count = 0
  for block in model.get_submodule("residual_stack").residual_blocks:
    path = (f"C:/Users/ihargrav/Desktop/checkpoints/full_model/b{count}_last.pt")
    torch.save(block.state_dict(), path)
    count += 1

  # TODO: write real validation method for whole dataset
  if loss < minValLoss:
    minValLoss = loss
    checkpoint = {
      'state_dict': model.state_dict(),
      'minValLoss': minValLoss
    }
    torch.save(checkpoint, r'C:\Users\ihargrav\Desktop\checkpoints\full_model\ioConv_best.pt')

    count = 0
    for block in model.get_submodule("residual_stack").residual_blocks:
      path = (f"C:/Users/ihargrav/Desktop/checkpoints/full_model/b{count}_best.pt")
      torch.save(block.state_dict(), path)
      count += 1

  loss_list.append(float(loss))
  with open(loss_txt_path, 'w') as file:
    for item in loss_list:
      file.write("%s\n" % item)



  # t.set_description(f"epoch : {epoch}, loss {loss}")

print("\nModel Finished")
print(f"output: {output.shape}")


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
loss: 5.204665660858154
epoch: 10
finished forward pass
loss: 5.176656246185303
epoch: 11
finished forward pass
loss: 5.145015716552734
epoch: 12
finished forward pass
loss: 5.109809875488281
epoch: 13
finished forward pass
loss: 5.071117401123047
epoch: 14
finished forward pass
loss: 5.028620719909668
epoch: 15
finished forward pass
loss: 4.982072830200195
epoch: 16
finished forward pass
loss: 4.931706428527832
epoch: 17
finished forward pass
loss: 4.878355026245117
epoch: 18
finished forward pass
loss: 4.823121070861816
epoch: 19
finished forward pass
loss: 4.766985893249512
epoch: 20
finished forward pass
loss: 4.710554122924805
epoch: 21
finished forward pass
loss: 4.654173374176025
epoch: 22
finished forward pass
loss: 4.598211288452148
epoch: 23
finished forward pass
loss: 4.543301105499268
epoch: 24
finished forward pass
loss: 4.490271091461182
epoch: 25
finished forward pass
loss: 4.439911842346191
epoch: 26
finis

KeyboardInterrupt: ignored

## One Song Experiment (generation)

In [None]:
# Load Model
model_path = "C:/Users/ihargrav/Desktop/checkpoints/full_model/ioConv_best.pt"
checkpoint = torch.load(model_path)

model_ref = myWaveNet(stack_size=4, layer_size=8)

# Load first and last convolutions
model_ref.load_state_dict(checkpoint['state_dict'])

# Load residual blocks
blocks_this = model_ref.get_submodule("residual_stack").residual_blocks

for idx in range(len(blocks_this)):
  path = (f"C:/Users/ihargrav/Desktop/checkpoints/full_model/b{idx}_best.pt")
  blocks_ref_dict = torch.load(path)
  blocks_this[idx].load_state_dict(blocks_ref_dict)

model_ref.eval()

# Build Generator
generator = Generator(model_ref, device='cuda:0').to('cuda:0')

# Generate novel sequence
sequence = generator.run(16000)  # prob distributions
print(sequence.shape)

# Apply temperature sampling to your sequence
temperature = 1.0
# sequence_ = temperature_sampling(sequence, temperature=temperature)

sequence_ = torch.argmax(sequence, axis=1)
sequence_ = sequence_.squeeze()

# Convert to audio
decode = torchaudio.transforms.MuLawDecoding(quantization_channels=256)
audio = decode(sequence_)

audio = audio.to("cpu")

# print(audio)

sf.write(r"C:\Users\ihargrav\Desktop\audio_gen\test_generated.wav", audio, 16000)

NameError: ignored

In [None]:
## GET AUDIO DEBUG

# Load Model
model_path = "C:/Users/ihargrav/Desktop/checkpoints/demo_singSong_longField/demo_singSong_longField_full_model/ioConv_best.pt"
checkpoint = torch.load(model_path)

model_test = myWaveNet(stack_size=1, layer_size=24)
model_test.load_state_dict(checkpoint['state_dict'])    # Load first and last convolutions!!!

# Load residual blocks!!!
blocks_this = model_test.get_submodule("residual_stack").residual_blocks

for idx in range(len(blocks_this)):
  path = (f"C:/Users/ihargrav/Desktop/checkpoints/demo_singSong_longField/demo_singSong_longField_full_model/b{idx}_best.pt")
  blocks_ref_dict = torch.load(path)
  blocks_this[idx].load_state_dict(blocks_ref_dict)


model_test.eval()


# Load audio
print("Loading audio...")
audio_raw, sr = librosa.load(r"C:\Users\ihargrav\Desktop\fma_small\000002.mp3")
audio = librosa.util.normalize(audio_raw)

# Resample
if sr != 16000:
  audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
audio = np.pad(audio, (0, max(0, 16000*30 - len(audio))), mode='constant') # cut to 30s and pad with 0s
if audio.size > 16000*30:
  audio = audio[:(16000*30)]

audio_tensor = torch.tensor(audio, dtype=torch.float32) # cast to tensor
audio_tensor = audio_tensor.unsqueeze(0)

# Companding transforms
audio_tensor = torch.div(audio_tensor, torch.max(torch.abs(audio_tensor))) # normalize again

# compand
transform = torchaudio.transforms.MuLawEncoding(quantization_channels=256)
audio_tensor = transform(audio_tensor)

# one hot
audio_tensor = torch.nn.functional.one_hot(audio_tensor, num_classes=256).to(torch.float32)
audio_tensor = audio_tensor.to('cuda:0')

output = model_test(audio_tensor)

# Convert to audio
decode = torchaudio.transforms.MuLawDecoding(quantization_channels=256)

# print(f"sequence: {sequence.shape}")
sequence_ = torch.argmax(output, axis=2)
# print(f"sequence_: {sequence_.shape}")
sequence_ = sequence_.squeeze()
print(f"sequence_: {sequence_.shape}")

audio = decode(sequence_)

audio = audio.to("cpu")
print(audio)

sf.write(r"C:\Users\ihargrav\Desktop\audio_gen\test_benchmark.wav", audio, 16000)



## FMA Training

In [None]:
model = myWaveNet(stack_size=4, layer_size=8)

# COMMENT OUT TO START FRESH, UNCOMMENT TO PICK UP ===========================
model_path = "C:/Users/ihargrav/Desktop/checkpoints/full_model/ioConv_best.pt"
checkpoint = torch.load(model_path)
print('loaded model path')
# Load first and last convolutions
model.load_state_dict(checkpoint['state_dict'])

# Load residual blocks
blocks_this = model.get_submodule("residual_stack").residual_blocks

for idx in range(len(blocks_this)):
  path = (f"C:/Users/ihargrav/Desktop/checkpoints/full_model/b{idx}_best.pt")
  blocks_ref_dict = torch.load(path)
  blocks_this[idx].load_state_dict(blocks_ref_dict)
print('got residual blocks')
# COMMENT OUT TO START FRESH, UNCOMMENT TO PICK UP ===========================

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

minValLoss = np.inf
loss_list = []
loss_txt_path = r"C:\Users\ihargrav\Desktop\checkpoints\loss_fmaSmall.txt"
model.train()

## TRAINING LOOP

for epoch in range(2000):
  iterator = iter(train_dataloader)
  with trange(len(train_dataloader)) as t:
    for idx in t:
      try:
        # Load sample and target
        sample = next(iterator)
        sample = sample.to('cuda:0')
        target = torch.clone(sample).transpose(1, 2)

        # print(f"epoch: {epoch}")

        # Forward pass
        model.zero_grad()

        output = model(sample)
        # print("finished forward pass")
        output = output.transpose(1,2)

        # Backprop
        loss = criterion(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # print(f"loss: {loss}")

        # Save model (last)
        checkpoint = {
          'state_dict': model.state_dict(),
          'minValLoss': minValLoss
        }
        torch.save(checkpoint, 'C:/Users/ihargrav/Desktop/checkpoints/full_model/ioConv_last.pt')

        count = 0
        for block in model.get_submodule("residual_stack").residual_blocks:
          path = (f"C:/Users/ihargrav/Desktop/checkpoints/full_model/b{count}_last.pt")
          torch.save(block.state_dict(), path)
          count += 1

        # save model (best)
        if loss < minValLoss:
          minValLoss = loss
          checkpoint = {
            'state_dict': model.state_dict(),
            'minValLoss': minValLoss
          }
          torch.save(checkpoint, r'C:\Users\ihargrav\Desktop\checkpoints\full_model\ioConv_best.pt')

          count = 0
          for block in model.get_submodule("residual_stack").residual_blocks:
            path = (f"C:/Users/ihargrav/Desktop/checkpoints/full_model/b{count}_best.pt")
            torch.save(block.state_dict(), path)
            count += 1

        # Update loss.txt
        loss_list.append(float(loss))
        with open(loss_txt_path, 'w') as file:
          for item in loss_list:
            file.write("%s\n" % item)

        t.set_description(f"epoch: {epoch}, loss: {loss}")
      except:
        continue
print("\nModel Finished")
print(f"output: {output.shape}")

loaded model path
got residual blocks


  audio, sr = librosa.load(file_path, sr=None)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)
epoch: 0, loss: 3.3234918117523193: 100%|████████████████████████████████████████| 6400/6400 [4:08:47<00:00,  2.33s/it]
epoch: 1, loss: 3.2723100185394287: 100%|████████████████████████████████████████| 6400/6400 [4:12:13<00:00,  2.36s/it]
epoch: 2, loss: 4.3123884201049805: 100%|████████████████████████████████████████| 6400/6400 [4:10:33<00:00,  2.35s/it]
epoch: 3, loss: 3.4535350799560547: 100%|████████████████████████████████████████| 6400/6400 [4:12:12<00:00,  2.36s/it]
epoch: 4, loss: 4.760231018066406: 100%|█████████████████████████████████████████| 6400/6400 [4:13:45<00:00,  2.38s/it]
epoch: 5, loss: 3.262321949005127: 100%|█████████████████████████████████████████| 6400/6400 [4:17:30<00:00,  2.41s/it]
epoch: 6, loss: 3.8314266204833984: 100%|████████████████████████████████████████|

## FMA Validation

In [None]:
# Load Model
model_path = "C:/Users/ihargrav/Desktop/checkpoints/full_model/ioConv_best.pt"
checkpoint = torch.load(model_path)
criterion = torch.nn.CrossEntropyLoss()
model = myWaveNet(stack_size=4, layer_size=8)

# Load first and last convolutions
model.load_state_dict(checkpoint['state_dict'])

# Load residual blocks
blocks_this = model.get_submodule("residual_stack").residual_blocks

for idx in range(len(blocks_this)):
  path = (f"C:/Users/ihargrav/Desktop/checkpoints/full_model/b{idx}_best.pt")
  blocks_ref_dict = torch.load(path)
  blocks_this[idx].load_state_dict(blocks_ref_dict)

model.eval()

## VALIDATION LOOP

val_loss_list = []
val_loss_txt_path = r"C:\Users\ihargrav\Desktop\checkpoints\val_loss_fmaSmall.txt"

val_loss = 0
model.eval()
model.zero_grad()

for num, sample in enumerate(valid_dataloader):
  # Load sample and target

  sample = sample.to('cuda:0')
  target = torch.clone(sample).transpose(1, 2)

  # Forward pass
  model.zero_grad()

  output = model(sample)
  # print("finished forward pass")
  output = output.transpose(1,2)

  # Update loss
  val_loss += criterion(output, target).item()

  print(f"\r\nCumulative loss: {val_loss}", end='')

  val_loss_list.append(float(val_loss))
  with open(val_loss_txt_path, 'w') as file:
    for item in val_loss_list:
      file.write("%s\n" % item)

print(f"Loss: {val_loss / len(valid_dataloader)}")

# Save model (best)
if val_loss < minValLoss:
  minValLoss = val_loss
  checkpoint = {
    'state_dict': model.state_dict(),
    'minValLoss': minValLoss
  }
  torch.save(checkpoint, r'C:\Users\ihargrav\Desktop\checkpoints\full_model\ioConv_best.pt')

  count = 0
  for block in model.get_submodule("residual_stack").residual_blocks:
    path = (f"C:/Users/ihargrav/Desktop/checkpoints/full_model/b{count}_best.pt")
    torch.save(block.state_dict(), path)
    count += 1


Cumulative loss: 5.331389427185059

OutOfMemoryError: ignored

# Jingle Dataloader

In [None]:
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split

class JingleDataset(Dataset):
  def __init__(self, directory, transform=None):
    self.directory = directory
    self.transform = transform
    self.file_names = [f for f in os.listdir(directory) if f.endswith('.mp3') and not f.startswith(".")]
    print(f"File names: {self.file_names}")


    ######### NEW ###########
    # Create a mapping from mood names to numerical identifiers
    self.mood_enum = {'upb': 0, 'rlx': 1, 'plf': 2, 'mys': 3, 'mel': 4}

  def __len__(self):
    return len(self.file_names)

  def __getitem__(self, idx):
    file_path = os.path.join(self.directory, self.file_names[idx])
    #audio, sr = librosa.load(file_path, sr=None)
    audio, sr = sf.read(file_path)
    print(audio.shape)
    audio = audio[:,1]
    print(audio.shape)

    print(f"PATH: {file_path}\nAUDIO: {audio.shape}")

    # print("loaded file")
    # audio_metadata = TinyTag.get(file_path)

   ######### NEW ###########
    # GET THE MOOD
    name = self.file_names[idx]
    mood_name = name[:3]
    mood_tag = self.mood_enum.get(mood_name, -1)  # Use -1 for unknown moods

    print(f"file name: {self.file_names[idx]}")
    print(f"name = {mood_name}, tag = {mood_tag}")

    # PREPROCESSING
    librosa.util.normalize(audio) # normalize

    # Resample
    if sr != 16000:
      audio = librosa.resample(audio, orig_sr=sr, target_sr=16000)
    audio = np.pad(audio, (0, max(0, 16000*30 - len(audio))), mode='constant') # cut to 30s and pad with 0s
    if audio.size > 16000*30:
      audio = audio[:(16000*30)]

    audio_tensor = torch.tensor(audio, dtype=torch.float32) # cast to tensor

    # Companding transforms
    audio_tensor = torch.div(audio_tensor, torch.max(torch.abs(audio_tensor))) # normalize
    transform = torchaudio.transforms.MuLawEncoding(quantization_channels=256)
    audio_tensor = transform(audio_tensor) # compand
    #print(f"after companding transform: {audio_tensor.shape}")
    audio_tensor = torch.nn.functional.one_hot(audio_tensor, num_classes=256).to(torch.float32)



    return audio_tensor, mood_tag # , audio_metadata.genre

######### NEW ###########
#Instantiate dataset
dataset = JingleDataset(jingle_path)

# names = dataset.file_names;
# for n in names :
#
#    mood_name = n[:3]
#    mood_tag = dataset.mood_enum.get(mood_name, -1)  # Use -1 for unknown moods
#    print(f"filename = {n}, name = {mood_name}, tag = {mood_tag}")

#Split for train, valid, and test
train_size = int(len(dataset) * 0.8)
valid_size = (len(dataset) - train_size) // 2
test_size = len(dataset) - train_size - valid_size

train_dataset, valid_dataset, test_dataset = random_split(dataset, [train_size, valid_size, test_size])

#Instantiate dataloaders. Dimensions are {1, batch_size, sample_num}
# train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# valid_dataloader = DataLoader(valid_dataset, batch_size=32, shuffle=False)
# test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)

# debug
debug_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)

print(len(debug_dataloader.dataset))
print(len(debug_dataloader))

# for idx in range(len(train_dataset)):
print(train_dataset)
#for batch in debug_dataloader:



    # print(f"Sample {idx}: {sample}, Tag: {tag}")




# Jingle Building Blocks

In [None]:
class myCausalConv1d(torch.nn.Module):
  def __init__(self, in_channels, out_channels):
    '''
      in_channels   :   Number of features in the input signal
                        ex: a color image -> 3 in_channels (R,G,B)
                        ex: a black and white image -> 1 in_channel (black)

      out_channels  :   Number of channels produced by the convolution
    '''

    super(myCausalConv1d, self).__init__()

    # padding = (kernel_size - 1)*dilation + 1
    # (1 extra padding to ensure L > 0)
    # causal -> kernel_size = 1, dilation = 1
    # therefor, causal -> padding = 1
    self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=1, bias=False).to(device)

  def forward(self, x):
    output = self.conv(x.float())

    # model doesn't use the current sample when predicting the current sample
    return output[:,:,:-2]    # [N, C, L]


class myDilatedConv1d(torch.nn.Module):
  def __init__(self, channels, dilation=1):
    super(myDilatedConv1d, self).__init__()
    self.pad = dilation

    # padding = (kernel_size - 1)*dilation
    # kernel_size = 2 (always for wavenet)
    # therefor, padding = (2-1)*dilation = dilation
    self.conv = torch.nn.Conv1d(channels, channels, kernel_size=2, stride=1, dilation=dilation, padding=self.pad, bias=False).to(device)

  def forward(self, x):
    output = self.conv(x)

    # causal (asymmetric padding)
    return output[:,:,:-self.pad]   # [N, C, L]


class myResidualBlock(torch.nn.Module):
  def __init__(self, residual_channels, skip_channels, dilation):
    super(myResidualBlock, self).__init__()

    self.dilated_conv = myDilatedConv1d(residual_channels, dilation=dilation).to(device)
    self.residual_conv = torch.nn.Conv1d(residual_channels, residual_channels, kernel_size=1).to(device)
    self.skip_conv = torch.nn.Conv1d(residual_channels, skip_channels, kernel_size=1).to(device)

    self.gate_tanh = torch.nn.Tanh()
    self.gate_sig = torch.nn.Sigmoid()

  def forward(self, x, skip_size):    # skip_size == the last output size (??)
    # Dilate
    dilated = self.dilated_conv(x)

    # Gating
    tanh_out = self.gate_tanh(dilated)
    sig_out = self.gate_sig(dilated)
    gated = tanh_out * sig_out

    # Residual
    output = self.residual_conv(gated)
    input_cut = x[:, :, -output.size(2):]   # ensure same dimensions
    output += input_cut

    # Skip
    skip = self.skip_conv(gated)
    skip = skip[:, :, -skip_size:]  # ensure same dimensions

    return output, skip   # [N, C, L]


class myResidualStack(torch.nn.Module):
  def __init__(self, layer_size, stack_size, residual_channels, skip_channels):
    super(myResidualStack, self).__init__()

    self.layer_size = layer_size    # 10 = layer[dilation=1, , 4, 8, 16, 32, 64, 128, 256, 512]
    self.stack_size = stack_size    # 5 = stack[layer1, layer2, layer3, layer4, layer5]

    self.residual_blocks = self.stack_blocks(residual_channels, skip_channels)


  def stack_blocks(self, residual_chan, skip_chan):
    residual_blocks = []
    dilations = self.make_dilations()

    for d in dilations:
      # print(f"residual_chan: {residual_chan}, skip_chan: {skip_chan}, d: {d}")  # debug
      this_block = self.make_block(residual_chan, skip_chan, d)
      residual_blocks.append(this_block)

    return residual_blocks

  def make_dilations(self):
    dilations = []  # 1, 2, 4, 8, 16, ...

    for s in range(self.stack_size):
      for l in range(self.layer_size):
        dilations.append(2 ** l)

    return dilations

  def make_block(self, residual_chan, skip_chan, dilation):
    block = myResidualBlock(residual_chan, skip_chan, dilation)
    return block


  def forward(self, x, skip_size):
    skip_connections = []
    output = x

    for block in self.residual_blocks:
      output, skip = block(output, skip_size)
      skip_connections.append(skip)

    return torch.stack(skip_connections)  # [K, N, C, L]


class myOutConv(torch.nn.Module):
  def __init__(self, skip_channels, out_channels):
    super(myOutConv, self).__init__()

    # 1x1 convolutions
    self.conv1 = torch.nn.Conv1d(skip_channels, skip_channels, kernel_size=1).to(device)
    self.conv2 = torch.nn.Conv1d(skip_channels, out_channels, kernel_size=1).to(device)

    self.relu = torch.nn.ReLU()
    self.softmax = torch.nn.Softmax(dim=1)

  def forward(self, x):
    o = self.relu(x)
    o = self.conv1(o)

    o = self.relu(o)
    o = self.conv2(o)

    output = self.softmax(o)

    return output   # [N, C, L]


# Jingle Model

In [None]:
class myWaveNet(torch.nn.Module):
  ######### NEW ########### -- added gc channels and gc cardinality as arguments
  def __init__(self, layer_size=10, stack_size=4, in_channels=256, residual_channels=32, skip_channels=32, global_condition_channels=None,global_condition_cardinality=None):
    super(myWaveNet, self).__init__()

    self.receptive_fields = np.sum([2 ** i for i in range(layer_size)] * stack_size)

    self.first_conv = myCausalConv1d(in_channels, residual_channels)
    self.residual_stack = myResidualStack(layer_size, stack_size, residual_channels, skip_channels)
    self.final_conv = myOutConv(skip_channels, in_channels)

    ######### NEW ###########
    # Embedding layer for global conditioning
    if global_condition_channels is not None and global_condition_cardinality is not None:
        self.embedding = torch.nn.Embedding(global_condition_cardinality, global_condition_channels)
    else:
        self.embedding = None
    # set gc channels and gc cardinality
    #self.global_condition_channels = global_condition_channels
    #self.global_condition_cardinality = global_condition_cardinality

  def forward(self, x, mood_tag=None):

    # print("\nInitial Transopose...")
    # print(f"in: {x.shape}")
    x = x.transpose(1, 2)   # [N, C, L]
    # print(f"out: {x.shape}")

    # Embed global condition if available
    if self.embedding is not None and mood_tag is not None:
      gc_embedding = self.embedding(mood_tag)
      print(f"x shape: {x.shape}")
      print(f"gc_embedding shape: {gc_embedding.shape}")

      gc_embedding = gc_embedding.unsqueeze(2).expand(-1, x.size(1), -1)

      print(f"gc_embedding shape after expansion: {gc_embedding.shape}")
      x = torch.cat([x,gc_embedding], dim=2)




    size = int(x.size(2))
    # print(f"size: {size}")

    # print("\nFirst Conv...")
    # print(f"in: {x.shape}")
    x = self.first_conv(x)  # [N, C, L]
    # print(f"out: {x.shape}")

    # Apply residual stack
    # print("\nResidual Stack...")
    # print(f"in: {x.shape}")
    skip_connections = self.residual_stack(x, size)   # [K, N, C, L]
    # print(f"out: {skip_connections.shape}")

    # print("\nSumming Residual Stack...")
    # print(f"in: {skip_connections.shape}")
    output = torch.sum(skip_connections, dim=0)   # [N, C, L]
    # print(f"out: {output.shape}")

    # print("\nFinal Conv...")
    # print(f"in: {output.shape}")
    output = self.final_conv(output)    # [N, C, L]
    # print(f"out: {output.shape}")

    # print("\nFinal Transpose...")
    # print(f"in: {output.shape}")
    output = output.transpose(1, 2)  # [N, L, C]
    # print(f"out: {output.shape}")

    return output


# Jingle Debug Training

In [None]:
print("Building Model...")
model = myWaveNet(global_condition_channels=1,global_condition_cardinality=5)
model.to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)

#for num, (sample, tag) in enumerate(debug_dataloader):
#  print(num)
#  print(f"sample={sample}, tag={tag}")

minValLoss = np.inf
loss_list = []
loss_txt_path = "loss.txt"
model.train()

# Training loop
for epoch in range(2):
  print(f"epoch: {epoch + 1}")
  for num, (sample, tag) in enumerate(debug_dataloader):
    print(f"SAMPLE: {sample}")
    sample = sample.to(device)
    tag = tag.to(device)
    target = torch.clone(sample).transpose(1, 2).to(device)

    print(f"epoch: {epoch+1}, batch number {num} of {len(debug_dataloader)}")

    model.zero_grad()

    print(f"Audio Shape: {sample.shape}")
    print(f"Audio: {sample}")
    print(f"Tag Shape: {tag.shape}")
    print(f"Tag: {tag}")

    output = model(sample, mood_tag = tag)
    print("finished forward pass")
    output = output.transpose(1,2)

    loss = criterion(output, target)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"loss: {loss}")

    loss_list.append(float(loss))
    with open(loss_txt_path, 'w') as file:
      for item in loss_list:
        file.write("%s\n" % item)


  checkpoint = {
    'state_dict': model.state_dict(),
    'minValLoss': minValLoss
  }
  torch.save(checkpoint, 'savedModel_jingleConditioning_last.pt')

  # TODO: write real validation method for whole dataset
  if loss < minValLoss:
    minValLoss = loss
    checkpoint = {
      'state_dict': model.state_dict(),
      'minValLoss': minValLoss
    }
    torch.save(checkpoint, 'savedModel_jingleConditioning_best.pt')


  # t.set_description(f"epoch : {epoch}, loss {loss}")

print("\nModel Finished")
print(f"output: {output.shape}")