In [1]:
! unzip /content/gtzan_dataset-20250401T215714Z-001.zip

Archive:  /content/gtzan_dataset-20250401T215714Z-001.zip
  inflating: gtzan_dataset/images_original/blues/blues00092.png  
  inflating: gtzan_dataset/genres_original/hiphop/hiphop.00093.wav  
  inflating: gtzan_dataset/images_original/blues/blues00088.png  
  inflating: gtzan_dataset/images_original/blues/blues00085.png  
  inflating: gtzan_dataset/images_original/blues/blues00087.png  
  inflating: gtzan_dataset/images_original/blues/blues00097.png  
  inflating: gtzan_dataset/images_original/blues/blues00083.png  
  inflating: gtzan_dataset/images_original/blues/blues00091.png  
  inflating: gtzan_dataset/images_original/blues/blues00096.png  
  inflating: gtzan_dataset/images_original/blues/blues00090.png  
  inflating: gtzan_dataset/genres_original/hiphop/hiphop.00080.wav  
  inflating: gtzan_dataset/genres_original/hiphop/hiphop.00094.wav  
  inflating: gtzan_dataset/features_30_sec.csv  
  inflating: gtzan_dataset/images_original/blues/blues00065.png  
  inflating: gtzan_dataset

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import torchaudio.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import librosa
import librosa.display
import numpy as np
import os
from glob import glob
import matplotlib.pyplot as plt

# Define constants
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAMPLE_RATE = 22050
N_MELS = 128
N_FFT = 2048
HOP_LENGTH = 512
BATCH_SIZE = 8
EPOCHS = 50
LEARNING_RATE = 2e-4

# Load GTZAN dataset
class GTZANDataset(Dataset):
    def __init__(self, root_dir):
        self.file_paths = glob(os.path.join(root_dir, "genres_original", "*", "*.wav"))

        # Check if the dataset is empty
        if len(self.file_paths) == 0:
            raise ValueError(f"No .wav files found in {root_dir}/genres_original. Please check your dataset path.")

        self.labels = [os.path.basename(os.path.dirname(fp)) for fp in self.file_paths]
        self.label_dict = {genre: idx for idx, genre in enumerate(sorted(set(self.labels)))}

        self.transform = transforms.MelSpectrogram(
            sample_rate=SAMPLE_RATE,
            n_mels=N_MELS,
            n_fft=N_FFT,
            hop_length=HOP_LENGTH
        )

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        label = self.label_dict[self.labels[idx]]
        waveform, _ = librosa.load(file_path, sr=SAMPLE_RATE)
        waveform = torch.tensor(waveform[:SAMPLE_RATE*3])  # 3 seconds clip
        mel_spec = self.transform(waveform)  # Shape: [1, n_mels, time]
        return mel_spec.squeeze(0), label  # Return [n_mels, time]

# Load dataset
dataset_path = "/content/gtzan_dataset"
train_dataset = GTZANDataset(dataset_path)

# Print dataset size for debugging
print(f"Number of samples in dataset: {len(train_dataset)}")

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Define Transformer-based Generator
class TransformerGenerator(nn.Module):
    def __init__(self, input_dim=N_MELS, num_heads=4, ff_dim=256):
        super(TransformerGenerator, self).__init__()
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_dim, nhead=num_heads, dim_feedforward=ff_dim)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=6)
        self.fc = nn.Linear(input_dim, input_dim)

    def forward(self, x):
        x = x.permute(2, 0, 1)  # (batch, n_mels, time) -> (time, batch, n_mels)
        x = self.transformer_encoder(x)
        x = self.fc(x)
        x = x.permute(1, 2, 0)  # Back to (batch, n_mels, time)
        return x

# Define Discriminator
class CNNDiscriminator(nn.Module):
    def __init__(self, input_dim=N_MELS):
        super(CNNDiscriminator, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2), padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.unsqueeze(1)  # Add channel dim: [batch, 1, n_mels, time]
        return self.conv(x)

