In [None]:
# # This Python 3 environment comes with many helpful analytics libraries installed
# # It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# # For example, here's several helpful packages to load

# import numpy as np # linear algebra
# import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# # Input data files are available in the read-only "../input/" directory
# # For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# # You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# # You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install pydub

Audio file conversion script (.wav to mp3)

In [5]:
import os
from pydub import AudioSegment

#Two versions of audio created
input_dir = "/kaggle/input/smalldataset" 
output_dir = "/kaggle/working/mp3_converted"
#output_dir = "/kaggle/working/mp3_converted_192k"

# Desired MP3 bitrate (e.g., '128k', '192k', '320k')
bitrate = "128k"
#bitrate = "192k"

# --- Conversion Logic ---
print(f"Starting conversion of WAV files from: {input_dir}")

# 1. Create the output directory if it doesn't exist
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
    print(f"Created output directory: {output_dir}")

# 2. List all files in the input directory
try:
    files = os.listdir(input_dir)
except FileNotFoundError:
    print(f"ERROR: Input directory not found: {input_dir}")
    print("Please make sure you have attached your dataset to the notebook and the path is correct.")
    files = []

converted_count = 0
# 3. Loop through files and convert the .wav files
for filename in files:
    if filename.lower().endswith(".wav"):
        try:
            # Construct full file paths
            wav_path = os.path.join(input_dir, filename)
            mp3_filename = os.path.splitext(filename)[0] + ".mp3"
            mp3_path = os.path.join(output_dir, mp3_filename)
            
            # Load the WAV file
            audio = AudioSegment.from_wav(wav_path)
            
            # Export as MP3 with specified bitrate
            print(f"Converting {filename} to MP3...")
            audio.export(mp3_path, format="mp3", bitrate=bitrate)
            converted_count += 1
            
        except Exception as e:
            print(f"Could not convert {filename}. Error: {e}")

print(f"\nConversion complete. Converted {converted_count} files.")
print(f"MP3 files are saved in: {output_dir}")

Starting conversion of WAV files from: /kaggle/input/smalldataset
Created output directory: /kaggle/working/mp3_converted
Converting a3.wav to MP3...
Converting a1.wav to MP3...
Converting a4.wav to MP3...
Converting a2.wav to MP3...

Conversion complete. Converted 4 files.
MP3 files are saved in: /kaggle/working/mp3_converted


Data manifest file in csv: will pair each MP3 with its corresponding WAV file

In [7]:
import os
import pandas as pd

# --- Configuration ---
# Define the paths to your data directories in Kaggle
wav_dir = "/kaggle/input/smalldataset/"
mp3_128k_dir = "/kaggle/working/mp3_converted/"
mp3_192k_dir = "/kaggle/working/mp3_converted_192k/"

# Output path for the final CSV file
output_csv_path = "/kaggle/working/dataset_manifest.csv"

# --- Logic to Create Pairs ---
data_pairs = []

# We'll use the original WAV files as the source of truth
print(f"Scanning for WAV files in: {wav_dir}")
for wav_filename in os.listdir(wav_dir):
    if wav_filename.lower().endswith(".wav"):
        base_name = os.path.splitext(wav_filename)[0]
        wav_path = os.path.join(wav_dir, wav_filename)

        # --- Pair with 128k MP3 ---
        mp3_128k_path = os.path.join(mp3_128k_dir, f"{base_name}.mp3")
        if os.path.exists(mp3_128k_path):
            data_pairs.append({
                "input_path": mp3_128k_path,
                "target_path": wav_path,
                "bitrate_kbps": 128,
                "original_id": base_name
            })
        else:
            print(f"Warning: Could not find matching 128k MP3 for {wav_filename}")

        # --- Pair with 192k MP3 ---
        mp3_192k_path = os.path.join(mp3_192k_dir, f"{base_name}.mp3")
        if os.path.exists(mp3_192k_path):
            data_pairs.append({
                "input_path": mp3_192k_path,
                "target_path": wav_path,
                "bitrate_kbps": 192,
                "original_id": base_name
            })
        else:
            print(f"Warning: Could not find matching 192k MP3 for {wav_filename}")

# --- Create and Save the DataFrame ---
if data_pairs:
    df = pd.DataFrame(data_pairs)
    df.to_csv(output_csv_path, index=False)
    print(f"\nSuccessfully created manifest file with {len(df)} pairs.")
    print(f"CSV saved to: {output_csv_path}")
    
    # Display the first few rows of the created table
    print("\n--- CSV Preview ---")
    print(df.head())
