In [1]:
from src.requirements import *
from src.ssl_model import *
from src.tokenizer import *
import csv, random

In [2]:
class TripletDataset(Dataset):
    def __init__(self, triplets):
        self.triplets = triplets

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, idx):
        path_a, path_b, path_x, label = self.triplets[idx]
        
        wave_a = load_wave(path_a)
        wave_b = load_wave(path_b)
        wave_x = load_wave(path_x)

        return wave_a, wave_b, wave_x, label

def collate_fn(batch):
    waves_a, waves_b, waves_x, labels = zip(*batch)
    
    waves_a = rnn_utils.pad_sequence(waves_a, batch_first=True)
    waves_b = rnn_utils.pad_sequence(waves_b, batch_first=True)
    waves_x = rnn_utils.pad_sequence(waves_x, batch_first=True)
    
    return waves_a.unsqueeze(1), waves_b.unsqueeze(1), waves_x.unsqueeze(1), torch.tensor(labels)

In [3]:
@torch.no_grad()
def abx_test_batched(model, loader, device):
    model.eval()
    correct = 0
    total = 0

    for batch_a, batch_b, batch_x, labels in loader:
        batch_a, batch_b, batch_x = batch_a.to(device), batch_b.to(device), batch_x.to(device)
        labels = labels.to(device)

        emb_a = model.extract_features(batch_a).mean(dim=1)
        emb_b = model.extract_features(batch_b).mean(dim=1)
        emb_x = model.extract_features(batch_x).mean(dim=1)

        sim_xa = F.cosine_similarity(emb_x, emb_a)
        sim_xb = F.cosine_similarity(emb_x, emb_b)

        # If label is 0: we want sim_xa > sim_xb
        # If label is 1: we want sim_xb > sim_xa
        is_match_a = (labels == 0)
        
        # Where label is 0, check if X is closer to A. 
        # Where label is 1, check if X is closer to B.
        pred_correct = torch.where(is_match_a, sim_xa >= sim_xb, sim_xb >= sim_xa)
        
        correct += pred_correct.sum().item()
        total += labels.size(0)

    return correct / total

In [4]:
def load_tsv_dataset(tsv_path):
    dataset = []
    with open(tsv_path, "r", encoding="utf-8") as f:
        reader = csv.reader(f, delimiter="\t")
        for row in reader:
            if len(row) != 2:
                continue
            audio_path, text = row
            dataset.append({
                "audio_path": audio_path,
                "text": text.strip()
            })
    return dataset

In [5]:
def tokenize_chars(tokenizer, text):
    return tokenizer.tokenize(text)

In [6]:
def find_minimal_pairs(dataset, tokenizer, max_pairs=2000):
    tokenized = []
    for i, item in enumerate(dataset):
        tokens = tokenize_chars(tokenizer, item["text"])
        tokenized.append((i, tokens))

    pairs = []
    for i in range(len(tokenized)):
        idx_i, tok_i = tokenized[i]
        for j in range(i + 1, len(tokenized)):
            idx_j, tok_j = tokenized[j]

            if len(tok_i) != len(tok_j):
                continue

            diffs = [k for k in range(len(tok_i)) if tok_i[k] != tok_j[k]]
            if len(diffs) == 1:
                pairs.append((idx_i, idx_j, diffs[0]))

            if len(pairs) >= max_pairs:
                return pairs

    return pairs

In [7]:
def load_wave(path, target_sr=16000):
    waveform, sr = sf.read(path, always_2d=True)
    waveform = torch.tensor(waveform, dtype=torch.float32)

    if waveform.ndim == 2:
        waveform = waveform.T
        waveform = waveform.mean(dim=0, keepdim=True)
    elif waveform.ndim == 1:
        waveform = waveform.unsqueeze(0)

    wave_np = waveform.squeeze(0).numpy()
    trimmed, _ = librosa.effects.trim(wave_np, top_db=TOP_DB)
    waveform = torch.tensor(trimmed, dtype=torch.float32).unsqueeze(0)
        
    max_val = waveform.abs().max()
    if max_val > 0:
        waveform = waveform / max_val

    if sr != 16_000:
        waveform = torchaudio.functional.resample(waveform, sr, 16_000)

    waveform = waveform.squeeze(0)
    return waveform

In [8]:
def generate_abx_triplets(dataset, tokenizer, max_triplets=500):
    minimal_pairs = find_minimal_pairs(dataset, tokenizer)
    triplets = []

    text_to_entries = {}
    for d in dataset:
        text_to_entries.setdefault(d["text"], []).append(d)

    for idx_A, idx_B, _ in minimal_pairs:
        A = dataset[idx_A]
        B = dataset[idx_B]

        candidates_A = [d for d in text_to_entries[A["text"]] if d["audio_path"] != A["audio_path"]]
        candidates_B = [d for d in text_to_entries[B["text"]] if d["audio_path"] != B["audio_path"]]

        if candidates_A:
            X = random.choice(candidates_A)
            label = 0  # X matches A
        elif candidates_B:
            X = random.choice(candidates_B)
            label = 1  # X matches B
        else:
            continue

        # wave_A = load_wave(A["audio_path"])
        # wave_B = load_wave(B["audio_path"])
        # wave_X = load_wave(X["audio_path"])

        triplets.append((A["audio_path"], B["audio_path"], X["audio_path"], label))

        # triplets.append((wave_A, wave_B, wave_X, label))

        if len(triplets) >= max_triplets:
            break

    return triplets

In [9]:
tokenizer = Tokenizer.load(os.path.join("data", "tokenizer.json"))
dataset = load_tsv_dataset(os.path.join("data", "metadata.tsv"))

Final Vocabulary Size after filtering: 494
Blank ID: 0


In [10]:
device = "cuda" if torch.cuda.is_available() else "cpu"
update_ver = 9_000
checkpoint_dict = torch.load(os.path.join('models', 'ssl_model', f'ssl_model_prototype_{update_ver}.pth'))
ssl_state_dict = checkpoint_dict['model_state_dict']
ssl_model = SSLModel().to(device)
ssl_model.load_state_dict(ssl_state_dict, strict=True)

<All keys matched successfully>

In [11]:
triplets = generate_abx_triplets(dataset, tokenizer, max_triplets=500)

abx_triplet_ds = TripletDataset(triplets)
abx_loader = DataLoader(abx_triplet_ds, batch_size=16, collate_fn=collate_fn)

accuracy = abx_test_batched(ssl_model, abx_loader, device)

print(f"Final ABX Score: {accuracy * 100:.2f}%")

Final ABX Score: 79.60%
