## Import

In [None]:
# imports go here
import os
import pickle
import numpy as np
import librosa
import random
import soundfile as sf
import time
import re
import datetime
import math
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.nn.utils import clip_grad_norm_
from torch.optim import AdamW
from torch.utils.data import TensorDataset, DataLoader, Dataset, RandomSampler


## Global variables

In [None]:
TIMIT_ROOT = "/kaggle/working/data"
LOGMEL_ROOT = "/kaggle/working/data/logmels/"
CHECKPT_DIR = "/kaggle/working/SV-checkpts/checkpoint5"

num_frames = 180                # Number of frames after preprocessing
hop = 0.01                      # Hop length in s
window = 0.025                  # Window size in s
n_fft = 512                     # Length of windowed signal after padding
sr = 16000                      # Sampling rate
win_length = int(window * sr)   # Window length
hop_length = int(hop * sr)      # Hop length
n_mels = 40                     # Number of Mel bands
epsilon = 1e-8                  # Small amount to add to avoid taking log of 0

n_hidden = 768                  # Dimensionality of LSTM outputs
n_projection = 256              # Dimensionality after projection
num_layers = 3                  # Number of LSTM layers
n_speakers = 6                  # Number of speakers per batch
n_utterances_per_speaker = 10   # Number of utterances per speaker each batch

BATCH_SIZE = 16                 # Batch size
NUM_EPOCHS = 5                 # Number of epochs

force_restart_training = False  # Force training to restart from epoch 0
save = True                     # Whether to save model parameters
load_opts = True                # Load optimizer states along with model param values
halve_after_every = 12          # Number of epochs after which to halve learning rate

## Dataset

In [None]:
def get_spectrograms_for_file(file_path):
  """
  Returns the log mel specrogram's first and last n_frames frames for each "portion". 
  Implementation is based on 
  """
  min_length = (num_frames*hop + window)*sr
  # Load the audio
  y, _ = librosa.load(file_path, sr=sr)
  # Split the audio into non-silent intervals. 
  # Reference implementation takes top_db (thresh for silence) to be 30, but librosa
  # default is 60.
  intervals = librosa.effects.split(y, top_db=30)
  extracted = []
  for i in range(intervals.shape[0]):
    begin = intervals[i][0]
    end = intervals[i][1]
    if end - begin <= min_length:
      continue
    # Extract relevant portion of wav
    yp = y[begin:end]
    # Perform STFT
    stft = librosa.stft(y=yp, n_fft=n_fft, win_length=win_length, hop_length=hop_length)
    # Squared magnitude of stft - abs necessary because complex
    sqmag = np.abs(stft) ** 2
    # Get mel basis
    M = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels)
    # Extract log mel spectrogram
    logmel = np.log10(np.dot(M, sqmag) + epsilon)
    # Return the first and last n_frames frames
    extracted.append(logmel[:, :num_frames])
    extracted.append(logmel[:, -num_frames:])
  return extracted

def get_spectrograms_for_speaker(speaker_dir):
  """
  Given a directory with a speaker's utterances, returns the concatenated list
  of extracted log mel features from them *after* converting it into a numpy array.
  """
  extracted = []
  for fname in os.listdir(speaker_dir):
    if fname.endswith(".WAV.wav"):
      extracted += get_spectrograms_for_file(os.path.join(speaker_dir, fname))
  return np.array(extracted)

def save_spectrograms(splits = ["TRAIN", "TEST"]):
  """
  Call only once. Goes through each speaker dir and saves the generated spectrograms
  under LOGMEL_ROOT/{split}
  """
  for split in splits:
    split_data_dir = os.path.join(TIMIT_ROOT, split)
    split_logmel_dir = os.path.join(LOGMEL_ROOT, split)
    for DR in os.listdir(split_data_dir):
      DR_dir = os.path.join(split_data_dir, DR)
      for speaker in os.listdir(DR_dir):
        extracted = get_spectrograms_for_speaker(os.path.join(DR_dir, speaker))
        out_file = os.path.join(split_logmel_dir, "{}.npy".format(speaker))
        np.save(open(out_file, 'wb+'), extracted)