else:
    print("\nNo data pairs were created. Please check your directory paths.")

Scanning for WAV files in: /kaggle/input/smalldataset/

Successfully created manifest file with 8 pairs.
CSV saved to: /kaggle/working/dataset_manifest.csv

--- CSV Preview ---
                                  input_path  \
0       /kaggle/working/mp3_converted/a3.mp3   
1  /kaggle/working/mp3_converted_192k/a3.mp3   
2       /kaggle/working/mp3_converted/a1.mp3   
3  /kaggle/working/mp3_converted_192k/a1.mp3   
4       /kaggle/working/mp3_converted/a4.mp3   

                         target_path  bitrate_kbps original_id  
0  /kaggle/input/smalldataset/a3.wav           128          a3  
1  /kaggle/input/smalldataset/a3.wav           192          a3  
2  /kaggle/input/smalldataset/a1.wav           128          a1  
3  /kaggle/input/smalldataset/a1.wav           192          a1  
4  /kaggle/input/smalldataset/a4.wav           128          a4  


In [None]:
!pip install librosa

Data Loader:
The core of the data loading process is a custom Dataset class. This class tells PyTorch three essential things:

__init__: How to initialize the dataset (e.g., by loading the CSV).

__len__: How many total items are in the dataset (the number of rows in the CSV).

__getitem__: How to get a single item (one input/target pair) from the dataset.

In [8]:
import torch
import torchaudio
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F # We need this for the padding function

# The Dataset class can remain the same as before
class AudioUpscalingDataset(Dataset):
    def __init__(self, manifest_path, target_sample_rate=44100):
        self.manifest = pd.read_csv(manifest_path)
        self.target_sample_rate = target_sample_rate

    def __len__(self):
        return len(self.manifest)

    def __getitem__(self, idx):
        input_path = self.manifest.iloc[idx]['input_path']
        target_path = self.manifest.iloc[idx]['target_path']
        
        try:
            input_waveform, orig_sr_input = torchaudio.load(input_path)
            target_waveform, orig_sr_target = torchaudio.load(target_path)
        except Exception as e:
            print(f"Error loading files at index {idx}: {e}")
            return torch.zeros(1, 1), torch.zeros(1, 1) # Return minimal tensor on error

        if orig_sr_input != self.target_sample_rate:
            resampler = torchaudio.transforms.Resample(orig_sr_input, self.target_sample_rate)
            input_waveform = resampler(input_waveform)
        
        if orig_sr_target != self.target_sample_rate:
            resampler = torchaudio.transforms.Resample(orig_sr_target, self.target_sample_rate)
            target_waveform = resampler(target_waveform)

        # We no longer need to manually fix lengths here. The collate_fn will handle it.
        return input_waveform, target_waveform

# --- NEW: The Custom Collate Function ---
def pad_collate_fn(batch):
    """
    Pads audio samples in a batch to the length of the longest sample.
    Args:
        batch: A list of tuples, where each tuple is (input_waveform, target_waveform).
    """
    # Separate the inputs and targets
    inputs, targets = zip(*batch)

    # Find the maximum length in the batch for inputs
    max_input_len = max(w.shape[1] for w in inputs)
    # Find the maximum length in the batch for targets
    max_target_len = max(w.shape[1] for w in targets)
    # Use the overall max length to be safe
    max_len = max(max_input_len, max_target_len)
    
    # Pad all inputs to the max_len
    # `pad` arguments are (left, right, top, bottom) for 2D tensors
    padded_inputs = torch.stack([
        F.pad(w, (0, max_len - w.shape[1])) for w in inputs
    ])

    # Pad all targets to the max_len
    padded_targets = torch.stack([
        F.pad(w, (0, max_len - w.shape[1])) for w in targets
    ])

    return padded_inputs, padded_targets

# --- UPDATED: Using the DataLoader ---

# 1. Define the path to your manifest file
manifest_file = "/kaggle/working/dataset_manifest.csv"

# 2. Create an instance of your custom Dataset
audio_dataset = AudioUpscalingDataset(manifest_path=manifest_file, target_sample_rate=44100)

# 3. Create the DataLoader, NOW WITH THE CUSTOM COLLATE FUNCTION
batch_size = 4
train_loader = DataLoader(
    audio_dataset, 
    batch_size=batch_size, 
    shuffle=True,
    collate_fn=pad_collate_fn  # <-- This is the crucial change!
)

