In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
import faster_whisper
import h5py
import whisperx
import numpy as np
import pandas as pd
import utils
import basic_adapter_utils
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset

In [2]:
adapter = basic_adapter_utils.LinearAdapter()
embeddings_file = 'embeddings.h5'

In [7]:
df = pd.read_csv('LibriVox_Kaggle_org.csv')
train, test = train_test_split(df, random_state=42)

In [11]:
device = "cuda:1"
optimizer = optim.Adam(adapter.parameters(), lr=0.001)
loss_fn = nn.MSELoss()

In [10]:
# Set up data and model
input_embeddings = []
target_embeddings = []

for f in [file for file in train['file']]:
    inn, tar = utils.load_embeddings_and_rir_from_hdf5(embeddings_file, f)
    input_embeddings.append(inn)
    target_embeddings.append(tar)


In [14]:
input_embeddings = torch.FloatTensor(np.array(input_embeddings))
target_embeddings = torch.FloatTensor(np.array(target_embeddings))

dataset = TensorDataset(input_embeddings, target_embeddings)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

In [16]:
# Training loop
num_epochs = 25
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
adapter.to(device)

for epoch in range(num_epochs):
    adapter.train()
    total_loss = 0
    
    for batch_input, batch_target in dataloader:
        batch_input, batch_target = batch_input.to(device), batch_target.to(device)
        
        # Forward pass
        outputs = adapter(batch_input)
        loss = loss_fn(outputs, batch_target)
        
        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    # Print average loss for the epoch
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")

Epoch [1/25], Average Loss: 0.0645
Epoch [2/25], Average Loss: 0.0644
Epoch [3/25], Average Loss: 0.0643
Epoch [4/25], Average Loss: 0.0642
Epoch [5/25], Average Loss: 0.0642
Epoch [6/25], Average Loss: 0.0640
Epoch [7/25], Average Loss: 0.0640
Epoch [8/25], Average Loss: 0.0640
Epoch [9/25], Average Loss: 0.0639
Epoch [10/25], Average Loss: 0.0638
Epoch [11/25], Average Loss: 0.0637
Epoch [12/25], Average Loss: 0.0637
Epoch [13/25], Average Loss: 0.0636
Epoch [14/25], Average Loss: 0.0636
Epoch [15/25], Average Loss: 0.0635
Epoch [16/25], Average Loss: 0.0635
Epoch [17/25], Average Loss: 0.0634
Epoch [18/25], Average Loss: 0.0634
Epoch [19/25], Average Loss: 0.0633
Epoch [20/25], Average Loss: 0.0633
Epoch [21/25], Average Loss: 0.0632
Epoch [22/25], Average Loss: 0.0631
Epoch [23/25], Average Loss: 0.0631
Epoch [24/25], Average Loss: 0.0631
Epoch [25/25], Average Loss: 0.0630


In [17]:
# Save the trained model
torch.save(adapter.state_dict(), 'linear_adapter_50epochs.pth')