def load_data(splits = ["TRAIN", "TEST"], min_samples=4):
  """
  Loads the dataset -- removes all speakers with < 4 examples.
  """
  data = {}
  for split in splits:
    part = []
    ldir = os.path.join(LOGMEL_ROOT, split)
    for fname in os.listdir(ldir):
      if not fname.endswith(".npy"):
        continue
      narray = np.load(open(os.path.join(ldir, fname), "rb"))
      if narray.shape[0] < min_samples:
        continue
      part.append(narray)
    data[split] = part
  return data

In [None]:
save_spectrograms()

## Train model

In [None]:
class SpeakerVerificationDataset(Dataset):
  def __init__(self, logmels, n_speakers=n_speakers, \
    n_samples_per_speaker=n_utterances_per_speaker, total_examples=80000):
    """
    total_examples is the number of examples drawn per epoch
    """
    self.logmels = logmels
    self.n_total_speakers = len(self.logmels)
    self.n_speakers = n_speakers
    self.n_samples_per_speaker = n_samples_per_speaker
    self.total_examples = total_examples

  def __len__(self):
    return self.total_examples

  def __getitem__(self, idx):
    """
    For now we simply ignore idx and return a random sample
    """
    # First, select n different random speakers
    # Use the commented code when number of speakers is more
    # speakers = np.random.permutation(self.n_total_speakers)[:self.n_speakers]
    speakers = []
    while len(speakers) < self.n_speakers:
      speaker = random.randint(0, self.n_total_speakers-1)
      if speaker not in speakers:
        speakers.append(speaker)
    data = []
    for speaker in speakers:
      # We may have as low as 8-10 (up to 28) examples per speaker, and we want to choose
      # 4-10 of them. A permutation likely avoids the otherwise many tries.
      utter_idxs = np.random.permutation(self.logmels[speaker].shape[0])[:self.n_samples_per_speaker]
      utterances = torch.from_numpy(self.logmels[speaker][utter_idxs, :, :])
      data.append(utterances)
    item = torch.stack(data)
    # Currently have (speaker, utterance, mel, frames)
    # Reorder to (speaker, utterance, frames, mel)
    return torch.permute(item, (0,1,3,2))

In [None]:
data = load_data(min_samples=n_utterances_per_speaker)
train_dataset = SpeakerVerificationDataset(data['TRAIN'])
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE)

In [None]:
class SpeakerEmbedder(nn.Module):
  """
  The input to this model is of shape (batch_size*N*M, frames, mel)
  """
  def __init__(self):
    super(SpeakerEmbedder, self).__init__()
    self.LSTMs = nn.LSTM(input_size=n_mels, hidden_size=n_hidden, \
                         num_layers=num_layers, batch_first=True)
    self.FC = nn.Linear(n_hidden, n_projection)

  def forward(self, x):
    LSTMs_out, _ = self.LSTMs(x)
    # Current shape is (batch_size*N*M, n_timesteps, n_hidden)
    # Need only the last time step
    last_out = LSTMs_out[:, LSTMs_out.size(1)-1]
    # Now the shape is (batch_size*N*M, n_hidden)
    FC_out = self.FC(last_out)
    # Normalize each "row"
    FC_out = FC_out / torch.linalg.norm(FC_out, axis=1).unsqueeze(axis=1)
    return FC_out

class LossModule(nn.Module):
  # Values taken from 
  def __init__(self):
    super(LossModule, self).__init__()
    self.w = nn.Parameter(torch.tensor(10.0), requires_grad=True)
    self.b = nn.Parameter(torch.tensor(-5.0), requires_grad=True)

  def forward(self, embeddings):
    # The input should be in the shape (batch_size, N, M, n_projection)
    # First get the centroids
    centroids = torch.mean(embeddings, dim=2)
    N = embeddings.shape[1]
    M = embeddings.shape[2]
    S = torch.zeros(BATCH_SIZE, N, M, N)
    loss = 0
  
    for b in range(BATCH_SIZE):
      for j in range(N):
        for i in range(M):
          for k in range(N):
            if j == k:
              # In this case recompute centroid to not include current example
              centroid = (M*centroids[b,k] - embeddings[b,j,i]) / (M-1)
            else:
              centroid = centroids[b,k]
            S[b,j,i,k] = self.w*torch.dot(embeddings[b,j,i], centroid) + self.b
            if j == k:
              loss -= S[b,j,i,k]
    expsum = torch.sum(torch.exp(S), axis=-1)
    loss += torch.sum(torch.log(expsum))
    return loss