# 4. Now, this loop will work without errors
print("Running DataLoader with padding collate function...")
for i, (inputs, targets) in enumerate(train_loader):
    print(f"Batch {i+1}:")
    print(f"  Input batch shape: {inputs.shape}")
    print(f"  Target batch shape: {targets.shape}")
    # All tensors in a batch will now have the same length.
    break

Running DataLoader with padding collate function...
Batch 1:
  Input batch shape: torch.Size([4, 2, 23653938])
  Target batch shape: torch.Size([4, 2, 23653938])


U-Net model with 1D convolutions

In [9]:
import torch
import torch.nn as nn

class ConvBlock(nn.Module):
    """A single encoder block: Conv1D -> BatchNorm -> LeakyReLU"""
    def __init__(self, in_channels, out_channels, kernel_size=15, stride=1, padding=7):
        super().__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm1d(out_channels)
        self.relu = nn.LeakyReLU(0.2)

    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

class UpConvBlock(nn.Module):
    """A single decoder block: Upsample -> Conv1D -> BatchNorm -> LeakyReLU"""
    def __init__(self, in_channels, out_channels, kernel_size=5, stride=1, padding=2):
        super().__init__()
        # Using ConvTranspose1d to upsample
        self.upconv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
        self.bn = nn.BatchNorm1d(out_channels)
        self.relu = nn.LeakyReLU(0.2)

    def forward(self, x):
        return self.relu(self.bn(self.upconv(x)))

class UNet(nn.Module):
    def __init__(self, in_channels=2, out_channels=2):
        super().__init__()

        # --- Encoder ---
        self.enc1 = ConvBlock(in_channels, 16)
        self.pool1 = nn.AvgPool1d(kernel_size=2, stride=2) # Downsample
        self.enc2 = ConvBlock(16, 32)
        self.pool2 = nn.AvgPool1d(kernel_size=2, stride=2)
        self.enc3 = ConvBlock(32, 64)
        self.pool3 = nn.AvgPool1d(kernel_size=2, stride=2)

        # --- Bottleneck ---
        self.bottleneck = ConvBlock(64, 128)

        # --- Decoder ---
        # The in_channels for the decoder is double because of the skip connection concatenation
        self.upconv3 = UpConvBlock(128, 64)
        self.dec3 = ConvBlock(128, 64) # 64 from upconv + 64 from enc3 skip
        self.upconv2 = UpConvBlock(64, 32)
        self.dec2 = ConvBlock(64, 32)  # 32 from upconv + 32 from enc2 skip
        self.upconv1 = UpConvBlock(32, 16)
        self.dec1 = ConvBlock(32, 16)  # 16 from upconv + 16 from enc1 skip
        
        # --- Output Layer ---
        self.final_conv = nn.Conv1d(16, out_channels, kernel_size=1)
        self.final_tanh = nn.Tanh()

    def forward(self, x):
        # --- Encoder Path ---
        e1 = self.enc1(x)
        p1 = self.pool1(e1)
        e2 = self.enc2(p1)
        p2 = self.pool2(e2)
        e3 = self.enc3(p2)
        p3 = self.pool3(e3)

        # --- Bottleneck ---
        b = self.bottleneck(p3)

        # --- Decoder Path with Skip Connections ---
        u3 = self.upconv3(b)
        # Concatenate skip connection from encoder
        skip3 = torch.cat([u3, e3], dim=1)
        d3 = self.dec3(skip3)

        u2 = self.upconv2(d3)
        skip2 = torch.cat([u2, e2], dim=1)
        d2 = self.dec2(skip2)

        u1 = self.upconv1(d2)
        skip1 = torch.cat([u1, e1], dim=1)
        d1 = self.dec1(skip1)
        
        # --- Final Output ---
        out = self.final_conv(d1)
        
        return self.final_tanh(out)

# --- How to use it ---
# Assuming stereo audio (2 channels)
model = UNet(in_channels=2, out_channels=2)

# Create a dummy input batch to test the model
# (batch_size, num_channels, sequence_length)
dummy_input = torch.randn(4, 2, 44100 * 2) # Batch of 4, 2-second stereo clips at 44.1kHz

output = model(dummy_input)

print(f"Model created successfully!")
print(f"Input shape: {dummy_input.shape}")
print(f"Output shape: {output.shape}")
# Note: The output length might be slightly different due to convolutions.
# You may need to pad the input or crop the output to ensure they match exactly.

