In [None]:
from src.requirements import *
from src.ssl_model import *
from src.tokenizer import *

In [None]:
@torch.no_grad()
def extract_embedding(model, waveform):
    features = model.extract_features(waveform)
    embedding = features.mean(dim=1)

    return embedding.squeeze(0)

In [None]:
def cosine_distance(a, b):
    return 1 - F.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0))

In [None]:
def abx_decision(A, B, X):
    d_xa = cosine_distance(X, A)
    d_xb = cosine_distance(X, B)

    return d_xa < d_xb

In [None]:
def abx_test(model, triplets, device):
    model.eval()
    correct = 0

    for wave_A, wave_B, wave_X, label in triplets:
        wave_A = wave_A.to(device)
        wave_B = wave_B.to(device)
        wave_X = wave_X.to(device)

        emb_A = extract_embedding(ssl_model, wave_A)
        emb_B = extract_embedding(ssl_model, wave_B)
        emb_X = extract_embedding(ssl_model, wave_X)

        if label == 0:
            pred = abx_decision(emb_A, emb_B, emb_X)
        else:
            pred = abx_decision(emb_B, emb_A, emb_X)

        correct += int(pred)

    return correct / len(triplets)

In [None]:
import csv

def load_tsv_dataset(tsv_path):
    dataset = []
    with open(tsv_path, "r", encoding="utf-8") as f:
        reader = csv.reader(f, delimiter="\t")
        for row in reader:
            if len(row) != 2:
                continue
            audio_path, text = row
            dataset.append({
                "audio_path": audio_path,
                "text": text.strip()
            })
    return dataset

In [None]:
dataset = load_tsv_dataset(os.path.join("data", "metadata.tsv"))

In [None]:
tokenizer = Tokenizer.load(os.path.join("data", "tokenizer.json"))

In [None]:
def tokenize_chars(tokenizer, text):
    return tokenizer.tokenize(text)

In [None]:
def find_minimal_pairs(dataset, tokenizer, max_pairs=2000):
    tokenized = []
    for i, item in enumerate(dataset):
        tokens = tokenize_chars(tokenizer, item["text"])
        tokenized.append((i, tokens))

    pairs = []
    for i in range(len(tokenized)):
        idx_i, tok_i = tokenized[i]
        for j in range(i + 1, len(tokenized)):
            idx_j, tok_j = tokenized[j]

            if len(tok_i) != len(tok_j):
                continue

            diffs = [k for k in range(len(tok_i)) if tok_i[k] != tok_j[k]]
            if len(diffs) == 1:
                pairs.append((idx_i, idx_j, diffs[0]))

            if len(pairs) >= max_pairs:
                return pairs

    return pairs

In [None]:
def load_wave(path, target_sr=16000):
    waveform, sr = sf.read(path, always_2d=True)
    waveform = torch.tensor(waveform, dtype=torch.float32)

    if waveform.ndim == 2:
        waveform = waveform.T
        waveform = waveform.mean(dim=0, keepdim=True)
    elif waveform.ndim == 1:
        waveform = waveform.unsqueeze(0)

    wave_np = waveform.squeeze(0).numpy()
    trimmed, _ = librosa.effects.trim(wave_np, top_db=TOP_DB)
    waveform = torch.tensor(trimmed, dtype=torch.float32).unsqueeze(0)
        
    max_val = waveform.abs().max()
    if max_val > 0:
        waveform = waveform / max_val

    if sr != 16_000:
        waveform = torchaudio.functional.resample(waveform, sr, 16_000)

    waveform = waveform.unsqueeze(0)
    return waveform

In [None]:
import random