# Define Genre Classifier
class GenreClassifier(nn.Module):
    def __init__(self, input_dim=N_MELS, num_classes=10):
        super(GenreClassifier, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.flattened_size = 128 * (N_MELS//8) * ((3*SAMPLE_RATE//HOP_LENGTH)//8)
        self.fc = nn.Sequential(
            nn.Linear(self.flattened_size, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = x.unsqueeze(1)  # Add channel dimension
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

# Initialize models
generator = TransformerGenerator().to(device)
discriminator = CNNDiscriminator().to(device)
classifier = GenreClassifier().to(device)

##############################
# Training the Genre Classifier
##############################
print("Training the genre classifier...")
criterion = nn.CrossEntropyLoss()
classifier_optimizer = optim.Adam(classifier.parameters(), lr=1e-3)

best_classifier_loss = float('inf')

for epoch in range(50):
    epoch_loss = 0.0
    for mel_specs, labels in train_loader:
        mel_specs, labels = mel_specs.to(device), labels.to(device)

        classifier_optimizer.zero_grad()
        outputs = classifier(mel_specs)
        loss = criterion(outputs, labels)
        loss.backward()
        classifier_optimizer.step()

        epoch_loss += loss.item()

    epoch_loss /= len(train_loader)
    # Save the best classifier model
    if epoch_loss < best_classifier_loss:
        best_classifier_loss = epoch_loss
        torch.save(classifier.state_dict(), "best_genre_classifier.pth")
        print(f"Saved best classifier model at epoch {epoch+1} with loss {epoch_loss:.4f}")

    print(f"Classifier Epoch {epoch+1}, Average Loss: {epoch_loss:.4f}")

# Save the final classifier model (if needed)
torch.save(classifier.state_dict(), "genre_classifier.pth")
classifier.eval()

##############################
# Training the GAN
##############################
adversarial_loss = nn.BCELoss()
content_loss = nn.MSELoss()
cycle_loss = nn.L1Loss()

g_optimizer = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

best_g_loss = float('inf')

print("Training the GAN...")
for epoch in range(EPOCHS):
    epoch_g_loss = 0.0
    for batch_idx, (real_data, _) in enumerate(train_loader):
        real_data = real_data.to(device)
        batch_size = real_data.size(0)

        # Generate fake data
        fake_data = generator(real_data)

        # Train Discriminator
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        d_optimizer.zero_grad()
        real_output = discriminator(real_data)
        real_loss = adversarial_loss(real_output, real_labels)
        fake_output = discriminator(fake_data.detach())
        fake_loss = adversarial_loss(fake_output, fake_labels)
        d_loss = real_loss + fake_loss
        d_loss.backward()
        d_optimizer.step()

        # Train Generator
        g_optimizer.zero_grad()
        g_loss_adv = adversarial_loss(discriminator(fake_data), real_labels)
        g_loss_content = content_loss(fake_data, real_data)
        g_loss_cycle = cycle_loss(generator(fake_data), real_data)
        g_loss = g_loss_adv + 0.5 * g_loss_content + 0.5 * g_loss_cycle
        g_loss.backward()
        g_optimizer.step()

        epoch_g_loss += g_loss.item()

        if batch_idx % 10 == 0:
            print(f"GAN Epoch [{epoch+1}/{EPOCHS}] Batch [{batch_idx}/{len(train_loader)}] "
                  f"D Loss: {d_loss.item():.4f} G Loss: {g_loss.item():.4f}")

    epoch_g_loss /= len(train_loader)
    # Save the best generator model
    if epoch_g_loss < best_g_loss:
        best_g_loss = epoch_g_loss
        torch.save(generator.state_dict(), "best_generator.pth")
        print(f"Saved best generator model at epoch {epoch+1} with average G Loss {epoch_g_loss:.4f}")

    print(f"GAN Epoch [{epoch+1}/{EPOCHS}], Average Generator Loss: {epoch_g_loss:.4f}")

print("Training Complete!")


Number of samples in dataset: 998
Training the genre classifier...
Saved best classifier model at epoch 1 with loss 45.8812
Classifier Epoch 1, Average Loss: 45.8812
Saved best classifier model at epoch 2 with loss 2.2407
Classifier Epoch 2, Average Loss: 2.2407
Saved best classifier model at epoch 3 with loss 2.1084
Classifier Epoch 3, Average Loss: 2.1084
Saved best classifier model at epoch 4 with loss 1.9449
Classifier Epoch 4, Average Loss: 1.9449
Saved best classifier model at epoch 5 with loss 1.7649
Classifier Epoch 5, Average Loss: 1.7649
Saved best classifier model at epoch 6 with loss 1.5039
Classifier Epoch 6, Average Loss: 1.5039
Saved best classifier model at epoch 7 with loss 1.3054
Classifier Epoch 7, Average Loss: 1.3054
Classifier Epoch 8, Average Loss: 1.3791
Saved best classifier model at epoch 9 with loss 1.0548
Classifier Epoch 9, Average Loss: 1.0548
Classifier Epoch 10, Average Loss: 1.5237
Classifier Epoch 11, Average Loss: 1.6468
Classifier Epoch 12, Average L

In [None]:
# Genre conversion function
def convert_genre(audio_tensor, source_genre, target_genre):
    print(f"Converting from {source_genre} to {target_genre}...")
    audio_tensor = audio_tensor.to(device)
    converted_audio = generator(audio_tensor.unsqueeze(0)).squeeze(0)

    # Predict the genre of the converted audio
    predicted_genre_idx = predict_genre(converted_audio)
    predicted_genre = list(train_dataset.label_dict.keys())[predicted_genre_idx]

    print(f"Converted audio predicted as: {predicted_genre}")
    return converted_audio

# Genre prediction function
def predict_genre(audio_tensor):
    audio_tensor = audio_tensor.to(device)
    with torch.no_grad():
        prediction = classifier(audio_tensor.unsqueeze(0))
    return torch.argmax(prediction, dim=1).item()

def load_and_preprocess_audio(file_path, duration=3):
    """Load and preprocess an audio file for conversion.

    Args:
        file_path (str): Path to the input audio file.
        duration (int): Desired duration in seconds (default is 3 seconds).

    Returns:
        torch.Tensor: Mel-spectrogram tensor of shape [n_mels, time].
    """
    # Load audio file using librosa
    waveform, sr = librosa.load(file_path, sr=SAMPLE_RATE)

    # Trim or pad to desired duration
    if len(waveform) > SAMPLE_RATE * duration:
        waveform = waveform[:SAMPLE_RATE * duration]
    else:
        padding = SAMPLE_RATE * duration - len(waveform)
        waveform = np.pad(waveform, (0, padding), mode='constant')

    # Convert waveform to mel-spectrogram using torchaudio
    transform = transforms.MelSpectrogram(
        sample_rate=SAMPLE_RATE,
        n_mels=N_MELS,
        n_fft=N_FFT,
        hop_length=HOP_LENGTH
    )
    waveform_tensor = torch.tensor(waveform)
    mel_spec = transform(waveform_tensor).squeeze(0)  # Shape: [n_mels, time]

    return mel_spec

def convert_audio_file(input_path, output_path, source_genre, target_genre):
    """
    Convert an audio file from source_genre to target_genre and save the converted audio.

    Args:
        input_path (str): Path to the input audio file.
        output_path (str): Path to save the converted audio.
        source_genre (str): Name of the source genre (for logging).
        target_genre (str): Name of the target genre (for logging).
    """
    # Load and preprocess audio
    print(f"Loading and preprocessing {input_path}...")
    mel_spec = load_and_preprocess_audio(input_path)

    # Convert genre - ensure the generator is in evaluation mode
    generator.eval()
    with torch.no_grad():
        print(f"Converting from {source_genre} to {target_genre}...")
        converted_mel = convert_genre(mel_spec, source_genre, target_genre)

    # Convert mel-spectrogram back to audio using inverse transforms (Griffin-Lim)
    print("Converting back to waveform...")

    # Create inverse mel-scale transform
    inv_mel = transforms.InverseMelScale(
        n_stft=N_FFT // 2 + 1,
        n_mels=N_MELS,
        sample_rate=SAMPLE_RATE
    )

    # Create Griffin-Lim transform for spectrogram inversion
    griffin_lim = transforms.GriffinLim(
        n_fft=N_FFT,
        hop_length=HOP_LENGTH,
        n_iter=32
    )

    # Ensure proper shape: add batch dimension -> [1, n_mels, time]
    converted_mel = converted_mel.unsqueeze(0).detach().cpu()

    # Convert mel to linear spectrogram
    spec_estimate = inv_mel(converted_mel)

    # Reconstruct waveform from the spectrogram
    waveform = griffin_lim(spec_estimate)

    # Save the converted audio (waveform must be detached and on CPU)
    torchaudio.save(output_path, waveform.detach(), SAMPLE_RATE)
    print(f"Converted audio saved to {output_path}")

# Example usage:
if __name__ == "__main__":
    # Example file paths - replace these with your actual file paths
    input_audio = "/content/gtzan_dataset/genres_original/disco/disco.00000.wav"  # Your input audio file
    output_audio = "output_classical.wav"  # Where to save the converted audio

    # Ensure your train_dataset is defined globally and has been created during training.
    # This line prints the available genres in the dataset.
    available_genres = list(train_dataset.label_dict.keys())
    print("Available genres:", available_genres)

    # Example conversion: converting from 'metal' to 'rock'
    convert_audio_file(
        input_path=input_audio,
        output_path=output_audio,
        source_genre="disco",
        target_genre="classical"
    )


Available genres: ['blues', 'classical', 'country', 'disco', 'hiphop', 'jazz', 'metal', 'pop', 'reggae', 'rock']
Loading and preprocessing /content/gtzan_dataset/genres_original/disco/disco.00000.wav...
Converting from disco to classical...
Converting from disco to classical...
Converted audio predicted as: pop
Converting back to waveform...
Converted audio saved to output_classical.wav


Upar vala nhi chala neeche vala new version 1

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchaudio
import torchaudio.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import librosa
import numpy as np
import os
from glob import glob
import matplotlib.pyplot as plt
from tqdm import tqdm

# Constants
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAMPLE_RATE = 22050
N_MELS = 128
N_FFT = 2048
HOP_LENGTH = 512
BATCH_SIZE = 8
EPOCHS = 20
LEARNING_RATE = 2e-4
NUM_GENRES = 10  # For GTZAN dataset

# Dataset Class with fixed dimensions
class GTZANDataset(Dataset):
    def __init__(self, root_dir, duration=3):
        self.file_paths = glob(os.path.join(root_dir, "genres_original", "*", "*.wav"))
        if not self.file_paths:
            raise ValueError(f"No .wav files found in {root_dir}")

        self.labels = [os.path.basename(os.path.dirname(fp)) for fp in self.file_paths]
        self.label_dict = {genre: idx for idx, genre in enumerate(sorted(set(self.labels)))}
        self.duration = duration

        # Calculate exact time dimension
        self.time_dim = (duration * SAMPLE_RATE) // HOP_LENGTH + 1

        self.mel_transform = transforms.MelSpectrogram(
            sample_rate=SAMPLE_RATE,
            n_mels=N_MELS,
            n_fft=N_FFT,
            hop_length=HOP_LENGTH
        )

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        label = self.label_dict[self.labels[idx]]

        # Load with exact duration
        waveform, _ = librosa.load(file_path, sr=SAMPLE_RATE, duration=self.duration)
        if len(waveform) < SAMPLE_RATE * self.duration:
            padding = SAMPLE_RATE * self.duration - len(waveform)
            waveform = np.pad(waveform, (0, padding), mode='constant')

        waveform = torch.FloatTensor(waveform)
        mel_spec = self.mel_transform(waveform)

        # Ensure exact dimensions
        if mel_spec.shape[-1] > self.time_dim:
            mel_spec = mel_spec[..., :self.time_dim]
        elif mel_spec.shape[-1] < self.time_dim:
            padding = self.time_dim - mel_spec.shape[-1]
            mel_spec = F.pad(mel_spec, (0, padding))

        # Log scaling and normalization
        mel_spec = torch.log(mel_spec + 1e-9)
        mel_spec = (mel_spec - mel_spec.mean()) / (mel_spec.std() + 1e-9)

        return mel_spec, label

# Genre-Conditioned Generator with fixed dimensions
class GenreConditionedGenerator(nn.Module):
    def __init__(self, num_genres=NUM_GENRES):
        super().__init__()

        # Genre embedding
        self.genre_embedding = nn.Embedding(num_genres, 64)

        # Initial convolution
        self.initial = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=7, stride=1, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )

        # Downsample blocks
        self.down1 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True)
        )

        self.down2 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True)
        )

        # Residual blocks with genre conditioning
        self.resblocks = nn.ModuleList([
            ResidualBlock(256, num_genres) for _ in range(6)
        ])

        # Upsample blocks with output padding
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2,
                              padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True)
        )

        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2,
                              padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )

        # Final convolution
        self.final = nn.Sequential(
            nn.Conv2d(64, 1, kernel_size=7, stride=1, padding=3),
            nn.Tanh()
        )

    def forward(self, x, target_genre):
        # Store original dimensions
        original_size = x.size()

        # Initial processing
        x = self.initial(x)

        # Downsample
        x = self.down1(x)
        x = self.down2(x)

        # Get genre embedding
        genre_emb = self.genre_embedding(target_genre)  # [B, 64]
        genre_emb = genre_emb.unsqueeze(-1).unsqueeze(-1)  # [B, 64, 1, 1]

        # Residual blocks
        for block in self.resblocks:
            x = block(x, genre_emb)

        # Upsample
        x = self.up1(x)
        x = self.up2(x)

        # Final processing
        x = self.final(x)

        # Ensure output matches input dimensions
        if x.size() != original_size:
            x = F.interpolate(x, size=original_size[2:], mode='bilinear', align_corners=False)

        return x

