In [1]:
from datasets import load_dataset, Audio, DatasetDict, Dataset, load_from_disk, IterableDatasetDict, interleave_datasets, concatenate_datasets

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

import math
import random

from pprint import pprint as pp

from torchvision.utils import make_grid
from copy import deepcopy




In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [7]:
checkpoint = "Lie_CommonVoice_IT_1_1_epochs.pth"
transformer_weights = torch.load(checkpoint, map_location=device, weights_only=False)
transformer_weights

{'epoch': 0,
 'model_state_dict': OrderedDict([('module.encoder.efficientnet.0.0.weight',
               tensor([[[[-1.8082e-01, -8.8772e-02, -1.9788e-01],
                         [ 2.8449e-02, -1.9688e-01, -7.0861e-02],
                         [ 6.8080e-02,  1.4433e-01, -3.9065e-02]]],
               
               
                       [[[-2.9002e-01, -1.0473e-01,  2.3296e-01],
                         [ 2.5491e-01,  2.2601e-02, -2.2772e-01],
                         [ 2.6379e-01,  3.4542e-01,  1.7845e-01]]],
               
               
                       [[[ 2.1327e-01,  2.0310e-01,  8.0709e-02],
                         [ 1.4639e-01,  2.4996e-01,  6.8514e-02],
                         [ 1.0678e-01,  1.6061e-01,  5.1621e-02]]],
               
               
                       [[[ 2.1159e-01, -1.8501e-01, -2.2527e-02],
                         [ 7.4048e-02,  1.6755e-01,  2.0936e-01],
                         [-3.4707e-01,  2.9960e-01,  3.0280e-01]]],
              

In [5]:
upd_weights = {}

for k,v in transformer_weights['model_state_dict'].items():
    upd_weights[k.replace('module.','')] = v

In [6]:
upd_weights

{'encoder.efficientnet.0.0.weight': tensor([[[[-1.8082e-01, -8.8772e-02, -1.9788e-01],
           [ 2.8449e-02, -1.9688e-01, -7.0861e-02],
           [ 6.8080e-02,  1.4433e-01, -3.9065e-02]]],
 
 
         [[[-2.9002e-01, -1.0473e-01,  2.3296e-01],
           [ 2.5491e-01,  2.2601e-02, -2.2772e-01],
           [ 2.6379e-01,  3.4542e-01,  1.7845e-01]]],
 
 
         [[[ 2.1327e-01,  2.0310e-01,  8.0709e-02],
           [ 1.4639e-01,  2.4996e-01,  6.8514e-02],
           [ 1.0678e-01,  1.6061e-01,  5.1621e-02]]],
 
 
         [[[ 2.1159e-01, -1.8501e-01, -2.2527e-02],
           [ 7.4048e-02,  1.6755e-01,  2.0936e-01],
           [-3.4707e-01,  2.9960e-01,  3.0280e-01]]],
 
 
         [[[-1.9009e-01,  6.9791e-02,  2.2615e-01],
           [-7.8283e-02, -2.4292e-01,  1.7412e-01],
           [-2.5191e-01, -3.3254e-01, -2.3182e-01]]],
 
 
         [[[-7.1229e-02,  2.2056e-01, -2.0919e-01],
           [ 2.3203e-01,  8.1474e-02,  5.4201e-02],
           [-9.7665e-02,  2.3563e-01,  2.8190e-01]]

In [3]:
ds = IterableDatasetDict()
ds["val"] = load_from_disk("StS_Nemo_CommonVoice_val.hf")
ds["train"] = load_from_disk("StS_Nemo_CommonVoice_train.hf")

FileNotFoundError: Directory StS_Nemo_CommonVoice_val.hf not found

In [None]:
def smooth_1d_field(length, epsilon=1.0, kernel_size=33, device='cpu'):
    """
    Generates smoothed 1D noise field (shape: [length])
    """
    field = torch.randn(length, device=device) * epsilon
    field = F.avg_pool1d(field[None, None, :], kernel_size=kernel_size, stride=1, padding=kernel_size // 2)
    return field.squeeze(0).squeeze(0)  # shape: [length]

def smooth_2d_field(height, width, epsilon=1.0, kernel_size=15, device='cpu'):
    """
    Generates smoothed 2D noise field (shape: [height, width])
    """
    field = torch.randn(1, 1, height, width, device=device) * epsilon
    field = F.avg_pool2d(field, kernel_size=kernel_size, stride=1, padding=kernel_size // 2)
    return field.squeeze(0).squeeze(0)  # shape: [height, width]

In [None]:
# def sinusoidal_1d_field(length, num_components=3, epsilon=0.1, device='cpu'):
#     t = torch.linspace(0, 1, length, device=device)
#     field = torch.zeros_like(t)
#     for _ in range(num_components):
#         freq = torch.rand(1).item() * 5 + 1  # freq ∈ [1,6]
#         phase = torch.rand(1).item() * 2 * torch.pi
#         amplitude = torch.randn(1).item() * epsilon
#         field += amplitude * torch.sin(2 * torch.pi * freq * t + phase)
#     return field


def sinusoidal_1d_field(length, epsilon=0.2, num_components=2, device='cpu'):
    """
    Generate a large-scale 1D sinusoidal field with few low-frequency basis components.
    """
    t = torch.linspace(0, 1, length, device=device)
    field = torch.zeros_like(t)
    for _ in range(num_components):
        freq = torch.rand(1).item() * 1.5 + 0.5   # total ≤ 2 cycles
        phase = torch.rand(1).item() * 2 * math.pi
        amplitude = torch.randn(1).item() * epsilon
        field += amplitude * torch.sin(2 * math.pi * freq * t + phase)
    return field


In [None]:
def localized_smooth_field_2d(num_freq_bins, num_time_steps,
                              num_sin=3, 
                              mask_radius_frac=0.25, 
                              epsilon=0.3, 
                              device='cpu'):
    """
    Create a smooth 2D sinusoidal field applied only to a local region (via a soft 2D mask).
    Output: Tensor of shape [F, T]
    """
    F, T = num_freq_bins, num_time_steps

    # === Step 1: Create smooth sinusoidal-ish field
    t = torch.linspace(0, 1, T, device=device)
    f = torch.linspace(0, 1, F, device=device)
    tt, ff = torch.meshgrid(t, f, indexing='xy')  # [T, F] then transpose

    base_field = torch.zeros(T, F, device=device)
    for _ in range(num_sin):
        freq_t = torch.rand(1).item() * 3 + 1
        freq_f = torch.rand(1).item() * 3 + 1
        phase_t = torch.rand(1).item() * 2 * math.pi
        phase_f = torch.rand(1).item() * 2 * math.pi
        amplitude = torch.randn(1).item() * epsilon

        wave = amplitude * torch.sin(2 * math.pi * freq_t * tt + phase_t) * \
                             torch.sin(2 * math.pi * freq_f * ff + phase_f)

        base_field += wave.T

    base_field = base_field.T  # [F, T]

    # === Step 2: Create a soft mask
    center_f = torch.randint(int(F * 0.25), int(F * 0.75), (1,)).to(device)
    center_t = torch.randint(int(T * 0.25), int(T * 0.75), (1,)).to(device)
    sigma_f = int(F * mask_radius_frac)
    sigma_t = int(T * mask_radius_frac)

    f_coords = torch.arange(F, device=device).unsqueeze(1).to(device)
    t_coords = torch.arange(T, device=device).unsqueeze(0).to(device)

    mask = torch.exp(-((f_coords - center_f) ** 2) / (2 * sigma_f ** 2)) * \
           torch.exp(-((t_coords - center_t) ** 2) / (2 * sigma_t ** 2))  # [F, T]

    # === Step 3: Apply mask to the smooth field
    localized_field = base_field * mask

    return localized_field


In [None]:
def generate_lie_generator_fields(num_freq_bins, num_time_steps, epsilon_dict=None, device='cpu'):
    """
    Returns a dict of all 5 Lie generator fields, each of shape [F, T]
    """
    if epsilon_dict is None:
        epsilon_dict = {
            't_stretch': 0.05,
            'f_stretch': 0.05,
            'warp_2d': 0.05,
            'amplitude': 0.1,
            'phase': 0.1,
        }

    # 1. Time stretch: v(t) → broadcast to (F, T)
    v_t = sinusoidal_1d_field(num_time_steps, epsilon=epsilon_dict['t_stretch'], device=device)
    v_t_broadcasted = v_t.unsqueeze(0).expand(num_freq_bins, -1)

    # 2. Frequency stretch: w(f) → broadcast to (F, T)
    w_f = sinusoidal_1d_field(num_freq_bins, epsilon=epsilon_dict['f_stretch'], device=device)
    w_f_broadcasted = w_f.unsqueeze(1).expand(-1, num_time_steps)

    # 3. 2D warp: v_2d and w_2d (already [F, T])
    v_2d = localized_smooth_field_2d(num_freq_bins, num_time_steps, epsilon=epsilon_dict['warp_2d'], device=device)
    w_2d = localized_smooth_field_2d(num_freq_bins, num_time_steps, epsilon=epsilon_dict['warp_2d'], device=device)

    # 4. Amplitude modulation α(f,t)
    alpha = localized_smooth_field_2d(num_freq_bins, num_time_steps, epsilon=epsilon_dict['amplitude'], device=device)

    # 5. Phase modulation β(f,t)
    beta = localized_smooth_field_2d(num_freq_bins, num_time_steps, epsilon=epsilon_dict['phase'], device=device)

    return {
        't_stretch': v_t_broadcasted,  # shape: [F, T]
        'f_stretch': w_f_broadcasted,  # shape: [F, T]
        'warp_2d': (v_2d,w_2d),        # shape: ([F, T],[F, T])
        'amplitude': alpha,            # shape: [F, T]
        'phase': beta                  # shape: [F, T]
    }


In [None]:
num_freq_bins = 80
num_time_steps = 512

fields = generate_lie_generator_fields(num_freq_bins, num_time_steps, device=device)

# for name, tensor in fields.items():
#     print(f"{name}: shape = {tensor.shape}")


In [None]:
plt.imshow(fields['t_stretch'].detach().cpu().squeeze(), aspect='auto', origin='lower')
plt.colorbar()

In [None]:
plt.imshow(fields['f_stretch'].detach().cpu().squeeze(), aspect='auto', origin='lower')
plt.colorbar()

In [None]:
plt.imshow(fields['warp_2d'][0].detach().cpu().squeeze(), aspect='auto', origin='lower')
plt.colorbar()

In [None]:
plt.imshow(fields['warp_2d'][1].detach().cpu().squeeze(), aspect='auto', origin='lower')
plt.colorbar()

In [None]:
plt.imshow(fields['amplitude'].detach().cpu().squeeze(), aspect='auto', origin='lower')
plt.colorbar()

In [None]:
plt.imshow(fields['phase'].detach().cpu().squeeze(), aspect='auto', origin='lower')
plt.colorbar()

In [None]:
def apply_transformation(spectrogram, field, mode='t_stretch'):
    """
    Apply a transformation to a log-Mel spectrogram of shape [F, T].
    Returns: [F, T]
    """
    F_bins, T_steps = spectrogram.shape
    device = spectrogram.device

    # Coordinate grids
    t_coords = torch.linspace(-1, 1, T_steps, device=device)
    f_coords = torch.linspace(-1, 1, F_bins, device=device)
    f_grid, t_grid = torch.meshgrid(f_coords, t_coords, indexing='ij')  # [F, T]

    if mode == 't_stretch':
        if field.ndim == 1:
            field = field.unsqueeze(0).expand(F_bins, -1)  # [F, T]
        delta_t = field / T_steps * 2
        delta_f = torch.zeros_like(delta_t)

    elif mode == 'f_stretch':
        if field.ndim == 1:
            field = field.unsqueeze(1).expand(-1, T_steps)
        delta_t = torch.zeros_like(field)
        delta_f = field / F_bins * 2

    elif mode == 'warp_2d':
        v_field, w_field = field  # both [F, T]
        delta_t = v_field / T_steps * 2
        delta_f = w_field / F_bins * 2

    elif mode == 'amplitude':
        return spectrogram * (1.0 + field)

    elif mode == 'phase':
        return spectrogram  # placeholder

    else:
        raise ValueError(f"Unsupported mode: {mode}")

    warped_t = t_grid + delta_t  # x-axis (time)
    warped_f = f_grid + delta_f  # y-axis (freq)
    grid = torch.stack([warped_t, warped_f], dim=-1)  # [F, T, 2]
    grid = grid.unsqueeze(0)  # [1, F, T, 2]

    # Input spectrogram: [1, 1, F, T]
    spec = spectrogram.unsqueeze(0).unsqueeze(0)

    warped = F.grid_sample(spec, grid, mode='bilinear', padding_mode='border', align_corners=True)
    return warped.squeeze(0).squeeze(0)  # [F, T]


In [None]:
def apply_random_transformations(spectrogram, num_transforms=1, transform_pool=None, device='cpu'):
    """
    Apply N randomly chosen transformations to the input spectrogram.
    """
    if transform_pool is None:
        transform_pool = ['t_stretch', 'f_stretch', 'warp_2d', 'amplitude', 'phase']

    selected = random.sample(transform_pool, num_transforms)
    S = spectrogram.clone().to(device)

    epsilon_dict = {
        't_stretch': 2.0,
        'f_stretch': 2.0,
        'warp_2d': 2.0,
        'amplitude': 1.0,
        'phase': 1.0,
    }

    fields = generate_lie_generator_fields(S.shape[-2], S.shape[-1], epsilon_dict=epsilon_dict, device=device)
    # pp(fields)

    for transform in selected:
        print(f"Applying {transform}")
        field = fields[transform].to(device)
        S = apply_transformation(S, field, mode=transform) 

    return S, fields, selected


In [None]:
idx = 11
sample = ds['train'][idx]

S_n = torch.tensor(sample['in_spectrogram']).to(device)
print(S_n.shape)

plt.figure()
plt.imshow(S_n.detach().cpu().squeeze(), aspect='auto', origin='lower')
plt.colorbar()
plt.show()

In [None]:
# transform_pool = ['t_stretch', 'f_stretch', 'warp_2d', 'amplitude', 'phase']

S_d, fields, transform = apply_random_transformations(S_n, transform_pool = ['amplitude'], device=device)

print(transform, S_d.shape)
print(f"In: {torch.sum(S_n)}, Out: {torch.sum(S_d)}")

# --- Plot everything
plt.figure(figsize=(18, 6))

plt.subplot(1, 4, 1)
plt.imshow(S_n.cpu(), aspect='auto', origin='lower', cmap='magma')
plt.title('Original Spectrogram')
plt.colorbar()

plt.subplot(1, 4, 2)
plt.imshow(S_d.detach().cpu(), aspect='auto', origin='lower', cmap='magma')
plt.title(f'Transformed Spectrogram ({transform[0]})')
plt.colorbar()

S_diff = S_d - S_n
plt.subplot(1, 4, 3)
plt.imshow(S_diff.detach().cpu(), aspect='auto', origin='lower', cmap='magma')
plt.title(f'Difference: ({transform[0]})')
plt.colorbar()


plt.subplot(1, 4, 4)
plt.imshow(fields[transform[0]].cpu(), aspect='auto', origin='lower', cmap='coolwarm')
plt.title(f'Transformation Field ({transform[0]})')
plt.colorbar()

plt.tight_layout()
plt.show()

In [None]:

def apply_random_curriculum_transform(S_n, epoch, max_epochs):
    """
    Apply 1–N progressively harder distortions based on current epoch.
    """
    # Curriculum parameters
    epsilon_base = 0.3
    epsilon_max = 2.0
    num_transforms_max = 4

    # Linearly increase ε and number of transforms
    progress = epoch / max_epochs
    epsilon = epsilon_base + (epsilon_max - epsilon_base) * progress
    epsilon_dict = {
        't_stretch': epsilon,
        'f_stretch': epsilon,
        'warp_2d': epsilon,
        'amplitude': epsilon,
        'phase': epsilon,
    }
    
    num_transforms = 1 + int(progress * (num_transforms_max - 1))

    # Randomly choose transformations
    available_modes = ['t_stretch', 'f_stretch', 'warp_2d', 'amplitude']
    selected = random.sample(available_modes, num_transforms)

    S = S_n.clone().to(device)
    fields = generate_lie_generator_fields(S.shape[-2], S.shape[-1], epsilon_dict=epsilon_dict, device=device)

    for transform in selected:
        # print(f"Applying {transform} with epsilon={epsilon}")
        field = fields[transform]
        S = apply_transformation(S, field, mode=transform) 

    return S, epsilon, selected


In [None]:
from torch.utils.data import Dataset

class SpectrogramAugmentDataset(Dataset):
    def __init__(self, data, max_epochs, transform_epoch_getter):
        """
        data: a list of dicts with 'in_spectrogram': Tensor [F, T]
        max_epochs: total training epochs (for curriculum)
        transform_epoch_getter: function returning current epoch
        """
        self.data = data
        self.max_epochs = max_epochs
        self.get_epoch = transform_epoch_getter
        self.train = True

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        spec = self.data[idx]['in_spectrogram']
        spec = torch.tensor(spec)
        epoch = self.get_epoch()
        if self.train:
            distorted, epsilon, selected = apply_random_curriculum_transform(spec, epoch, self.max_epochs)
        else:
            # distorted, epsilon, selected = apply_random_curriculum_transform(spec, self.max_epochs, self.max_epochs)
            distorted, epsilon, selected = apply_random_curriculum_transform(spec, epoch, self.max_epochs)
            
        return distorted, spec, epsilon, selected


In [None]:
from model import Processor, STSAutoEncoder
from losses import MCDLoss, PerceptualLoss, ContrastiveLoss, AutoencoderLoss



encoder_channels = [32, 16, 24, 40, 80, 112, 192, 320, 1280]

autoencoder = STSAutoEncoder(efficientnet_variant="efficientnet_b0", 
                 temporal_channels=512, 
                 reduced_channels=512, 
                 encoder_channels=encoder_channels, 
                 output_channels=1, 
                 target_F=80, 
                 target_T=512)

# encoder = torch.load("STS_Nemo_encoder_2.pth", map_location=device, weights_only=False) 
# decoder = torch.load("STS_Nemo_decoder_2.pth", map_location=device, weights_only=False)
encoder = torch.load("Lie_0_encoder_1.pth", map_location=device, weights_only=False) 
decoder = torch.load("Lie_0_decoder_1.pth", map_location=device, weights_only=False)


autoencoder.encoder.load_state_dict(encoder)
autoencoder.decoder.load_state_dict(decoder)
autoencoder = autoencoder.to(device)


In [None]:
import nemo.collections.asr as nemo_asr
from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor
from nemo.collections.asr.models import EncDecRNNTModel

asr_model_id = "nvidia/stt_it_conformer_transducer_large"
# asr_model_id = "nvidia/stt_en_conformer_transducer_large"

asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(asr_model_id)

asr_model.cfg.joint['fuse_loss_wer'] = False
asr_model.joint = EncDecRNNTModel.from_config_dict(asr_model.cfg.joint)
asr_model = asr_model.to(device)
asr_model.eval()

In [None]:
from model import Processor

processor = Processor(asr_model=asr_model, device=device)

In [None]:
loss_fn = AutoencoderLoss(perceptual_model=processor, 
                          decoder=autoencoder.decoder, 
                          embedding_dim=1280, 
                          lambda_mse=1e2, 
                          lambda_mcd=1.0, 
                          lambda_perceptual=1e4, 
                          lambda_contrastive=1.0)

loss_fn.contrastive_loss = torch.load('STS_contrastive_loss_3.pth', map_location=device, weights_only=False)

# Step 0/2, epoch 2: loss = 6.0458, MSE 0.0377, MCD 5.9370, perceptual 0.0001, contrastive 0.0710




In [None]:
from torch.utils.data import DataLoader
import torch.optim as optim

# Assume your dataset is a list of dicts: [{'in_spectrogram': Tensor[F, T]}, ...]

current_epoch = [0]  # wrapped in list for mutability
def get_epoch(): return current_epoch[0]

num_epochs = 4

train_ds = SpectrogramAugmentDataset(ds["train"].select(range(30000)), max_epochs=num_epochs, transform_epoch_getter=get_epoch)
train_ds.train = True

val_ds = SpectrogramAugmentDataset(ds["val"].select(range(3000)), max_epochs=num_epochs, transform_epoch_getter=get_epoch)
val_ds.train = False

train_loader = DataLoader(train_ds, batch_size=80, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=40, shuffle=True)


opt = optim.Adam(autoencoder.parameters(), lr=3e-4)

for i, batch in enumerate(val_loader):
    S_n, S_d, epsilon, selected = batch
    print(i)
    print(S_n.shape,S_d.shape)
    print(epsilon, selected)
    break

In [None]:
import torch.optim as optim
from torch.utils.data import DataLoader
import time
import torch.distributed as dist
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader, DistributedSampler
import gc
from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau

def train_model(
    model, train_loader, val_loader, perceptual_model, loss_fn,
    device="cuda", epochs=2, lr=1e-4, freeze_epochs=1
):
    """
    Train the autoencoder (Encoder + Decoder) using MSE, Perceptual, and Contrastive Loss.
    """
    
    model.encoder.to(device)
    model.decoder.to(device)
    loss_fn.to(device)

    losses = []
    val_losses = []
    
    # Define optimizer
    optimizer = optim.AdamW(list(model.parameters())+list(loss_fn.parameters()), lr=lr, weight_decay=1e-5)
    scheduler = ExponentialLR(optimizer, gamma=0.9)

    
    # Training loop
    for epoch in range(epochs):
        model.train()
        total_samples = 0
        running_loss = 0.0
        running_mse = 0.0
        running_mcd = 0.0
        running_perceptual = 0.0
        running_contrastive = 0.0
        
        start_time = time.time()

        print("TRAIN ============================================================================\n")

        if epoch < freeze_epochs:
            print(f"Epoch n.{epoch+1}, freezing encoder \n")
            for param in model.encoder.parameters():
                param.requires_grad = False
        else:
            print(f"Epoch n.{epoch+1}, unfreezing encoder \n")
            for param in model.encoder.parameters():
                param.requires_grad = True
         
        for batch_idx, (S_orig, S_dys, epsilon, selected) in enumerate(train_loader):
            
            S_orig = S_orig.to(device).unsqueeze(1)  # Original spectrograms (B, 1, F, T)
            S_dys = S_dys.to(device).unsqueeze(1)  # Transformed "dysarthric" spectrograms (B, 1, F, T)

            # Step 1: Encoder forward pass on Original spectrogram
            E_f_orig, E_t_orig, encoder_hidden_state_orig = model.encoder(S_orig)  # Extract spectral & temporal features
            E_orig = loss_fn.concat_embeds(E_f_orig, E_t_orig)  # Combined latent representation
            
            # Step 2: Encoder forward pass on Transformed spectrogram
            E_f_dys, E_t_dys, encoder_hidden_state_dys = model.encoder(S_dys)  # Extract spectral & temporal features
            E_dys = loss_fn.concat_embeds(E_f_dys, E_t_dys)  # Combined latent representation
            S_recon = model.decoder(E_f_dys, E_t_dys, encoder_hidden_state_dys)  # Trying to reconstruct normal spectrogram 

            # Step 3: Encoder forward pass on Reconstructed spectrogram
            E_f_recon, E_t_recon, _ = model.encoder(S_recon) # Calculate latent vectors for reconstructed spectrogram
            E_recon = loss_fn.concat_embeds(E_f_recon, E_t_recon)


            # Step 4: Prepare contrastive loss pairs
            batch_size = S_orig.shape[0]
            permutation = torch.randperm(batch_size)
            E_other = E_orig[permutation]  # Shuffle embeddings to create negative samples
            
            # Step 5: Compute Loss between original S_orig and S_recon, reconstructed from dysarthric S_dys
            # Perceptual Loss is L2 difference between ASR model embeddings (E_orig and E_recon), not autoencoder embeddings!
            # MSE and MCD are losses between original S_orig and reconstructed S_recon
            # Contrastive loss is between autoencoder embeddings of (E_orig, E_recon) and (E_other, E_recon)
            loss, mse, mcd, perceptual, contrastive = loss_fn(S_orig, S_recon, E_orig, E_recon, E_other) 

            losses.append(
                {
                    "loss": loss.item(),
                    "mse": mse.item(),
                    "mcd": mcd.item(),
                    "perceptual": perceptual.item(),
                    "contrastive": contrastive.item()
                }
            )
            
            # Step 6: Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_samples += S_orig.size(0)
            running_loss += loss.item() * S_orig.size(0)
            running_mse += mse.item() * S_orig.size(0)
            running_mcd +=mcd.item() * S_orig.size(0)
            running_perceptual += perceptual.item() * S_orig.size(0)
            running_contrastive += contrastive.item() * S_orig.size(0)
            
            if batch_idx%10 == 0:
                print(f"Train step {batch_idx}/{len(train_loader)}, epoch {epoch+1}: loss = {loss:.4f}, MSE {mse:.4f}, MCD {mcd:.4f}, perceptual {perceptual:.4f}, contrastive {contrastive:.4f}")             
                print(f"[ε = {epsilon[0]:.2f}, types = {set([mode for batch in selected for mode in batch])}\n")
            torch.cuda.empty_cache()

        print("VALIDATION ============================================================================\n")

        total_val_samples = 0
        val_loss = 0
        val_mse = 0.0
        val_mcd = 0.0
        val_perceptual = 0.0
        val_contrastive = 0.0    
        
        model.encoder.eval()
        model.decoder.eval()
        
        with torch.no_grad():
            for batch_idx, (S_orig, S_dys, epsilon, selected) in enumerate(val_loader):
                S_orig = S_orig.to(device).unsqueeze(1)  # Original spectrograms (B, 1, F, T)
                S_dys = S_dys.to(device).unsqueeze(1)  # Transformed "dysarthric" spectrograms (B, 1, F, T)
    
                # Val, Step 1: Encoder forward pass on Original spectrogram
                E_f_orig, E_t_orig, encoder_hidden_state_orig = model.encoder(S_orig)  # Extract spectral & temporal features
                E_orig = loss_fn.concat_embeds(E_f_orig, E_t_orig)  # Combined latent representation
                
                # Val, Step 2: Encoder forward pass on Transformed spectrogram
                E_f_dys, E_t_dys, encoder_hidden_state_dys = model.encoder(S_dys)  # Extract spectral & temporal features
                E_dys = loss_fn.concat_embeds(E_f_dys, E_t_dys)  # Combined latent representation
                S_recon = model.decoder(E_f_dys, E_t_dys, encoder_hidden_state_dys)  # Trying to reconstruct normal spectrogram 
    
                # Val, Step 3: Encoder forward pass on Reconstructed spectrogram
                E_f_recon, E_t_recon, _ = model.encoder(S_recon) # Calculate latent vectors for reconstructed spectrogram
                E_recon = loss_fn.concat_embeds(E_f_recon, E_t_recon)
    
                # Val, Step 4: Prepare contrastive loss pairs
                batch_size = S_orig.shape[0]
                permutation = torch.randperm(batch_size)
                E_other = E_orig[permutation]  # Shuffle embeddings to create negative samples

                # Val, Step 5: Compute loss
                loss, mse, mcd, perceptual, contrastive = loss_fn(S_orig, S_recon, E_orig, E_recon, E_other)  # Autoencoder loss     
                val_losses.append(
                    {
                        "loss": loss.item(),
                        "mse": mse.item(),
                        "mcd": mcd.item(),
                        "perceptual": perceptual.item(),
                        "contrastive": contrastive.item()
                    }
                )      

                total_val_samples += S_orig.size(0)
                val_loss += loss.item() * S_orig.size(0)      
                val_mse += mse.item() * S_orig.size(0)
                val_mcd +=mcd.item() * S_orig.size(0)
                val_perceptual += perceptual.item() * S_orig.size(0)
                val_contrastive += contrastive.item() * S_orig.size(0)
                if batch_idx%10 == 0:
                    print(f"Val step {batch_idx}/{len(val_loader)}, epoch {epoch+1}: loss = {loss:.4f}, MSE {mse:.4f}, MCD {mcd:.4f}, perceptual {perceptual:.4f}, contrastive {contrastive:.4f}")
                    print(f"[ε = {epsilon[0]:.2f}, types = {set([mode for batch in selected for mode in batch])}]\n\n")
                    
        scheduler.step()
        
        torch.save(model.encoder.state_dict(), f"Lie_1_encoder_{epoch}.pth")
        torch.save(model.decoder.state_dict(), f"Lie_1_decoder_{epoch}.pth")

        
        # Print Training Stats
        avg_loss = running_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        print("============================================================================\n")
        print(f"Epoch [{epoch+1}/{epochs}] - Loss: {avg_loss:.4f} - Val loss: {avg_val_loss:.4f} - Time: {time.time() - start_time:.2f}s")
        print(f"Train MSE {running_mse/len(train_loader):.4f}, Train MCD {running_mcd/len(train_loader):.4f}, Train perceptual {running_perceptual/len(train_loader):.4f}, Train contrastive {running_contrastive/len(train_loader):.4f}")
        print(f"Val MSE {val_mse/len(val_loader):.4f}, Val MCD {val_mcd/len(val_loader):.4f}, Val perceptual {val_perceptual/len(val_loader):.4f}, Val contrastive {val_contrastive/len(val_loader):.4f}")
        print(f"LR: {optimizer.param_groups[0]['lr']:.5f}\n\n")
        
        gc.collect()
        torch.cuda.empty_cache()

    print("Training Complete!")

    return losses, val_losses, loss_fn



In [None]:
losses, val_losses, loss_fn = train_model(autoencoder, 
                                          train_loader, 
                                          val_loader, 
                                          processor, 
                                          loss_fn,
                                          device="cuda", 
                                          epochs=num_epochs, 
                                          lr=1e-3,
                                          freeze_epochs=num_epochs//2)

