# Imports:

In [3]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
from torchvision.models import resnet50, ResNet50_Weights
from transformers import AutoTokenizer, AutoModel


  from .autonotebook import tqdm as notebook_tqdm


# Model:

In [4]:
class EmotionResNet50(nn.Module):
    def __init__(self, num_classes, dropout_rate=0.4):
        super(EmotionResNet50, self).__init__()
        self.base_model = resnet50(weights=ResNet50_Weights.DEFAULT)

        # Get input size of the original fully connected layer
        in_features = self.base_model.fc.in_features

        # Replace the final fc layer with Dropout + Linear
        self.base_model.fc = nn.Sequential(
            nn.Dropout(p=dropout_rate),
            nn.Linear(in_features, num_classes)
        )

    def forward(self, x):
        return self.base_model(x)
    
    def get_embedding(self, x):
        # Forward until the layer before classification
        x = self.base_model.conv1(x)
        x = self.base_model.bn1(x)
        x = self.base_model.relu(x)
        x = self.base_model.maxpool(x)

        x = self.base_model.layer1(x)
        x = self.base_model.layer2(x)
        x = self.base_model.layer3(x)
        x = self.base_model.layer4(x)

        x = self.base_model.avgpool(x)
        x = torch.flatten(x, 1)
        return x  # This is your emotion embedding

In [5]:
class ResidualConvBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels)
        )
        self.activation = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        return self.activation(x + self.block(x))

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_residual=True):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.residual = ResidualConvBlock(out_channels) if use_residual else nn.Identity()

    def forward(self, x):
        x = self.initial(x)
        return self.residual(x)


class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(in_channels, out_channels),
            nn.MaxPool2d(2)
        )
    
    def forward(self, x):
        return self.block(x)


class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.attn = SkipAttention(out_channels)
        self.conv = ConvBlock(out_channels * 2, out_channels)

    def forward(self, x, skip):
        x = self.up(x)
        if x.shape != skip.shape:
            x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=True)
        skip = self.attn(skip)
        x = torch.cat([x, skip], dim=1)
        return self.conv(x)

class SkipAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.attn = nn.Sequential(
            nn.Conv2d(channels, channels // 2, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // 2, channels, 1),
            nn.Sigmoid()
        )

    def forward(self, skip):
        return skip * self.attn(skip)



In [6]:
class MelEncoder(nn.Module):
    def __init__(self, embedding_dim=1024):
        super().__init__()
        self.initial_conv = ConvBlock(1, 32)
        self.down1 = DownBlock(32, 64)       # H/2, W/2
        self.down2 = DownBlock(64, 128)      # H/4, W/4
        self.down3 = DownBlock(128, 256)      # H/8, W/8
        self.down4 = DownBlock(256, 512)      # H/16, W/16
        self.final_block = nn.Sequential(
            nn.MaxPool2d(2),                  # H/32, W/32
            ConvBlock(512, embedding_dim)
        )

    def forward(self, x):
        skips = []
        x = self.initial_conv(x)   # -> [B, 32, H, W]
        skips.append(x)

        x = self.down1(x)          # -> [B, 64, H/2, W/2]
        skips.append(x)

        x = self.down2(x)          # -> [B, 128, H/4, W/4]
        skips.append(x)

        x = self.down3(x)          # -> [B, 256, H/8, W/8]
        skips.append(x)

        x = self.down4(x)          # -> [B, 512, H/16, W/16]
        skips.append(x)

        x = self.final_block(x)    # -> [B, 1024, H/32, W/32]
        return x, skips

In [7]:
class TransformerEncoder(nn.Module):
    def __init__(self, model_name="bert-base-multilingual-cased"):
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_name)

    def forward(self, input_ids, attention_mask=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled = outputs.last_hidden_state.mean(dim=1)  # Mean pooling over tokens
        return pooled  # shape: [batch_size, hidden_dim]
    
tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")


In [8]:
class DiffusionModelDataset(Dataset):
    def __init__(self, csv_path, ser_model, mel_encoder, text_model, tokenizer, max_text_len=200, transform=None):
        """
        Args:
            csv_path (str): Path to CSV containing paths and labels
            ser_model: Speech Emotion Recognition model class
            mel_encoder: MelEncoder model class
            text_model: SmallTextTransformer instance
            tokenizer: Text tokenizer (e.g., from transformers)
            max_text_len: Maximum text sequence length
            transform: Optional transforms for mel spectrograms
        """
        self.df = pd.read_csv(csv_path)
        self.mel_paths = self.df["mel_npy_path"].tolist()
        self.labels = self.df["emotion_encoded"].tolist()
        self.texts = self.df["sentence"].tolist()
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.tokenizer = tokenizer
        self.max_text_len = max_text_len
        
        # Initialize models
        self.ser_model = ser_model(num_classes=6).to(self.device)
        self.mel_encoder = mel_encoder(embedding_dim=1024).to(self.device)
        self.text_model = text_model().to(self.device)
        
        # Load checkpoints
        ser_checkpoint = torch.load('models/SER.pth', map_location=self.device)
        self.ser_model.load_state_dict(ser_checkpoint['model_state_dict'])
        self.ser_model.eval()
        
        mel_checkpoint = torch.load('models/MelEncoder.pth', map_location=self.device)
        self.mel_encoder.load_state_dict(mel_checkpoint)
        self.mel_encoder.eval()
        
        self.transform = transform if transform else transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((224, 224)),
            transforms.Lambda(lambda x: x.repeat(3, 1, 1))
        ])

    def __len__(self):
        return len(self.mel_paths)

    def __getitem__(self, idx):
        # Load and process Mel-spectrogram
        mel_spec = np.load(self.mel_paths[idx])
        mel_spec = (mel_spec - mel_spec.min()) / (mel_spec.max() - mel_spec.min())
        mel_tensor = self.transform(
            Image.fromarray((mel_spec * 255).astype(np.uint8), mode='L')
        ).to(self.device)
        
        # Input for the MelEncoder
        mel = torch.tensor(mel_spec, dtype=torch.float32).unsqueeze(0).to(self.device)
        init_mel = mel.squeeze(0)
        
        # Process text
        text = self.texts[idx]
        inputs = self.tokenizer(
            text,
            max_length=self.max_text_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        ).to(self.device)
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
        
        # Get embeddings
        with torch.no_grad():
            # Audio embeddings
            emotions_embedding = self.ser_model.get_embedding(mel_tensor.unsqueeze(0))
            emotions_embedding = emotions_embedding.view(-1)
            utterance_vector, _ = self.mel_encoder(mel.unsqueeze(0))
            utterance_vector = utterance_vector.squeeze(0)
            
            # Text embedding
            text_embedding = self.text_model(input_ids, attention_mask)
            text_embedding = text_embedding.squeeze(0)
        
        return {
            # 'audio_emotion': emotions_embedding.squeeze(0),
            # 'mel_embedding': utterance_vector.squeeze(0),
            # 'text_embedding': text_embedding.squeeze(0),
            # 'label': torch.tensor(self.labels[idx].to(self.device))
            'emotion_embedding': emotions_embedding,
            'mel_embedding': utterance_vector,
            'text_embedding': text_embedding,
            'initial_mel': init_mel
        }

In [9]:
train_dataset = DiffusionModelDataset(
    csv_path='Data/split_train_val/train_split.csv',
    ser_model=EmotionResNet50,  # Replace with your actual class
    mel_encoder=MelEncoder,
    text_model=TransformerEncoder,
    tokenizer=tokenizer  # e.g., BertTokenizer.from_pretrained("bert-base-uncased")
)