# Update the ResidualBlock class to fix the dimension mismatch
class ResidualBlock(nn.Module):
    def __init__(self, channels, num_genres):
        super().__init__()

        # Project genre embedding to match channel dimensions
        self.genre_proj = nn.Sequential(
            nn.Linear(64, channels),
            nn.ReLU(inplace=True)
        )

        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.in1 = nn.InstanceNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.in2 = nn.InstanceNorm2d(channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x, genre_emb):
        residual = x

        # Process genre embedding
        genre_emb = self.genre_proj(genre_emb.squeeze(-1).squeeze(-1))  # [B, C]
        genre_emb = genre_emb.view(genre_emb.size(0), genre_emb.size(1), 1, 1)
        genre_emb = genre_emb.expand(-1, -1, x.size(2), x.size(3))

        # Add genre conditioning
        x = x + genre_emb

        x = self.relu(self.in1(self.conv1(x)))
        x = self.in2(self.conv2(x))

        return x + residual

# Discriminator with Genre Classification
class GenreAwareDiscriminator(nn.Module):
    def __init__(self, num_genres=NUM_GENRES):
        super().__init__()

        self.model = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.AdaptiveAvgPool2d(1)
        )

        # Real/fake discriminator
        self.discriminator = nn.Linear(512, 1)

        # Genre classifier
        self.genre_classifier = nn.Linear(512, num_genres)

    def forward(self, x):
        features = self.model(x).squeeze(-1).squeeze(-1)
        validity = self.discriminator(features)
        genre = self.genre_classifier(features)
        return validity, genre