In [None]:
embedder = SpeakerEmbedder()
lossmodule = LossModule()
embedder_optimizer = AdamW(embedder.parameters(), lr=1e-3, eps=epsilon)
lossmodule_optimizer = AdamW(lossmodule.parameters(), lr=1e-3, eps=epsilon)

In [None]:
if torch.cuda.is_available():
  print("Using GPU: {}".format(torch.cuda.get_device_name(0)))
  device = torch.device("cuda")
  embedder.cuda()
  lossmodule.cuda()
else:
  print("No GPUs available, using CPU")
  device = torch.device("cpu")

In [None]:
def format_time(elapsed):
  elapsed_rounded = int(round(elapsed))
  return str(datetime.timedelta(seconds=elapsed_rounded))

def get_max_checkpt(checkpt_dir):
  max_checkpt = 0
  for filename in os.listdir(checkpt_dir):
    if re.match(r"checkpt-embedder-([0-9]+).pt", filename):
      checkpt_num = int(filename.split('.')[-2].split('-')[-1])
      if checkpt_num > max_checkpt:
        max_checkpt = checkpt_num
  return max_checkpt

def load_latest_checkpt(checkpt_dir=CHECKPT_DIR):
  if force_restart_training:
    return 0
  mx_checkpt = get_max_checkpt(checkpt_dir)
  if mx_checkpt > 0:
    embedder_path = os.path.join(checkpt_dir, "checkpt-embedder-{}.pt".format(mx_checkpt))
    lossmodule_path = os.path.join(checkpt_dir, "checkpt-lossmodule-{}.pt".format(mx_checkpt))
    embedder_opt_path = os.path.join(checkpt_dir, "checkpt-eopt-{}.pt".format(mx_checkpt))
    lossmodule_opt_path = os.path.join(checkpt_dir, "checkpt-lopt-{}.pt".format(mx_checkpt))
    embedder.load_state_dict(torch.load(embedder_path))
    lossmodule.load_state_dict(torch.load(lossmodule_path))
    if load_opts:
      embedder_optimizer.load_state_dict(torch.load(embedder_opt_path))
      lossmodule_optimizer.load_state_dict(torch.load(lossmodule_opt_path))
  return mx_checkpt

In [None]:
def train_models():
  start_epoch = load_latest_checkpt()
  for epoch in range(start_epoch, NUM_EPOCHS):
    print("============ Epoch {} / {} ============".format(epoch+1, NUM_EPOCHS))
    print("Training phase")
    epoch_loss = 0.0
    embedder.train()
    lossmodule.train()
    epoch_start = time.time()
    if (epoch+1) % halve_after_every == 0:
      for param_group in embedder_optimizer.param_groups:
        param_group['lr'] /= 2
      for param_group in lossmodule_optimizer.param_groups:
        param_group['lr'] /= 2
    for step, batch in enumerate(train_dataloader):
      batch = batch.to(device)
      if step % 40 == 0 and step != 0:
        elapsed = format_time(time.time() - epoch_start)
        print("Batch {} of {}. Elapsed {}".format(step, len(train_dataloader), elapsed))
      N = batch.shape[1]
      M = batch.shape[2]
      embedder_in = batch.reshape(BATCH_SIZE*N*M, batch.shape[3], batch.shape[4])
      embedder.zero_grad()
      lossmodule.zero_grad()
      embeddings = embedder(embedder_in)
      embeddings = embeddings.reshape(BATCH_SIZE, N, M, n_projection)
      loss = lossmodule(embeddings)
      loss.backward()
      epoch_loss += loss.detach()
      clip_grad_norm_(embedder.parameters(), 3.0)
      clip_grad_norm_(lossmodule.parameters(), 1.0)
      embedder_optimizer.step()
      lossmodule_optimizer.step()
    epoch_loss /= len(train_dataloader) * BATCH_SIZE
    print("Epoch finished. Average training loss: {}".format(epoch_loss))

    if save:
      embedder_path = os.path.join(CHECKPT_DIR, "checkpt-embedder-{}.pt".format(epoch+1))
      lossmodule_path = os.path.join(CHECKPT_DIR, "checkpt-lossmodule-{}.pt".format(epoch+1))
      embedder_opt_path = os.path.join(CHECKPT_DIR, "checkpt-eopt-{}.pt".format(epoch+1))
      lossmodule_opt_path = os.path.join(CHECKPT_DIR, "checkpt-lopt-{}.pt".format(epoch+1))
      torch.save(embedder.state_dict(), embedder_path)
      torch.save(lossmodule.state_dict(), lossmodule_path)
      torch.save(embedder_optimizer.state_dict(), embedder_opt_path)
      torch.save(lossmodule_optimizer.state_dict(), lossmodule_opt_path)