val_dataset = DiffusionModelDataset(
    csv_path='Data/split_train_val/val_split.csv',
    ser_model=EmotionResNet50,
    mel_encoder=MelEncoder,
    text_model=TransformerEncoder,
    tokenizer=tokenizer  # e.g., BertTokenizer.from_pretrained("bert-base-uncased")
)

  ser_checkpoint = torch.load('models/SER.pth', map_location=self.device)
  mel_checkpoint = torch.load('models/MelEncoder.pth', map_location=self.device)
  ser_checkpoint = torch.load('models/SER.pth', map_location=self.device)
  mel_checkpoint = torch.load('models/MelEncoder.pth', map_location=self.device)


In [10]:
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=True)

In [11]:
# Fetch one batch
for batch in train_dataloader:
    print("emotion_embedding shape:", batch['emotion_embedding'].shape)
    print("mel_embedding shape:", batch['mel_embedding'].shape)
    print("text_embedding shape:", batch['text_embedding'].shape)
    print("initial me shape:", batch['initial_mel'].shape)

    # # Optional: look at the data itself
    # print("Sample label:", batch['label'][0])
    # print("Mel embedding (sample):", batch['mel_embedding'][0][:5])  # Show first few values

    break  # Only one batch for testing

emotion_embedding shape: torch.Size([2, 2048])
mel_embedding shape: torch.Size([2, 1024, 4, 26])
text_embedding shape: torch.Size([2, 768])
initial me shape: torch.Size([2, 128, 861])


In [12]:
class EmbeddingProjector(nn.Module):
    def __init__(self, emotion_dim=2048, mel_style_dim=1024, mel_h=4, mel_w=26, text_dim=768, cond_dim=1024):
        super(EmbeddingProjector, self).__init__()
        self.cond_dim = cond_dim

        # Projection layers for each modality
        self.proj_emotion = nn.Linear(emotion_dim, cond_dim)
        self.proj_text = nn.Linear(text_dim, cond_dim)
        
        # First projection to a larger intermediate size (e.g., 8192)
        self.proj_mel_first = nn.Linear(mel_style_dim * mel_h * mel_w, 8192)
        # Second projection to a middle-size (e.g., 4096)
        self.proj_mel_intermediate = nn.Linear(8192, 4096)
        # Final projection to cond_dim
        self.proj_mel = nn.Linear(4096, cond_dim)

    def forward(self, emotion_embedding, mel_embedding, text_embedding):
        # emotion_embedding: [B, 2048] --> [B, cond_dim]
        cond_emotion = self.proj_emotion(emotion_embedding)

        # text_embedding: [B, 768] --> [B, cond_dim]
        cond_text = self.proj_text(text_embedding)

        # mel_embedding: [B, 1024, 4, 26] --> [B, cond_dim]
        mel_flat = mel_embedding.view(mel_embedding.size(0), -1)  # Flatten to [B, 1024*4*26]
        
        # Pass through the first projection layer to a large intermediate dimension (8192)
        mel_intermediate_first = self.proj_mel_first(mel_flat)  # [B, 8192]
        
        # Pass through the second projection to a medium intermediate dimension (4096)
        mel_intermediate = self.proj_mel_intermediate(mel_intermediate_first)  # [B, 4096]
        
        # Final projection to cond_dim (e.g., 1024)
        cond_mel = self.proj_mel(mel_intermediate)  # [B, cond_dim]

        return cond_emotion, cond_text, cond_mel


In [13]:
class FiLM(nn.Module):
    def __init__(self, cond_dim, feature_dim):
        super().__init__()
        self.film = nn.Linear(cond_dim, feature_dim * 2)

    def forward(self, x, cond):
        gamma_beta = self.film(cond)  # [B, 2 * C]
        gamma, beta = gamma_beta.chunk(2, dim=1)  # [B, C] each
        gamma = gamma.unsqueeze(-1).unsqueeze(-1)
        beta = beta.unsqueeze(-1).unsqueeze(-1)
        return gamma * x + beta