Model created successfully!
Input shape: torch.Size([4, 2, 88200])
Output shape: torch.Size([4, 2, 88200])


In [None]:
# import torch
# import torch.nn as nn
# import torch.optim as optim
# from torch.utils.data import DataLoader

# # --- 1. Setup and Hyperparameters ---

# # Check for GPU availability and set the device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"Using device: {device}")

# # Hyperparameters
# NUM_EPOCHS = 25
# LEARNING_RATE = 1e-4 # A smaller learning rate is often good for U-Nets
# BATCH_SIZE = 4       # Adjust based on your GPU memory
# MANIFEST_FILE = "/kaggle/working/dataset_manifest.csv"
# MODEL_SAVE_PATH = "/kaggle/working/audio_upscaler_model.pth"

# # --- 2. Initialize Components ---

# # Instantiate the Dataset and DataLoader
# # (Assuming AudioUpscalingDataset and pad_collate_fn are defined in previous cells)
# audio_dataset = AudioUpscalingDataset(manifest_path=MANIFEST_FILE, target_sample_rate=44100)
# train_loader = DataLoader(
#     audio_dataset, 
#     batch_size=BATCH_SIZE, 
#     shuffle=True,
#     collate_fn=pad_collate_fn,
#     num_workers=2 # Use multiple cores to load data faster
# )

# # Instantiate the Model and move it to the selected device
# model = UNet(in_channels=2, out_channels=2).to(device)

# # Instantiate the Loss Function and Optimizer
# criterion = nn.L1Loss() # Mean Absolute Error is great for audio
# optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# print("Setup complete. Starting training...")

# # --- 3. The Training Loop ---

# for epoch in range(NUM_EPOCHS):
#     model.train()  # Set the model to training mode
#     running_loss = 0.0
    
#     # Loop over the data loader
#     for i, (inputs, targets) in enumerate(train_loader):
#         # Move tensors to the configured device (GPU or CPU)
#         inputs = inputs.to(device)
#         targets = targets.to(device)
        
#         # --- Forward Pass ---
#         # Get model outputs
#         outputs = model(inputs)
        
#         # --- Crop Output ---
#         # Ensure the output and target have the exact same length before calculating loss
#         # This handles any minor size differences from the U-Net's convolutions.
#         min_len = min(outputs.shape[2], targets.shape[2])
#         outputs = outputs[:, :, :min_len]
#         targets = targets[:, :, :min_len]
        
#         # --- Calculate Loss ---
#         loss = criterion(outputs, targets)
        
#         # --- Backward Pass and Optimization ---
#         # 1. Clear previous gradients
#         optimizer.zero_grad()
#         # 2. Calculate gradients
#         loss.backward()
#         # 3. Update model weights
#         optimizer.step()
        
#         # --- Statistics ---
#         running_loss += loss.item()
        
#         # Print progress for each batch
#         if (i + 1) % len(train_loader) == 0: # Print at the end of each epoch
#              print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], '
#                    f'Batch [{i+1}/{len(train_loader)}], '
#                    f'Loss: {loss.item():.4f}')

#     # Print average loss for the epoch
#     epoch_loss = running_loss / len(train_loader)
#     print(f'--- End of Epoch [{epoch+1}/{NUM_EPOCHS}], Average Loss: {epoch_loss:.4f} ---')


# print('Finished Training!')

# # --- 4. Save the Trained Model ---
# torch.save(model.state_dict(), MODEL_SAVE_PATH)
# print(f"Model saved to {MODEL_SAVE_PATH}")

Optimized training script (contains light U-net model using chunk of 2 sec audio)

In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchaudio
import pandas as pd
import numpy as np