# Training function with dimension checks
def train():
    # Initialize dataset
    dataset = GTZANDataset("/content/gtzan_dataset")
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    # Initialize models
    generator = GenreConditionedGenerator().to(device)
    discriminator = GenreAwareDiscriminator().to(device)

    # Loss functions
    adversarial_loss = nn.BCEWithLogitsLoss()
    genre_classification_loss = nn.CrossEntropyLoss()
    cycle_loss = nn.L1Loss()
    identity_loss = nn.L1Loss()

    # Optimizers
    g_optimizer = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

    for epoch in range(EPOCHS):
        for i, (real_imgs, real_genres) in enumerate(tqdm(dataloader)):
            real_imgs = real_imgs.unsqueeze(1).to(device)  # Add channel dim [B,1,M,T]
            real_genres = real_genres.to(device)

            # Random target genres different from source
            target_genres = torch.randint(0, NUM_GENRES, (real_imgs.size(0),)).to(device)
            target_genres = torch.where(target_genres == real_genres,
                                     (target_genres + 1) % NUM_GENRES,
                                     target_genres)

            # ---------------------
            #  Train Discriminator
            # ---------------------
            d_optimizer.zero_grad()

            # Generate fake images
            fake_imgs = generator(real_imgs, target_genres)

            # Real images
            real_validity, real_genre_pred = discriminator(real_imgs)
            d_real_loss = adversarial_loss(real_validity, torch.ones_like(real_validity))
            d_real_genre_loss = genre_classification_loss(real_genre_pred, real_genres)

            # Fake images
            fake_validity, _ = discriminator(fake_imgs.detach())
            d_fake_loss = adversarial_loss(fake_validity, torch.zeros_like(fake_validity))

            d_loss = d_real_loss + d_fake_loss + d_real_genre_loss
            d_loss.backward()
            d_optimizer.step()

            # -----------------
            #  Train Generator
            # -----------------
            g_optimizer.zero_grad()

            # Generate fake images and reconstruct
            fake_imgs = generator(real_imgs, target_genres)
            reconstructed_imgs = generator(fake_imgs, real_genres)

            # Adversarial and genre loss
            fake_validity, fake_genre_pred = discriminator(fake_imgs)
            g_adv_loss = adversarial_loss(fake_validity, torch.ones_like(fake_validity))
            g_genre_loss = genre_classification_loss(fake_genre_pred, target_genres)

            # Cycle consistency loss (with dimension check)
            if reconstructed_imgs.shape == real_imgs.shape:
                g_cycle_loss = cycle_loss(reconstructed_imgs, real_imgs)
            else:
                # Resize if dimensions don't match
                reconstructed_imgs = F.interpolate(
                    reconstructed_imgs, size=real_imgs.shape[2:], mode='bilinear', align_corners=False)
                g_cycle_loss = cycle_loss(reconstructed_imgs, real_imgs)

            # Identity loss
            identity_imgs = generator(real_imgs, real_genres)
            g_identity_loss = identity_loss(identity_imgs, real_imgs)

            g_loss = g_adv_loss + g_genre_loss + 10 * g_cycle_loss + 5 * g_identity_loss
            g_loss.backward()
            g_optimizer.step()

            if i % 50 == 0:
                print(f"[Epoch {epoch}/{EPOCHS}] [Batch {i}/{len(dataloader)}] "
                     f"D Loss: {d_loss.item():.4f} G Loss: {g_loss.item():.4f}")

        # Save checkpoints
        torch.save(generator.state_dict(), f"generator_{epoch}.pth")
        torch.save(discriminator.state_dict(), f"discriminator_{epoch}.pth")



# Audio Conversion Function
def convert_audio(input_path, output_path, source_genre, target_genre, generator, dataset):
    # Load and preprocess audio
    waveform, _ = librosa.load(input_path, sr=SAMPLE_RATE, duration=3)
    if len(waveform) < SAMPLE_RATE * 3:
        waveform = np.pad(waveform, (0, SAMPLE_RATE * 3 - len(waveform)))

    # Convert to mel-spectrogram
    mel_transform = transforms.MelSpectrogram(
        sample_rate=SAMPLE_RATE,
        n_mels=N_MELS,
        n_fft=N_FFT,
        hop_length=HOP_LENGTH
    )
    mel_spec = mel_transform(torch.FloatTensor(waveform))
    mel_spec = torch.log(mel_spec + 1e-9)
    mel_spec = (mel_spec - mel_spec.mean()) / (mel_spec.std() + 1e-9)

    # Convert genre
    generator.eval()
    with torch.no_grad():
        target_idx = dataset.label_dict[target_genre]
        converted_mel = generator(
            mel_spec.unsqueeze(0).unsqueeze(0).to(device),
            torch.tensor([target_idx]).to(device)
        ).squeeze().cpu()

    # Inverse mel-spectrogram
    inv_mel = transforms.InverseMelScale(
        n_stft=N_FFT // 2 + 1,
        n_mels=N_MELS,
        sample_rate=SAMPLE_RATE
    )
    griffin_lim = transforms.GriffinLim(
        n_fft=N_FFT,
        hop_length=HOP_LENGTH,
        n_iter=32
    )

    # Denormalize and convert back to waveform
    converted_mel = converted_mel * mel_spec.std() + mel_spec.mean()
    converted_mel = torch.exp(converted_mel) - 1e-9
    spec_estimate = inv_mel(converted_mel.unsqueeze(0))
    waveform = griffin_lim(spec_estimate)

    # Save result
    torchaudio.save(output_path, waveform, SAMPLE_RATE)