In [14]:
class ResNetBlockWithFiLM(nn.Module):
    def __init__(self, in_channels, out_channels, cond_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.film = FiLM(cond_dim, out_channels)
        self.skip = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()

    def forward(self, x, cond):
        identity = self.skip(x)
        out = self.relu(self.conv1(x))
        out = self.conv2(out)
        out = self.film(out, cond)
        out = self.relu(out + identity)
        return out


In [15]:
def forward_diffusion(self, mel_spec, timesteps):
    # Add noise to the mel spectrogram at each timestep
    noise_schedule = torch.linspace(0, 1, timesteps).to(mel_spec.device)
    noisy_mel_specs = []

    for t in noise_schedule:
        noise = torch.randn_like(mel_spec) * t
        noisy_mel_spec = mel_spec + noise
        noisy_mel_specs.append(noisy_mel_spec)

    return torch.stack(noisy_mel_specs)  # shape: [timesteps, B, C, H, W]


In [16]:
def reverse_diffusion(self, noisy_latent_tensor, emotion_embedding, mel_embedding, text_embedding, timesteps):
    # Initialize the noisy latent tensor and embeddings
    noise_schedule = torch.linspace(1, 0, timesteps).to(noisy_latent_tensor.device)
    denoised_output = noisy_latent_tensor  # Start with the noisy latent tensor

    for t in range(timesteps):
        t_idx = timesteps - t - 1  # Reverse order for denoising

        # Project embeddings to a common dimension
        cond_emotion, cond_text, cond_mel = self.embedding_projector(
            emotion_embedding, mel_embedding, text_embedding
        )

        # Apply ResNet blocks with FiLM conditioning layers (use the conditioning embeddings)
        denoised_output = self.resnet_block_1(denoised_output, cond_emotion)
        denoised_output = self.downsample_1(denoised_output)
        denoised_output = self.resnet_block_2(denoised_output, cond_mel)
        denoised_output = self.downsample_2(denoised_output)

        # Denoising through the rest of the layers
        denoised_output = self.resnet_block_3(denoised_output, cond_text)
        denoised_output = self.upsample_1(denoised_output)
        denoised_output = self.resnet_block_4(denoised_output, cond_mel)
        denoised_output = self.upsample_2(denoised_output)
        denoised_output = self.resnet_block_5(denoised_output, cond_emotion)

        # Final convolution layer to generate the output
        denoised_output = self.final_conv(denoised_output)

        # Gradually reduce the noise according to the noise schedule
        denoised_output = denoised_output / (1 - noise_schedule[t_idx])

    return denoised_output


In [17]:
class EmotionDiffusionModel(nn.Module):
    def __init__(self, emotion_dim=2048, mel_style_dim=1024, mel_h=4, mel_w=26, text_dim=768, cond_dim=1024):
        super(EmotionDiffusionModel, self).__init__()

        # Embedding Projector (projects input embeddings to a common dimension)
        self.embedding_projector = EmbeddingProjector(
            emotion_dim=emotion_dim,
            mel_style_dim=mel_style_dim,
            mel_h=mel_h,
            mel_w=mel_w,
            text_dim=text_dim,
            cond_dim=cond_dim
        )

        # Initial convolution layer (typically applied to the noisy latent tensor)
        self.initial_conv = nn.Conv2d(1, 256, kernel_size=3, padding=1)

        # Downsample/upsample blocks
        self.downsample_1 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
        self.downsample_2 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1)
        self.upsample_1 = nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1)
        self.upsample_2 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)

        # Final convolution layer to match output dimension
        self.final_conv = nn.Conv2d(256, 1, kernel_size=3, padding=1)

        # ResNet blocks with FiLM conditioning layers
        self.resnet_block_1 = ResNetBlockWithFiLM(256, 256, cond_dim)
        self.resnet_block_2 = ResNetBlockWithFiLM(512, 512, cond_dim)
        self.resnet_block_3 = ResNetBlockWithFiLM(1024, 1024, cond_dim)
        self.resnet_block_4 = ResNetBlockWithFiLM(512, 512, cond_dim)
        self.resnet_block_5 = ResNetBlockWithFiLM(256, 256, cond_dim)

    def forward(self, emotion_embedding, mel_embedding, text_embedding, noisy_latent_tensor, timesteps):
        # Reverse Diffusion
        return self.reverse_diffusion(noisy_latent_tensor, emotion_embedding, mel_embedding, text_embedding, timesteps)

    def forward_diffusion(self, mel_spec, timesteps):
        return self.forward_diffusion(mel_spec, timesteps)