In [None]:
train_models()

## Load model after train

In [None]:
embedder = SpeakerEmbedder()
lossmodule = LossModule()
if torch.cuda.is_available():
	embedder.cuda()
	lossmodule.cuda()
embedder_path = os.path.join(CHECKPT_DIR, "checkpt-embedder-{}.pt".format(5))
embedder.load_state_dict(torch.load(embedder_path))

In [None]:
from torch.nn.utils.rnn import pad_sequence
def run_utts_through_model(embedder, utts):
    utts = torch.stack(utts)
    utts = torch.permute(utts, (0,2,1))
    device = next(embedder.parameters()).device
    utts = utts.to(device)
    reps = embedder(utts)
    reps_list = [reps[i] for i in range(reps.shape[0])]
    return reps_list

In [None]:
def similarity(path1, path2):
    """
    Given the path to two audio paths and a threshold, computes
    their similarity score and decides whether both depict the same
    person.
    """
    utts1 = get_spectrograms_for_file(path1)
    utts2 = get_spectrograms_for_file(path2)
    nutts1 = len(utts1)
    nutts2 = len(utts2)
    if nutts1 == 0 or nutts2 == 0:
        avg_similarity = 0
        return avg_similarity
    #print(nutts1)
    #print(nutts2)
    utts = utts1 + utts2
    utts = [torch.from_numpy(utt) for utt in utts]
    
    reps = run_utts_through_model(embedder, utts)
    reps1, reps2 = reps[:nutts1], reps[nutts1:]
    avg_similarity = 0
    for i in range(nutts1):
        for j in range(nutts2):
            avg_similarity += torch.dot(reps1[i], reps2[j])
    avg_similarity /= nutts1 * nutts2

    return avg_similarity

In [None]:
import csv

# Đọc file CSV
results = []

with open('/kaggle/input/test-list-public/test_list_private.csv', 'r') as csv_file:
    reader = csv.reader(csv_file, delimiter=" ")  # tách theo dấu cách
    for i, row in enumerate(reader):
        if len(row) >= 2:
            file1, file2 = row[0], row[1]
            #print(f"/kaggle/input/privatetestspech/private-test-data-sv/{file1}")
            result = similarity(f"/kaggle/input/datatest/wav/{file1}", f"/kaggle/input/datatest/wav/{file2}")  # Gọi hàm processfile
            value = result.item() if isinstance(result, torch.Tensor) else result
            results.append(f"{value:.4f}")
            print(f"dòng{i}: {result:.4f}")
max_value = max(results)
max_index = results.index(max_value)
print(f"Max = {max_value}, ở vị trí {max_index}")
# Xử lý từng cặp file


    #result = similarity(f"/kaggle/input/privatetestspech/private-test-data-sv/{file1}", f"/kaggle/input/privatetestspech/private-test-data-sv/{file2}")  # Gọi hàm processfile
    #value = result.item() if isinstance(result, torch.Tensor) else result
    #results.append(f"{value:.4f}")
    #print(f"dòng{i}: {result:.4f}")

# Ghi kết quả vào file TXT
with open('/kaggle/working/predictions.txt', 'w') as txt_file: 
    for line in results: 
        txt_file.write(line + '\n')