if __name__ == "__main__":
    # Initialize dataset
    dataset = GTZANDataset("/content/gtzan_dataset")
    print("Available genres:", list(dataset.label_dict.keys()))

    # Train or load pretrained models
    train()

    # Example conversion
    generator = GenreConditionedGenerator().to(device)
    generator.load_state_dict(torch.load("generator_99.pth"))

    convert_audio(
        input_path="/content/gtzan_dataset/genres_original/disco/disco.00000.wav",
        output_path="converted_to_blues.wav",
        source_genre="disco",
        target_genre="blues",
        generator=generator,
        dataset=dataset
    )

Available genres: ['blues', 'classical', 'country', 'disco', 'hiphop', 'jazz', 'metal', 'pop', 'reggae', 'rock']


  1%|          | 1/125 [00:00<01:03,  1.95it/s]

[Epoch 0/20] [Batch 0/125] D Loss: 3.8324 G Loss: 14.7006


 41%|████      | 51/125 [00:23<00:39,  1.89it/s]

[Epoch 0/20] [Batch 50/125] D Loss: 3.4323 G Loss: 9.0255


 81%|████████  | 101/125 [00:46<00:12,  1.89it/s]

[Epoch 0/20] [Batch 100/125] D Loss: 3.1408 G Loss: 8.8781


100%|██████████| 125/125 [00:56<00:00,  2.20it/s]
  1%|          | 1/125 [00:00<00:57,  2.16it/s]

[Epoch 1/20] [Batch 0/125] D Loss: 3.0127 G Loss: 8.8942


 41%|████      | 51/125 [00:22<00:38,  1.93it/s]

[Epoch 1/20] [Batch 50/125] D Loss: 2.7789 G Loss: 8.8588


 81%|████████  | 101/125 [00:45<00:12,  1.91it/s]

[Epoch 1/20] [Batch 100/125] D Loss: 2.7035 G Loss: 8.9456


100%|██████████| 125/125 [00:56<00:00,  2.23it/s]
  1%|          | 1/125 [00:00<00:57,  2.16it/s]

[Epoch 2/20] [Batch 0/125] D Loss: 2.5258 G Loss: 9.0807


 41%|████      | 51/125 [00:23<00:38,  1.91it/s]

[Epoch 2/20] [Batch 50/125] D Loss: 2.5087 G Loss: 8.8506


 81%|████████  | 101/125 [00:45<00:12,  1.90it/s]

[Epoch 2/20] [Batch 100/125] D Loss: 2.4414 G Loss: 8.9830


100%|██████████| 125/125 [00:56<00:00,  2.21it/s]
  1%|          | 1/125 [00:00<00:57,  2.15it/s]

[Epoch 3/20] [Batch 0/125] D Loss: 2.3388 G Loss: 9.3106


 41%|████      | 51/125 [00:23<00:39,  1.90it/s]

[Epoch 3/20] [Batch 50/125] D Loss: 2.3360 G Loss: 9.4505


 81%|████████  | 101/125 [00:46<00:12,  1.90it/s]

[Epoch 3/20] [Batch 100/125] D Loss: 2.2172 G Loss: 9.1809


100%|██████████| 125/125 [00:56<00:00,  2.20it/s]
  1%|          | 1/125 [00:00<00:57,  2.14it/s]

[Epoch 4/20] [Batch 0/125] D Loss: 2.2110 G Loss: 9.6384


 41%|████      | 51/125 [00:23<00:38,  1.90it/s]

[Epoch 4/20] [Batch 50/125] D Loss: 2.3223 G Loss: 9.9476


 81%|████████  | 101/125 [00:46<00:12,  1.90it/s]

[Epoch 4/20] [Batch 100/125] D Loss: 2.1574 G Loss: 9.5342


100%|██████████| 125/125 [00:56<00:00,  2.20it/s]
  1%|          | 1/125 [00:00<00:57,  2.17it/s]

[Epoch 5/20] [Batch 0/125] D Loss: 3.0467 G Loss: 7.8818


 41%|████      | 51/125 [00:23<00:38,  1.90it/s]

[Epoch 5/20] [Batch 50/125] D Loss: 2.4482 G Loss: 9.2156


 81%|████████  | 101/125 [00:46<00:12,  1.90it/s]

[Epoch 5/20] [Batch 100/125] D Loss: 2.2049 G Loss: 9.6237


100%|██████████| 125/125 [00:56<00:00,  2.20it/s]
  1%|          | 1/125 [00:00<00:57,  2.14it/s]

[Epoch 6/20] [Batch 0/125] D Loss: 2.3153 G Loss: 9.4407


 41%|████      | 51/125 [00:23<00:39,  1.89it/s]

[Epoch 6/20] [Batch 50/125] D Loss: 2.3202 G Loss: 9.4172


 81%|████████  | 101/125 [00:46<00:12,  1.89it/s]

[Epoch 6/20] [Batch 100/125] D Loss: 2.2289 G Loss: 10.0934


100%|██████████| 125/125 [00:56<00:00,  2.20it/s]
  1%|          | 1/125 [00:00<00:57,  2.17it/s]

[Epoch 7/20] [Batch 0/125] D Loss: 2.2796 G Loss: 9.7305


 41%|████      | 51/125 [00:23<00:39,  1.89it/s]