# --- (Optional but Recommended) Lighter U-Net Model ---
# Using a model with fewer channels reduces memory usage significantly.
class LightUNet(nn.Module):
    def __init__(self, in_channels=2, out_channels=2):
        super().__init__()
        # A simplified U-Net with fewer channels
        self.enc1 = nn.Conv1d(in_channels, 16, kernel_size=15, padding=7)
        self.pool1 = nn.MaxPool1d(2)
        self.enc2 = nn.Conv1d(16, 32, kernel_size=15, padding=7)
        self.pool2 = nn.MaxPool1d(2)
        
        self.bottleneck = nn.Conv1d(32, 64, kernel_size=15, padding=7)
        
        self.upconv2 = nn.ConvTranspose1d(64, 32, kernel_size=4, stride=2, padding=1)
        self.dec2 = nn.Conv1d(64, 32, kernel_size=5, padding=2) # 32 + 32
        self.upconv1 = nn.ConvTranspose1d(32, 16, kernel_size=4, stride=2, padding=1)
        self.dec1 = nn.Conv1d(32, 16, kernel_size=5, padding=2) # 16 + 16
        
        self.final_conv = nn.Conv1d(16, out_channels, kernel_size=1)
        self.final_tanh = nn.Tanh()

    def forward(self, x):
        e1 = self.enc1(x)
        p1 = self.pool1(e1)
        e2 = self.enc2(p1)
        p2 = self.pool2(e2)
        b = self.bottleneck(p2)
        u2 = self.upconv2(b)
        d2 = self.dec2(torch.cat([u2, e2], dim=1))
        u1 = self.upconv1(d2)
        d1 = self.dec1(torch.cat([u1, e1], dim=1))
        return self.final_tanh(self.final_conv(d1))

# --- NEW Dataset that uses fixed-size chunks ---
class AudioChunkDataset(Dataset):
    def __init__(self, manifest_path, sample_rate=44100, chunk_duration_secs=2):
        self.manifest = pd.read_csv(manifest_path)
        self.sample_rate = sample_rate
        self.chunk_size = sample_rate * chunk_duration_secs

    def __len__(self):
        return len(self.manifest)

    def __getitem__(self, idx):
        input_path = self.manifest.iloc[idx]['input_path']
        target_path = self.manifest.iloc[idx]['target_path']
        
        input_waveform, _ = torchaudio.load(input_path)
        target_waveform, _ = torchaudio.load(target_path)
        
        # Get a random chunk
        # If the file is shorter than the chunk size, it will be padded later.
        if input_waveform.shape[1] > self.chunk_size:
            start = np.random.randint(0, input_waveform.shape[1] - self.chunk_size)
            input_chunk = input_waveform[:, start:start + self.chunk_size]
            target_chunk = target_waveform[:, start:start + self.chunk_size]
        else:
            input_chunk = input_waveform
            target_chunk = target_waveform

        # Pad if the chunk (or original file) is shorter than the desired chunk size
        pad_len = self.chunk_size - input_chunk.shape[1]
        if pad_len > 0:
            input_chunk = torch.nn.functional.pad(input_chunk, (0, pad_len))
            target_chunk = torch.nn.functional.pad(target_chunk, (0, pad_len))
            
        return input_chunk, target_chunk

# --- 1. Setup and Hyperparameters ---
device = torch.device("cpu") # Forcing CPU
print(f"Using device: {device}")

# Hyperparameters optimized for low memory
NUM_EPOCHS = 25
LEARNING_RATE = 1e-4
BATCH_SIZE = 1 # Process one file at a time
ACCUMULATION_STEPS = 4 # Simulate a batch size of 1 * 4 = 4
MANIFEST_FILE = "/kaggle/working/dataset_manifest.csv"
MODEL_SAVE_PATH = "/kaggle/working/audio_upscaler_cpu_model.pth"

