In [6]:
import time
from wavenet_model import *
from audio_data import WavenetDataset
from wavenet_training import *
from model_logging import *
from scipy.io import wavfile
import torch
import numpy as np
from model_logging import TensorboardLogger
import torch.nn.functional as F
from tqdm import tqdm
import os

# Debug printing function
debug_enabled = False
def debug_print(msg):
    if debug_enabled:
        print(msg)

def set_debug(enabled):
    global debug_enabled
    debug_enabled = enabled

# Set debug printing off by default
set_debug(False)

def get_device():
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    return "cpu"

# Check device availability
print(f"MPS available: {torch.backends.mps.is_available()}")
print(f"MPS built: {torch.backends.mps.is_built()}")
print(f"CUDA available: {torch.cuda.is_available()}")

device = get_device()
print(f"Using device: {device}")

MPS available: True
MPS built: True
CUDA available: False
Using device: mps


In [7]:
# Memory dataset for better performance
class MemoryDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.length = len(dataset)
        self.train = dataset.train if hasattr(dataset, 'train') else True

        # Print sample item shape during initialization
        sample_item = self.dataset[0]
        print("\nDataset sample info:")
        print(f"Sample item type: {type(sample_item)}")
        if isinstance(sample_item, tuple):
            print(f"Sample x shape: {sample_item[0].shape}")
            print(f"Sample target shape: {sample_item[1].shape}")
        else:
            print(f"Sample shape: {sample_item.shape}")

    def __getitem__(self, idx):
        item = self.dataset[idx]
        # Print shape occasionally for debugging
        if idx % 1000 == 0 and debug_enabled:
            if isinstance(item, tuple):
                print(f"\nBatch {idx} shapes:")
                print(f"x: {item[0].shape}")
                print(f"target: {item[1].shape}")
        return item

    def __len__(self):
        return self.length

In [8]:
model = WaveNetModel(layers=6,
                     blocks=4,
                     dilation_channels=16,
                     residual_channels=16,
                     skip_channels=32,
                     output_length=8,
                     bias=False)

# Load latest model if available
model = load_latest_model_from('snapshots', use_cuda=False)

# Ensure model is on correct device
model = model.to(device)
print(f"Model device after transfer: {next(model.parameters()).device}")
print(f"Start conv weight device: {model.start_conv.weight.device}")

print('model: ', model)
print('receptive field: ', model.receptive_field)
print('parameter count: ', model.parameter_count())

Initializing WaveNetModel...
Parameters: layers=6, blocks=4, dilation_channels=16, residual_channels=16, skip_channels=32, classes=256, output_length=8
Creating start convolution...
Creating dilated convolution 0...
Creating dilated convolution 1...
Creating dilated convolution 2...
Creating dilated convolution 3...
Creating dilated convolution 4...
Creating dilated convolution 5...
Creating dilated convolution 0...
Creating dilated convolution 1...
Creating dilated convolution 2...
Creating dilated convolution 3...
Creating dilated convolution 4...
Creating dilated convolution 5...
Creating dilated convolution 0...
Creating dilated convolution 1...
Creating dilated convolution 2...
Creating dilated convolution 3...
Creating dilated convolution 4...
Creating dilated convolution 5...
Creating dilated convolution 0...
Creating dilated convolution 1...
Creating dilated convolution 2...
Creating dilated convolution 3...
Creating dilated convolution 4...
Creating dilated convolution 5...
Wa

In [9]:
# Create dataset
data = WavenetDataset(dataset_file='train_samples/bach_chaconne/dataset.npz',
                      item_length=model.receptive_field + model.output_length - 1,
                      target_length=model.output_length,
                      file_location='train_samples/bach_chaconne',
                      test_stride=20)

# Load dataset file and print info
print("Loading dataset file:", data.dataset_file)
with np.load(data.dataset_file) as dataset:
    print("Available keys in dataset:", dataset.files)
    data.data = dataset['arr_0']

print('the dataset has ' + str(len(data)) + ' items')
print(f"Dataset type: {type(data.data)}")
print(f"Dataset shape: {data.data.shape}")
print(f"Dataset dtype: {data.data.dtype}")

Raw data shape: (9594672,), dtype: uint8
Converted data shape: (9594672,), dtype: int64
Loading dataset file: train_samples/bach_chaconne/dataset.npz
Available keys in dataset: ['arr_0']
the dataset has 479579 items
Dataset type: <class 'numpy.ndarray'>
Dataset shape: (9594672,)
Dataset dtype: uint8


In [10]:
# Create memory dataset for better performance
memory_dataset = MemoryDataset(data)

# Create trainer with memory_dataset
trainer = WavenetTrainer(
    model=model,
    dataset=memory_dataset,
    batch_size=2, # JOS: was 8
    val_batch_size=2, # JOS: was 32
    val_subset_size=500,
    lr=0.001,
    weight_decay=0.0,
    gradient_clipping=1,
    snapshot_interval=1000,
    snapshot_path='snapshots',
    val_interval=1000
)


Dataset sample info:
Sample item type: <class 'tuple'>
Sample x shape: torch.Size([256, 3085])
Sample target shape: torch.Size([16])