[Epoch 7/20] [Batch 50/125] D Loss: 2.0537 G Loss: 10.0884


 81%|████████  | 101/125 [00:46<00:12,  1.89it/s]

[Epoch 7/20] [Batch 100/125] D Loss: 2.2209 G Loss: 10.0549


100%|██████████| 125/125 [00:56<00:00,  2.20it/s]
  1%|          | 1/125 [00:00<00:57,  2.16it/s]

[Epoch 8/20] [Batch 0/125] D Loss: 2.1375 G Loss: 9.9265


 41%|████      | 51/125 [00:23<00:39,  1.90it/s]

[Epoch 8/20] [Batch 50/125] D Loss: 1.9855 G Loss: 10.0698


 81%|████████  | 101/125 [00:46<00:12,  1.90it/s]

[Epoch 8/20] [Batch 100/125] D Loss: 2.1102 G Loss: 9.7423


100%|██████████| 125/125 [00:56<00:00,  2.20it/s]
  1%|          | 1/125 [00:00<00:57,  2.15it/s]

[Epoch 9/20] [Batch 0/125] D Loss: 2.0131 G Loss: 10.0839


 41%|████      | 51/125 [00:23<00:38,  1.90it/s]

[Epoch 9/20] [Batch 50/125] D Loss: 2.1265 G Loss: 9.7423


 81%|████████  | 101/125 [00:46<00:12,  1.90it/s]

[Epoch 9/20] [Batch 100/125] D Loss: 2.0831 G Loss: 10.7826


100%|██████████| 125/125 [00:56<00:00,  2.20it/s]
  1%|          | 1/125 [00:00<00:57,  2.17it/s]

[Epoch 10/20] [Batch 0/125] D Loss: 1.9601 G Loss: 9.3232


 41%|████      | 51/125 [00:23<00:39,  1.89it/s]

[Epoch 10/20] [Batch 50/125] D Loss: 2.0587 G Loss: 10.9260


 81%|████████  | 101/125 [00:46<00:12,  1.89it/s]

[Epoch 10/20] [Batch 100/125] D Loss: 1.9918 G Loss: 10.1880


100%|██████████| 125/125 [00:56<00:00,  2.20it/s]
  1%|          | 1/125 [00:00<00:57,  2.15it/s]

[Epoch 11/20] [Batch 0/125] D Loss: 2.1515 G Loss: 10.3195


 41%|████      | 51/125 [00:23<00:39,  1.90it/s]

[Epoch 11/20] [Batch 50/125] D Loss: 1.8921 G Loss: 10.9143


 81%|████████  | 101/125 [00:46<00:12,  1.90it/s]

[Epoch 11/20] [Batch 100/125] D Loss: 2.0883 G Loss: 10.1897


100%|██████████| 125/125 [00:56<00:00,  2.20it/s]
  1%|          | 1/125 [00:00<00:57,  2.16it/s]

[Epoch 12/20] [Batch 0/125] D Loss: 1.7190 G Loss: 10.4550


 41%|████      | 51/125 [00:23<00:38,  1.91it/s]

[Epoch 12/20] [Batch 50/125] D Loss: 1.7424 G Loss: 10.3497


 81%|████████  | 101/125 [00:45<00:12,  1.91it/s]

[Epoch 12/20] [Batch 100/125] D Loss: 2.0183 G Loss: 10.1219


100%|██████████| 125/125 [00:56<00:00,  2.21it/s]
  1%|          | 1/125 [00:00<00:57,  2.15it/s]

[Epoch 13/20] [Batch 0/125] D Loss: 1.8686 G Loss: 10.6336


 41%|████      | 51/125 [00:23<00:38,  1.92it/s]

[Epoch 13/20] [Batch 50/125] D Loss: 1.7766 G Loss: 10.3322


 81%|████████  | 101/125 [00:45<00:12,  1.92it/s]

[Epoch 13/20] [Batch 100/125] D Loss: 1.9324 G Loss: 10.5881


100%|██████████| 125/125 [00:56<00:00,  2.22it/s]
  1%|          | 1/125 [00:00<00:56,  2.18it/s]

[Epoch 14/20] [Batch 0/125] D Loss: 1.8068 G Loss: 10.2749


 41%|████      | 51/125 [00:23<00:38,  1.91it/s]

[Epoch 14/20] [Batch 50/125] D Loss: 1.9232 G Loss: 11.1747


 81%|████████  | 101/125 [00:45<00:12,  1.90it/s]

[Epoch 14/20] [Batch 100/125] D Loss: 1.7308 G Loss: 10.7394


100%|██████████| 125/125 [00:56<00:00,  2.22it/s]
  1%|          | 1/125 [00:00<00:57,  2.17it/s]

[Epoch 15/20] [Batch 0/125] D Loss: 1.7208 G Loss: 10.8055


 41%|████      | 51/125 [00:23<00:38,  1.91it/s]

[Epoch 15/20] [Batch 50/125] D Loss: 1.8494 G Loss: 10.8477


 81%|████████  | 101/125 [00:45<00:12,  1.90it/s]

[Epoch 15/20] [Batch 100/125] D Loss: 1.8771 G Loss: 10.9928


100%|██████████| 125/125 [00:56<00:00,  2.21it/s]
  1%|          | 1/125 [00:00<00:57,  2.15it/s]

[Epoch 16/20] [Batch 0/125] D Loss: 1.9138 G Loss: 10.4087


 41%|████      | 51/125 [00:23<00:38,  1.91it/s]

[Epoch 16/20] [Batch 50/125] D Loss: 1.7891 G Loss: 10.7447


 81%|████████  | 101/125 [00:46<00:12,  1.90it/s]

[Epoch 16/20] [Batch 100/125] D Loss: 1.3591 G Loss: 10.1569


100%|██████████| 125/125 [00:56<00:00,  2.20it/s]
  1%|          | 1/125 [00:00<00:57,  2.16it/s]