def generate_abx_triplets(dataset, tokenizer, max_triplets=500):
    minimal_pairs = find_minimal_pairs(dataset, tokenizer)
    triplets = []

    # Index transcripts â†’ utterances
    text_to_entries = {}
    for d in dataset:
        text_to_entries.setdefault(d["text"], []).append(d)

    for idx_A, idx_B, _ in minimal_pairs:
        A = dataset[idx_A]
        B = dataset[idx_B]

        candidates_A = [d for d in text_to_entries[A["text"]] if d["audio_path"] != A["audio_path"]]
        candidates_B = [d for d in text_to_entries[B["text"]] if d["audio_path"] != B["audio_path"]]

        if candidates_A:
            X = random.choice(candidates_A)
            label = 0  # X matches A
        elif candidates_B:
            X = random.choice(candidates_B)
            label = 1  # X matches B
        else:
            continue

        wave_A = load_wave(A["audio_path"])
        wave_B = load_wave(B["audio_path"])
        wave_X = load_wave(X["audio_path"])

        triplets.append((wave_A, wave_B, wave_X, label))

        if len(triplets) >= max_triplets:
            break

    return triplets

In [None]:
triplets = generate_abx_triplets(dataset, tokenizer, max_triplets=100)

A, B, X, label = triplets[0]
print("Triplet label:", label)
print("Wave shapes:", A.shape, B.shape, X.shape)

In [None]:
idx_A, idx_B, _ = find_minimal_pairs(dataset, tokenizer, max_pairs=1)[0]
print(dataset[idx_A]["text"], "vs", dataset[idx_B]["text"])

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
update_ver = 2_000
checkpoint_dict = torch.load(os.path.join('models', 'ssl_model', f'ssl_model_prototype_{update_ver}.pth'))
ssl_state_dict = checkpoint_dict['model_state_dict']
ssl_model = SSLModel().to(device)
ssl_model.load_state_dict(ssl_state_dict, strict=True)

In [None]:
result = abx_test(ssl_model, triplets, device)
result * 100

In [None]:
class ABXDataset(Dataset):
    def __init__(self, metadata_path, segment_len=32000): # e.g., 2 seconds at 16k
        self.df = pd.read_csv(metadata_path, sep="\t")
        self.segment_len = segment_len

    def __len__(self):
        return len(self.df)

    def _get_segment(self, path):
        waveform, _ = sf.read(path, always_2d=True)
        waveform = torch.tensor(waveform.T, dtype=torch.float32).mean(dim=0)
        
        # Ensure it's long enough, else pad
        if waveform.shape[0] < self.segment_len:
            waveform = F.pad(waveform, (0, self.segment_len - waveform.shape[0]))
        
        # Random crop
        start = torch.randint(0, waveform.shape[0] - self.segment_len + 1, (1,)).item()
        return waveform[start : start + self.segment_len]

    def __getitem__(self, idx):
        path_a = self.df.iloc[idx]['path']
        
        # A and X are two different crops/augments of the same file
        anchor = self._get_segment(path_a)
        positive = self._get_segment(path_a) 
        
        # B is a random different file
        random_idx = torch.randint(0, len(self.df), (1,)).item()
        while random_idx == idx:
            random_idx = torch.randint(0, len(self.df), (1,)).item()
        
        negative = self._get_segment(self.df.iloc[random_idx]['path'])
        
        return anchor, positive, negative

@torch.no_grad()
def run_abx_val(model, abx_loader, device):
    model.eval()
    correct = 0
    total = 0
    
    for a, p, n in abx_loader:
        # Move to device and add channel dim [B, 1, Seq]
        a, p, n = a.to(device).unsqueeze(1), p.to(device).unsqueeze(1), n.to(device).unsqueeze(1)
        
        # Extract features (using the student/online encoder)
        feat_a = model.extract_features(a).mean(dim=1) # [B, Hidden]
        feat_p = model.extract_features(p).mean(dim=1)
        feat_n = model.extract_features(n).mean(dim=1)
        
        # Compute Cosine Similarity
        sim_pos = F.cosine_similarity(feat_a, feat_p)
        sim_neg = F.cosine_similarity(feat_a, feat_n)
        
        # Successful if anchor is more similar to positive than negative
        correct += (sim_pos > sim_neg).sum().item()
        total += a.size(0)
    
    accuracy = correct / total
    print(f"ABX Discrimination Accuracy: {accuracy:.2%}")
    return accuracy

abx_dataset = ABXDataset(metadata_path=path, segment_len=16000 * 2)
abx_loader = DataLoader(abx_dataset, batch_size=8, shuffle=False)