In [18]:
# Initialize the model with desired dimensions
emotion_dim = 2048  # Emotion embedding dimension
mel_style_dim = 1024  # Mel-style embedding dimension
mel_h, mel_w = 4, 26  # Mel spectrogram dimensions
text_dim = 768  # Text embedding dimension
cond_dim = 1024  # Common embedding dimension for conditioning

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = EmotionDiffusionModel(
    emotion_dim=emotion_dim,
    mel_style_dim=mel_style_dim,
    mel_h=mel_h,
    mel_w=mel_w,
    text_dim=text_dim,
    cond_dim=cond_dim
).to(device)  # move to GPU if available


In [19]:
def train_model(model, train_loader, val_loader, num_epochs, device, save_dir, save_every=5):
    """
    Trains the diffusion model with emotional conditioning using composite loss.

    Args:
        model (nn.Module): The EmotionDiffusionModel.
        train_loader (DataLoader): DataLoader for training data.
        val_loader (DataLoader): DataLoader for validation data.
        num_epochs (int): Number of training epochs.
        device (torch.device): Training device (CPU or GPU).
        save_dir (str): Directory to save checkpoints.
        save_every (int): Save model every `save_every` epochs.

    Returns:
        model: Trained model.
    """

    os.makedirs(save_dir, exist_ok=True)

    optimizer = optim.Adam(
        model.parameters(),
        lr=1e-3, weight_decay=1e-4
    )

    # Define both loss functions
    mse_loss = nn.MSELoss()
    l1_loss = nn.L1Loss()

    model = model.to(device)

    for epoch in range(1, num_epochs + 1):
        # -------------------- Train --------------------
        model.train()
        train_loss = 0.0

        for batch in train_loader:
            emotion_embedding = batch['emotion_embedding'].to(device)
            mel_embedding = batch['mel_embedding'].to(device)
            text_embedding = batch['text_embedding'].to(device)
            mel_spec = batch['initial_mel'].to(device)  # Ground truth mel-spectrogram

            # Forward diffusion: add noise
            noisy_mel_specs = model.forward_diffusion(mel_spec, timesteps=100)
            noisy_latent_tensor = noisy_mel_specs[-1]

            # Reverse diffusion: denoise
            output = model(emotion_embedding, mel_embedding, text_embedding,
                           noisy_latent_tensor, timesteps=100)

            # Composite loss
            loss = 0.7 * mse_loss(output, mel_spec) + 0.3 * l1_loss(output, mel_spec)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)

        # -------------------- Evaluation --------------------
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                emotion_embedding = batch['emotion_embedding'].to(device)
                mel_embedding = batch['mel_embedding'].to(device)
                text_embedding = batch['text_embedding'].to(device)
                mel_spec = batch['mel_spec'].to(device)

                noisy_mel_specs = model.forward_diffusion(mel_spec, timesteps=100)
                noisy_latent_tensor = noisy_mel_specs[-1]

                output = model(emotion_embedding, mel_embedding, text_embedding,
                               noisy_latent_tensor, timesteps=100)

                # Composite loss for validation
                loss = 0.7 * mse_loss(output, mel_spec) + 0.3 * l1_loss(output, mel_spec)
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)

        # -------------------- Logging --------------------
        print(f"Epoch {epoch:3d} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

        # -------------------- Checkpoint --------------------
        if epoch % save_every == 0:
            checkpoint_path = os.path.join(save_dir, f"model_epoch_{epoch}.pt")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss
            }, checkpoint_path)
            print(f"Checkpoint saved to: {checkpoint_path}")

    return model