# --- 2. Initialize Components ---
# Use the new chunk-based dataset
audio_dataset = AudioChunkDataset(manifest_path=MANIFEST_FILE)
# No custom collate_fn needed! num_workers=0 is safer for low-memory.
train_loader = DataLoader(audio_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

# Use the lighter model
model = LightUNet(in_channels=2, out_channels=2).to(device)
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print("Setup complete. Starting memory-efficient training...")

# --- 3. The UPDATED Training Loop with Gradient Accumulation ---
for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    
    # Reset gradients at the start of the epoch
    optimizer.zero_grad()
    
    for i, (inputs, targets) in enumerate(train_loader):
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        # Forward Pass
        outputs = model(inputs)
        
        # Crop output to match target length precisely
        min_len = min(outputs.shape[2], targets.shape[2])
        outputs = outputs[:, :, :min_len]
        targets = targets[:, :, :min_len]
        
        loss = criterion(outputs, targets)
        
        # Scale the loss for accumulation
        loss = loss / ACCUMULATION_STEPS
        
        # Backward Pass
        loss.backward()
        
        # --- Gradient Accumulation Step ---
        # Update weights only every ACCUMULATION_STEPS
        if (i + 1) % ACCUMULATION_STEPS == 0:
            optimizer.step()  # Update weights
            optimizer.zero_grad() # Reset gradients for the next accumulation cycle

        running_loss += loss.item() * ACCUMULATION_STEPS # Un-scale for logging

    # Print average loss for the epoch
    epoch_loss = running_loss / len(train_loader)
    print(f'--- End of Epoch [{epoch+1}/{NUM_EPOCHS}], Average Loss: {epoch_loss:.4f} ---')

print('Finished Training!')

# --- 4. Save the Trained Model ---
torch.save(model.state_dict(), MODEL_SAVE_PATH)
print(f"Model saved to {MODEL_SAVE_PATH}")

Using device: cpu
Setup complete. Starting memory-efficient training...
--- End of Epoch [1/25], Average Loss: 0.1361 ---
--- End of Epoch [2/25], Average Loss: 0.1304 ---
--- End of Epoch [3/25], Average Loss: 0.1299 ---
--- End of Epoch [4/25], Average Loss: 0.1251 ---
--- End of Epoch [5/25], Average Loss: 0.1186 ---
--- End of Epoch [6/25], Average Loss: 0.1139 ---
--- End of Epoch [7/25], Average Loss: 0.1111 ---
--- End of Epoch [8/25], Average Loss: 0.1058 ---
--- End of Epoch [9/25], Average Loss: 0.0983 ---
--- End of Epoch [10/25], Average Loss: 0.0946 ---
--- End of Epoch [11/25], Average Loss: 0.0885 ---
--- End of Epoch [12/25], Average Loss: 0.0829 ---
--- End of Epoch [13/25], Average Loss: 0.0781 ---
--- End of Epoch [14/25], Average Loss: 0.0672 ---
--- End of Epoch [15/25], Average Loss: 0.0550 ---
--- End of Epoch [16/25], Average Loss: 0.0476 ---
--- End of Epoch [17/25], Average Loss: 0.0426 ---
--- End of Epoch [18/25], Average Loss: 0.0327 ---
--- End of Epoch [1

Performing inference and then listening to the output

In [13]:
import torch
import torchaudio
import numpy as np

# --- 1. Setup ---
device = torch.device("cpu")
MODEL_PATH = "/kaggle/working/audio_upscaler_cpu_model.pth"
# Use one of your original MP3 files as input
INPUT_AUDIO_PATH = "/kaggle/working/mp3_converted/a1.mp3" 
OUTPUT_AUDIO_PATH = "/kaggle/working/upscaled_output.wav"
SAMPLE_RATE = 44100
CHUNK_DURATION_SECS = 2 # Use the same chunk size as in training
CHUNK_SIZE = SAMPLE_RATE * CHUNK_DURATION_SECS

# --- 2. Load Model ---
# Make sure the LightUNet class is defined in a previous cell
model = LightUNet(in_channels=2, out_channels=2).to(device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval() # Set the model to evaluation mode (very important!)

print("Model loaded. Starting inference...")

# --- 3. Load and Process Audio in Chunks ---
input_waveform, _ = torchaudio.load(INPUT_AUDIO_PATH)
input_waveform = input_waveform.to(device)
output_chunks = []

# Process the audio chunk by chunk to avoid memory errors
with torch.no_grad(): # Disable gradient calculation for efficiency
    for i in range(0, input_waveform.shape[1], CHUNK_SIZE):
        chunk = input_waveform[:, i:i + CHUNK_SIZE]
        
        # Pad the last chunk if it's smaller than the required size
        if chunk.shape[1] < CHUNK_SIZE:
            pad_len = CHUNK_SIZE - chunk.shape[1]
            chunk = torch.nn.functional.pad(chunk, (0, pad_len))

        # Add a batch dimension and run through the model
        chunk = chunk.unsqueeze(0) # Shape: [1, num_channels, chunk_size]
        output_chunk = model(chunk)
        output_chunks.append(output_chunk.squeeze(0)) # Remove batch dimension

# --- 4. Stitch Chunks Together and Save ---
# Concatenate all the processed chunks
output_waveform = torch.cat(output_chunks, dim=1)

# Trim any excess padding from the end by matching the original input length
output_waveform = output_waveform[:, :input_waveform.shape[1]]

# Save the final upscaled audio
torchaudio.save(OUTPUT_AUDIO_PATH, output_waveform.cpu(), SAMPLE_RATE)

print(f"Inference complete! Upscaled audio saved to: {OUTPUT_AUDIO_PATH}")

Model loaded. Starting inference...
Inference complete! Upscaled audio saved to: /kaggle/working/upscaled_output.wav