[Epoch 17/20] [Batch 0/125] D Loss: 1.7392 G Loss: 10.9263


 41%|████      | 51/125 [00:23<00:39,  1.89it/s]

[Epoch 17/20] [Batch 50/125] D Loss: 1.6168 G Loss: 12.0467


 81%|████████  | 101/125 [00:46<00:12,  1.90it/s]

[Epoch 17/20] [Batch 100/125] D Loss: 1.6085 G Loss: 10.8518


100%|██████████| 125/125 [00:56<00:00,  2.20it/s]
  1%|          | 1/125 [00:00<00:57,  2.16it/s]

[Epoch 18/20] [Batch 0/125] D Loss: 1.6396 G Loss: 10.7543


 41%|████      | 51/125 [00:23<00:39,  1.90it/s]

[Epoch 18/20] [Batch 50/125] D Loss: 1.5028 G Loss: 10.3606


 81%|████████  | 101/125 [00:46<00:12,  1.90it/s]

[Epoch 18/20] [Batch 100/125] D Loss: 1.4991 G Loss: 10.8356


100%|██████████| 125/125 [00:56<00:00,  2.20it/s]
  1%|          | 1/125 [00:00<00:58,  2.13it/s]

[Epoch 19/20] [Batch 0/125] D Loss: 2.0122 G Loss: 11.1396


 41%|████      | 51/125 [00:23<00:38,  1.90it/s]

[Epoch 19/20] [Batch 50/125] D Loss: 1.6557 G Loss: 10.3003


 81%|████████  | 101/125 [00:46<00:12,  1.90it/s]

[Epoch 19/20] [Batch 100/125] D Loss: 1.5751 G Loss: 10.8766


100%|██████████| 125/125 [00:56<00:00,  2.20it/s]
  generator.load_state_dict(torch.load("generator_99.pth"))


FileNotFoundError: [Errno 2] No such file or directory: 'generator_99.pth'

In [None]:
generator = GenreConditionedGenerator().to(device)
generator.load_state_dict(torch.load("generator_9.pth"))

convert_audio(
    input_path="/content/gtzan_dataset/genres_original/blues/blues.00000.wav",
    output_path="converted_to_classical.wav",
    source_genre="blues",
    target_genre="classical",
    generator=generator,
    dataset=dataset
)

  generator.load_state_dict(torch.load("generator_9.pth"))


In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchaudio
import torchaudio.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import librosa
import numpy as np
import os
from glob import glob
from tqdm import tqdm

# Constants
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAMPLE_RATE = 22050
N_MELS = 128
N_FFT = 2048
HOP_LENGTH = 512
BATCH_SIZE = 8
EPOCHS = 10
LEARNING_RATE = 2e-4
NUM_GENRES = 10
DURATION = 5  # 5-second clips

# Dataset Class
class GTZANDataset(Dataset):
    def __init__(self, root_dir, duration=DURATION):
        self.file_paths = glob(os.path.join(root_dir, "genres_original", "*", "*.wav"))
        if not self.file_paths:
            raise ValueError(f"No .wav files found in {root_dir}")

        self.labels = [os.path.basename(os.path.dirname(fp)) for fp in self.file_paths]
        self.label_dict = {genre: idx for idx, genre in enumerate(sorted(set(self.labels)))}
        self.duration = duration
        self.time_dim = (duration * SAMPLE_RATE) // HOP_LENGTH + 1

        self.mel_transform = transforms.MelSpectrogram(
            sample_rate=SAMPLE_RATE,
            n_mels=N_MELS,
            n_fft=N_FFT,
            hop_length=HOP_LENGTH
        )

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        label = self.label_dict[self.labels[idx]]

        waveform, _ = librosa.load(file_path, sr=SAMPLE_RATE, duration=self.duration)
        if len(waveform) < SAMPLE_RATE * self.duration:
            waveform = np.pad(waveform, (0, SAMPLE_RATE * self.duration - len(waveform)), 'constant')

        waveform = torch.FloatTensor(waveform)
        mel_spec = self.mel_transform(waveform)

        # Ensure exact dimensions
        mel_spec = mel_spec[..., :self.time_dim] if mel_spec.shape[-1] > self.time_dim else F.pad(mel_spec, (0, self.time_dim - mel_spec.shape[-1]))

        # Log scaling with clipping
        mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
        mel_spec = (mel_spec - mel_spec.mean()) / (mel_spec.std() + 1e-9)

        return mel_spec, label

# Fixed Generator Architecture
class GenreConditionedGenerator(nn.Module):
    def __init__(self, num_genres=NUM_GENRES):
        super().__init__()
        self.genre_embedding = nn.Embedding(num_genres, 128)

        # Initial convolution
        self.initial = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=7, stride=1, padding=3),
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(0.2)
        )

        # Downsample
        self.down1 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2)
        )

        # Residual blocks
        self.resblocks = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(128, 128, kernel_size=3, padding=1),
                nn.InstanceNorm2d(128),
                nn.LeakyReLU(0.2),
                nn.Conv2d(128, 128, kernel_size=3, padding=1),
                nn.InstanceNorm2d(128)
            ) for _ in range(6)
        ])

        # Genre projection
        self.genre_proj = nn.Linear(128, 128)

        # Upsample
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, output_padding=0),
            nn.InstanceNorm2d(64),
            nn.ReLU()
        )

        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1, output_padding=0),
            nn.Tanh()
        )

    def forward(self, x, genre):
        # Store original size
        original_size = x.size()

        # Initial processing
        x = self.initial(x)

        # Downsample
        x = self.down1(x)

        # Add genre conditioning
        genre_emb = self.genre_embedding(genre)
        genre_emb = self.genre_proj(genre_emb)
        genre_emb = genre_emb.view(genre_emb.size(0), genre_emb.size(1), 1, 1)
        genre_emb = genre_emb.expand(-1, -1, x.size(2), x.size(3))
        x = x + genre_emb

        # Residual blocks
        for block in self.resblocks:
            residual = x
            x = block(x) + residual

        # Upsample
        x = self.up1(x)
        x = self.up2(x)

        # Match dimensions
        if x.size() != original_size:
            x = F.interpolate(x, size=original_size[2:], mode='bilinear', align_corners=False)

        return x