In [20]:
def forward_diffusion(self, mel_spec, timesteps):
    return self.add_noise(mel_spec, timesteps)

In [21]:
trained_model = train_model(
    model=model,
    train_loader=train_dataloader,
    val_loader=val_dataloader,
    num_epochs=50,
    device=device,
    save_dir="./checkpoints/diffusion_model",
    
    save_every=5
)

: 

In [None]:
def load_best_checkpoint(model, checkpoint_dir, device):
    best_loss = float('inf')
    best_path = None

    # Loop over all checkpoint files
    for filename in os.listdir(checkpoint_dir):
        if filename.endswith(".pt"):
            path = os.path.join(checkpoint_dir, filename)
            checkpoint = torch.load(path, map_location=device)
            val_loss = checkpoint.get("val_loss", float("inf"))
            if val_loss < best_loss:
                best_loss = val_loss
                best_path = path

    if best_path is None:
        raise FileNotFoundError("No valid checkpoints found.")

    # Load the best model
    print(f"Loading best checkpoint: {best_path} (val_loss = {best_loss:.4f})")
    checkpoint = torch.load(best_path, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.to(device)
    model.eval()
    return model

In [None]:
def test_model_on_first_5(model, val_loader, device):
    model.eval()
    count = 0

    with torch.no_grad():
        for batch in val_loader:
            for i in range(batch['mel_spec'].size(0)):
                if count >= 5:
                    return  # Only process 5 samples

                emotion_embedding = batch['emotion_embedding'][i:i+1].to(device)
                mel_embedding = batch['mel_embedding'][i:i+1].to(device)
                text_embedding = batch['text_embedding'][i:i+1].to(device)
                mel_spec = batch['mel_spec'][i:i+1].to(device)

                noisy_mel = model.forward_diffusion(mel_spec, timesteps=100)[-1]
                predicted = model(emotion_embedding, mel_embedding, text_embedding,
                                  noisy_mel, timesteps=100)

                print(f"\nSample #{count+1}")
                print(f"Predicted shape: {predicted.shape}")
                print(f"Target shape:    {mel_spec.shape}")

                # You can optionally convert to numpy or save audio/mel images here
                count += 1

In [None]:
# Assuming you already have model architecture defined as EmotionDiffusionModel
model = EmotionDiffusionModel(...)  # fill in required args

# Load best checkpoint
model = load_best_checkpoint(model, checkpoint_dir="./checkpoints/diffusion_model", device=device)

# Test on first 5 validation samples
test_model_on_first_5(model, val_dataloader, device)

In [None]:
def test_model_on_first_5(model, val_loader, device, save_dir="results"):
    model.eval()
    os.makedirs(save_dir, exist_ok=True)
    count = 0

    with torch.no_grad():
        for batch in val_loader:
            for i in range(batch['mel_spec'].size(0)):
                if count >= 5:
                    return  # Only process 5 samples

                # Extract the i-th sample from the batch
                emotion_embedding = batch['emotion_embedding'][i:i+1].to(device)
                mel_embedding = batch['mel_embedding'][i:i+1].to(device)
                text_embedding = batch['text_embedding'][i:i+1].to(device)
                mel_spec = batch['mel_spec'][i:i+1].to(device)

                # Forward and reverse diffusion
                noisy_mel = model.forward_diffusion(mel_spec, timesteps=100)[-1]
                predicted = model(emotion_embedding, mel_embedding, text_embedding,
                                  noisy_mel, timesteps=100)

                # Convert predicted tensor to numpy and save
                predicted_np = predicted.squeeze(0).cpu().numpy()  # shape: [1, 80, T] -> [80, T]
                save_path = os.path.join(save_dir, f"predicted_mel_{count+1}.npy")
                np.save(save_path, predicted_np)

                print(f"Sample #{count+1} saved to {save_path}")
                count += 1