splitting dataset:
  total size: 479579
  train size: 431621
  val size: 47958


NameError: name 'true' is not defined

In [6]:
# Updated generate_audio function
def generate_audio(model, length=16000, temperature=1.0):
    """
    Generate audio samples using the trained model.
    Args:
        model: Trained WaveNet model
        length: Number of samples to generate
        temperature: Controls randomness (higher = more random, lower = more deterministic)
    """
    model.eval()  # Set to evaluation mode
    
    # Move model to device if not already
    model = model.to(device)

    # Start with zeros
    current_sample = torch.zeros(1, 1, model.receptive_field).to(device)
    generated_samples = []

    print(f"\nGenerating {length} samples...")

    with torch.no_grad():
        for i in tqdm(range(length)):
            # Get model prediction
            output = model(current_sample)

            # Apply temperature
            if temperature != 1:
                output = output / temperature

            # Sample from the output distribution
            probabilities = F.softmax(output[:, :, -1], dim=1)
            next_sample = torch.multinomial(probabilities, 1)

            # Append to generated samples
            generated_samples.append(next_sample.item())

            # Shift input window and add new sample
            current_sample = torch.roll(current_sample, -1, dims=2)
            current_sample[0, 0, -1] = next_sample

    # Convert to numpy array
    samples = np.array(generated_samples, dtype=np.int16)

    # Scale back to audio range
    samples = samples - 2**15

    return samples

In [7]:
# Updated generate_and_log_samples function
def generate_and_log_samples(step):
    sample_length = 4000
    gen_model = load_latest_model_from('snapshots')
    gen_model = gen_model.to(device)
    
    print("Start generating samples for logging...")
    
    # Generate with temperature 0
    samples_0 = generate_audio(gen_model, length=sample_length, temperature=0)
    
    # Generate with temperature 0.5
    samples_05 = generate_audio(gen_model, length=sample_length, temperature=0.5)
    
    # Log samples if TensorFlow is available
    try:
        import tensorflow as tf
        tf_samples_0 = tf.convert_to_tensor(samples_0, dtype=tf.float32)
        logger.audio_summary('temperature 0', tf_samples_0, step, sr=16000)
        
        tf_samples_05 = tf.convert_to_tensor(samples_05, dtype=tf.float32)
        logger.audio_summary('temperature 0.5', tf_samples_05, step, sr=16000)
    except ImportError:
        print("TensorFlow not available, skipping audio logging")
        # Save as WAV files instead
        wavfile.write(f'generated_temp0_step{step}.wav', 16000, samples_0)
        wavfile.write(f'generated_temp05_step{step}.wav', 16000, samples_05)
        
    print("Audio clips generated")

In [None]:
# Set up TensorboardLogger
try:
    logger = TensorboardLogger(log_interval=200,
                            validation_interval=200,
                            generate_interval=500,
                            generate_function=generate_and_log_samples,
                            log_dir="logs")
    print("TensorboardLogger initialized successfully")
except Exception as e:
    print(f"Error initializing TensorboardLogger: {e}")
    print("Continuing without logging...")

In [None]:
# Start training with error handling
print('\nStarting training...')
tic = time.time()
try:
    trainer.train(epochs=2) # JOS: was 20
except Exception as e:
    print(f"\nError during training:")
    print(f"Error type: {type(e).__name__}")
    print(f"Error message: {str(e)}")
    print(f"Model device: {next(model.parameters()).device}")
    raise
toc = time.time()
print('Training took {} seconds.'.format(toc - tic))

In [None]:
# Set to evaluation mode
memory_dataset.train = False
trainer.dataloader.dataset.train = False

print("Dataloader length:", len(trainer.dataloader))
print("Test dataset length:", len(memory_dataset))
print("Sample length:", data.item_length)

In [None]:
# Generate audio samples
# Get a starting sample from the dataset
start_data = data[100][0]
start_data = torch.max(start_data, 0)[1]
print("Starting data shape:", start_data.shape)
print("Starting data sample:", start_data)

In [None]:
# Progress callback for generation
def prog_callback(step, total_steps):
    print(str(100 * step // total_steps) + "% generated")

# Make sure model is in evaluation mode
model.eval()

# Generate audio using the fast method
try:
    # Ensure dilated queues are on the correct device
    for q in model.dilated_queues:
        q = q.to(device)
        
    generated1 = model.generate_fast(num_samples=160000, 
                                    first_samples=start_data,
                                    progress_callback=prog_callback,
                                    progress_interval=1000,
                                    temperature=1.0)
except Exception as e:
    print(f"Error during fast generation: {e}")
    print("Falling back to standard generation...")
    generated1 = generate_audio(model, length=160000, temperature=1.0)

In [None]:
# Play the generated audio
import IPython.display as ipd

ipd.Audio(generated1, rate=16000)

In [None]:
# Save the generated audio
output_file = "generated_audio.wav"
wavfile.write(output_file, 16000, generated1)
print(f"Generated audio saved to: {output_file}")

In [None]:
# Visualize the generated audio
%matplotlib inline
from matplotlib import pyplot as plt
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(generated1); ax1.set_title('Raw audio signal')
ax2.specgram(generated1, Fs=16000); ax2.set_title('Spectrogram');