# Discriminator
class GenreAwareDiscriminator(nn.Module):
    def __init__(self, num_genres=NUM_GENRES):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=2),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=2),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.AdaptiveAvgPool2d(1)
        )

        self.validity = nn.Linear(256, 1)
        self.genre_classifier = nn.Linear(256, num_genres)

    def forward(self, x):
        features = self.main(x).view(x.size(0), -1)
        return self.validity(features), self.genre_classifier(features)

# Training Function
def train():
    dataset = GTZANDataset("/content/gtzan_dataset")
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    generator = GenreConditionedGenerator().to(device)
    discriminator = GenreAwareDiscriminator().to(device)

    # Loss functions
    adversarial_loss = nn.BCEWithLogitsLoss()
    genre_loss = nn.CrossEntropyLoss()
    cycle_loss = nn.L1Loss()

    # Optimizers
    g_optimizer = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

    best_g_loss = float('inf')

    for epoch in range(EPOCHS):
        g_running_loss = 0.0
        d_running_loss = 0.0

        for real, real_genre in tqdm(dataloader):
            real = real.unsqueeze(1).to(device)
            real_genre = real_genre.to(device)

            # Random target genres
            target_genre = torch.randint(0, NUM_GENRES, (real.size(0),)).to(device)

            # Train Discriminator
            d_optimizer.zero_grad()

            fake = generator(real, target_genre)

            # Real loss
            real_validity, real_genre_pred = discriminator(real)
            d_real_loss = adversarial_loss(real_validity, torch.ones_like(real_validity))
            d_real_genre_loss = genre_loss(real_genre_pred, real_genre)

            # Fake loss
            fake_validity, _ = discriminator(fake.detach())
            d_fake_loss = adversarial_loss(fake_validity, torch.zeros_like(fake_validity))

            d_total_loss = d_real_loss + d_fake_loss + d_real_genre_loss
            d_total_loss.backward()
            d_optimizer.step()

            # Train Generator
            g_optimizer.zero_grad()

            fake = generator(real, target_genre)
            reconstructed = generator(fake, real_genre)

            # Adversarial loss
            fake_validity, fake_genre_pred = discriminator(fake)
            g_adv_loss = adversarial_loss(fake_validity, torch.ones_like(fake_validity))

            # Genre classification loss
            g_genre_loss = genre_loss(fake_genre_pred, target_genre)

            # Cycle consistency
            g_cycle_loss = cycle_loss(reconstructed, real)

            g_total_loss = g_adv_loss + g_genre_loss + 10*g_cycle_loss
            g_total_loss.backward()
            g_optimizer.step()

            g_running_loss += g_total_loss.item()
            d_running_loss += d_total_loss.item()

        # Save best model
        if g_running_loss < best_g_loss:
            best_g_loss = g_running_loss
            torch.save(generator.state_dict(), "best_generator.pth")
            torch.save(discriminator.state_dict(), "best_discriminator.pth")

        print(f"Epoch {epoch+1}/{EPOCHS} | G Loss: {g_running_loss/len(dataloader):.4f} | D Loss: {d_running_loss/len(dataloader):.4f}")

# Conversion Function
def convert_audio(input_path, output_path, source_genre, target_genre):
    # Load models
    generator = GenreConditionedGenerator().to(device)
    discriminator = GenreAwareDiscriminator().to(device)
    generator.load_state_dict(torch.load("best_generator.pth"))
    discriminator.load_state_dict(torch.load("best_discriminator.pth"))
    generator.eval()
    discriminator.eval()

    # Load audio
    y, sr = librosa.load(input_path, sr=SAMPLE_RATE, duration=DURATION)
    if len(y) < SAMPLE_RATE * DURATION:
        y = np.pad(y, (0, SAMPLE_RATE * DURATION - len(y)), 'constant')

    # Compute STFT
    stft = librosa.stft(y, n_fft=N_FFT, hop_length=HOP_LENGTH)
    mag = np.abs(stft)
    phase = np.angle(stft)

    # Convert to tensor
    mag = torch.from_numpy(mag).float().unsqueeze(0).unsqueeze(0).to(device)
    mag = torch.log(torch.clamp(mag, min=1e-5))
    mag = (mag - mag.mean()) / (mag.std() + 1e-9)

    # Generate transformed spectrogram
    with torch.no_grad():
        target_idx = dataset.label_dict[target_genre]
        transformed = generator(mag, torch.tensor([target_idx]).to(device))

        # Predict genre
        _, genre_pred = discriminator(transformed)
        predicted_idx = torch.argmax(genre_pred).item()
        predicted_genre = list(dataset.label_dict.keys())[predicted_idx]
        print(f"Predicted genre: {predicted_genre}")

    # Convert back to audio
    transformed = transformed.squeeze().cpu().numpy()
    transformed = np.exp(transformed)

    # Reconstruct with phase
    y_recon = librosa.griffinlim(
        transformed * np.exp(1j * phase),
        n_iter=64,
        hop_length=HOP_LENGTH,
        n_fft=N_FFT
    )

    # Save output
    import soundfile as sf
    sf.write(output_path, y_recon, sr)
    print(f"Converted {source_genre} to {target_genre}. Saved to {output_path}")

if __name__ == "__main__":
    # Initialize dataset
    dataset = GTZANDataset("/content/gtzan_dataset")
    print("Available genres:", list(dataset.label_dict.keys()))

    # Train models
    #train()

    # Example conversion
    convert_audio(
        input_path="/content/gtzan_dataset/genres_original/disco/disco.00003.wav",
        output_path="disco_to_class.wav",
        source_genre="disco",
        target_genre="classical"
    )

Available genres: ['blues', 'classical', 'country', 'disco', 'hiphop', 'jazz', 'metal', 'pop', 'reggae', 'rock']
Predicted genre: classical


  generator.load_state_dict(torch.load("best_generator.pth"))
  discriminator.load_state_dict(torch.load("best_discriminator.pth"))


Converted disco to classical. Saved to disco_to_class.wav
