In [1]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
#!cp '/content/drive/MyDrive/EE698R_Project/hindi.zip' '/content/hindi.zip'
# !unzip '/content/hindi.zip' -d '/content/hindi'
#!unzip '/content/drive/MyDrive/EE698R_Project/test.zip' -d '/content/drive/MyDrive/EE698R_Project/'

In [None]:
#!cp -r /content/hindi /content/drive/MyDrive/EE698R_Project/

^C


In [None]:
#generate_synthetic_hindi_audio("/content/drive/MyDrive/EE698R_Project/test/alignments/word_alignments.pkl")

In [2]:
import torch
import torchaudio
import pickle
import os
import random

class BESTSTDWordDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, split="test", segment_duration=1.0, sample_rate=16000):
        self.root_dir = root_dir
        self.audio_dir = os.path.join(root_dir, split, "audio")
        self.pkl_path = os.path.join(root_dir, split, "alignments", "word_alignments.pkl")

        with open(self.pkl_path, "rb") as f:
            self.alignment_dict = pickle.load(f)

        self.path_prefix = "Kathbath/kb_data_clean_m4a/hindi/"
        self.sample_rate = sample_rate
        self.segment_samples = int(segment_duration * sample_rate)

        self.word_to_segments = self._collect_word_segments()
        self.word_list = list(self.word_to_segments.keys())

        #print(f" Loaded {len(self.word_list)} words from split: {split}")
        #for word, segs in list(self.word_to_segments.items())[:5]:
        #   print(f"  - {word}: {len(segs)} segments")

    # def _collect_word_segments(self):
    #     word_to_segments = {}
    #     for word, utts in self.alignment_dict.items():
    #         for full_path, (start, end) in utts.items():
    #             if full_path.startswith(self.path_prefix):
    #                 relative_path = full_path[len(self.path_prefix):]
    #             else:
    #                 relative_path = full_path

    #             audio_path = os.path.join(self.root_dir, relative_path)
    #             word_to_segments.setdefault(word, []).append((audio_path, start, end))
    #     return {k: v for k, v in word_to_segments.items() if len(v) >= 2}

    def _collect_word_segments(self):
      word_to_segments = {}
      for word, utts in self.alignment_dict.items():
          for full_path, (start, end) in utts.items():
              if full_path.startswith(self.path_prefix):
                  relative_path = full_path[len(self.path_prefix):]
              else:
                  relative_path = full_path

              audio_path = os.path.join(self.root_dir, relative_path)
              if not os.path.exists(audio_path):
                  continue  # skip if audio file doesn't exist

              word_to_segments.setdefault(word, []).append((audio_path, start, end))

      # keep only words with at least 2 valid segments
      return {k: v for k, v in word_to_segments.items() if len(v) >= 2}

    def __len__(self):
        return len(self.word_list)

    def __getitem__(self, idx):
        word = self.word_list[idx]
        segs = self.word_to_segments[word]
        a_info, p_info = random.sample(segs, 2)
        return self._load_segment(*a_info), self._load_segment(*p_info), word

    def _load_segment(self, wav_path, start_sec, end_sec):
        waveform, sr = torchaudio.load(wav_path)
        if sr != self.sample_rate:
            waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sample_rate)(waveform)

        start = int(start_sec * self.sample_rate)
        end = int(end_sec * self.sample_rate)
        segment = waveform[:, start:end]
        return self._pad_or_trim(segment)

    def _pad_or_trim(self, segment):
        if segment.shape[1] < self.segment_samples:
            segment = torch.nn.functional.pad(segment, (0, self.segment_samples - segment.shape[1]))
        return segment[:, :self.segment_samples]
# import torch
# import torchaudio
# import pickle
# import os
# import random
# from collections import defaultdict

# class BESTSTDWordDataset(torch.utils.data.Dataset):
#     def __init__(self, root_dir, split="test", segment_duration=1.0, sample_rate=16000, max_segments_per_word=50):
#         self.root_dir = root_dir
#         self.audio_dir = os.path.join(root_dir, split, "audio")
#         self.pkl_path = os.path.join(root_dir, split, "alignments", "word_alignments.pkl")
#         self.sample_rate = sample_rate
#         self.segment_samples = int(segment_duration * sample_rate)
#         self.max_segments_per_word = max_segments_per_word

#         with open(self.pkl_path, "rb") as f:
#             self.alignment_dict = pickle.load(f)

#         self.path_prefix = "Kathbath/kb_data_clean_m4a/hindi/"
#         self.word_to_segments = self._collect_word_segments()
#         self.word_list = list(self.word_to_segments.keys())

#         print(f"\n Loaded {len(self.word_list)} words from split: {split}")
#         for word, segs in list(self.word_to_segments.items())[:5]:
#             print(f"  - {word}: {len(segs)} segments")

#     def _collect_word_segments(self):
#         word_to_segments = defaultdict(list)
#         for word, utts in self.alignment_dict.items():
#             for full_path, (start, end) in utts.items():
#                 relative_path = full_path[len(self.path_prefix):] if full_path.startswith(self.path_prefix) else full_path
#                 audio_path = os.path.join(self.root_dir, relative_path)
#                 word_to_segments[word].append((audio_path, start, end))

#         # Filter and downsample
#         final_dict = {}
#         for word, segs in word_to_segments.items():
#             if len(segs) >= 2:
#                 if len(segs) > self.max_segments_per_word:
#                     segs = random.sample(segs, self.max_segments_per_word)
#                 final_dict[word] = segs
#         return final_dict

#     def __len__(self):
#         return len(self.word_list)

#     def __getitem__(self, idx):
#         word = self.word_list[idx]
#         segs = self.word_to_segments[word]
#         a_info, p_info = random.sample(segs, 2)
#         return self._load_segment(*a_info), self._load_segment(*p_info), word

#     def _load_segment(self, wav_path, start_sec, end_sec):
#         waveform, sr = torchaudio.load(wav_path)
#         if sr != self.sample_rate:
#             waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sample_rate)(waveform)
#         start = int(start_sec * self.sample_rate)
#         end = int(end_sec * self.sample_rate)
#         segment = waveform[:, start:end]
#         return self._pad_or_trim(segment)

#     def _pad_or_trim(self, segment):
#         if segment.shape[1] < self.segment_samples:
#             segment = torch.nn.functional.pad(segment, (0, self.segment_samples - segment.shape[1]))
#         return segment[:, :self.segment_samples]




In [None]:
# from torch.utils.data import DataLoader
# import matplotlib.pyplot as plt
# import torchaudio
# from IPython.display import Audio, display

# def show_spectrogram(mel, title):
#     plt.figure(figsize=(10, 3))
#     plt.imshow(mel.numpy(), origin="lower", aspect="auto")
#     plt.title(title)
#     plt.colorbar()
#     plt.tight_layout()
#     plt.show()

# def listen_to_mel(mel, sample_rate=16000, n_mels=96, n_fft=1024):
#     # Reconstruct a rough waveform from the Mel spectrogram for listening.
#     inv_mel = torchaudio.transforms.InverseMelScale(n_stft=n_fft//2 + 1, n_mels=n_mels, sample_rate=sample_rate)
#     griffin_lim = torchaudio.transforms.GriffinLim(n_fft=n_fft)
#     waveform = griffin_lim(inv_mel(mel))
#     return waveform

# def test_loader(dataset):
#     loader = DataLoader(dataset, batch_size=1, shuffle=True)
#     for anchor, positive, word in loader:
#         print(f"\n Word: {word[0]}")
#         print(f"Anchor shape: {anchor.shape}, Positive shape: {positive.shape}")

#         show_spectrogram(anchor[0], title=f"{word[0]} - Anchor")
#         show_spectrogram(positive[0], title=f"{word[0]} - Positive")

#         print("Playing Anchor:")
#         display(Audio(listen_to_mel(anchor[0]).numpy(), rate=16000))
#         print("Playing Positive:")
#         display(Audio(listen_to_mel(positive[0]).numpy(), rate=16000))
#         break  # Only test on a single batch

# # Instantiate the dataset for the test split:
# dataset = BESTSTDWordDataset(root_dir="/content/drive/MyDrive/EE698R_Project/", split="test")
# test_loader(dataset)

In [3]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import Wav2Vec2Model
import torch.nn.functional as F
import os
import random

# ====== Quantizer ======
class VectorQuantizer(nn.Module):
    def __init__(self, codebook_size=512, dim=1024):
        super().__init__()
        self.codebook = nn.Embedding(codebook_size, dim)
        self.codebook.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size)

    def forward(self, x):  # x: [B, T, D]
        B, T, D = x.shape
        x_flat = x.view(-1, D)  # [B*T, D]
        dist = (
            x_flat.pow(2).sum(1, keepdim=True)
            - 2 * x_flat @ self.codebook.weight.t()
            + self.codebook.weight.pow(2).sum(1)
        )
        indices = dist.argmin(dim=1)  # [B*T]
        quantized = self.codebook(indices).view(B, T, D)
        return quantized, indices.view(B, T)

# ====== Loss Functions ======
def contrastive_loss(z1, z2, tau=0.5):
    sim = F.cosine_similarity(z1, z2, dim=-1)  # [B, T]
    loss = -torch.log(torch.exp(sim / tau).mean(dim=-1)).mean()
    return loss

def commitment_loss(z, z_q):
    return F.mse_loss(z_q.detach(), z) + 0.25 * F.mse_loss(z_q, z.detach())

def diversity_loss(quantized, codebook):
    # quantized: [B, T, D], codebook: [K, D]
    B, T, D = quantized.shape
    flatten = quantized.view(-1, D)
    codebook_norm = F.normalize(codebook.weight, dim=-1)
    flatten_norm = F.normalize(flatten, dim=-1)
    sim_matrix = flatten_norm @ codebook_norm.T  # [B*T, K]
    max_sim = sim_matrix.max(dim=-1)[0]  # [B*T]
    return -max_sim.mean()  # encourage spreading

# ====== Training Script ======
def train(root_dir="/content/drive/MyDrive/EE698R_Project/", split="test"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataset = BESTSTDWordDataset(root_dir=root_dir, split=split)
    loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2, pin_memory=True)

    # Load XLS-R
    xlsr = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-xlsr-53").to(device)
    xlsr.eval()  # frozen feature extractor

    projector = nn.Linear(1024, 1024).to(device)
    quantizer = VectorQuantizer(codebook_size=512, dim=1024).to(device)

    optimizer = torch.optim.Adam(list(projector.parameters()) + list(quantizer.parameters()), lr=1e-4)
    scaler = torch.amp.GradScaler()

    ckpt_dir = os.path.join(root_dir, "checkpoints")
    os.makedirs(ckpt_dir, exist_ok=True)

    print("Quantizer grad:", any(p.requires_grad for p in quantizer.parameters()))
    print("Projector grad:", any(p.requires_grad for p in projector.parameters()))

    for epoch in range(50):
        for anchor, positive, _ in loader:
            anchor = anchor.squeeze(1).to(device)
            positive = positive.squeeze(1).to(device)

            with torch.no_grad():
                with torch.amp.autocast(device_type='cuda'):
                    z1 = xlsr(anchor, return_dict=True).last_hidden_state
                    z2 = xlsr(positive, return_dict=True).last_hidden_state

            with torch.amp.autocast(device_type='cuda'):
                z1_proj = F.layer_norm(projector(z1), z1.shape[-1:])
                z2_proj = F.layer_norm(projector(z2), z2.shape[-1:])
                q1, tok1 = quantizer(z1_proj)
                q2, tok2 = quantizer(z2_proj)

                loss_c = contrastive_loss(q1, q2)
                loss_vq = commitment_loss(z1_proj, q1) + commitment_loss(z2_proj, q2)
                loss_d = diversity_loss(q1, quantizer.codebook)
                loss = loss_c + 0.1 * loss_vq + 0.1 * loss_d

            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        print(f"Epoch {epoch+1:02d} | Contrastive: {loss_c.item():.4f} | Commitment: {loss_vq.item():.4f} | Diversity: {loss_d.item():.4f} | z1_proj mean: {z1_proj.abs().mean():.4f} | q1 mean: {q1.abs().mean():.4f}")

    torch.save(quantizer.state_dict(), os.path.join(ckpt_dir, "quantizer.pt"))
    torch.save(projector.state_dict(), os.path.join(ckpt_dir, "projector.pt"))

    return xlsr, projector, quantizer





In [None]:
xlsr, projector, quantizer = train()

✅ Loaded 2017 words from split: test
  - पुलिस: 47 segments
  - ने: 262 segments
  - महिला: 15 segments
  - की: 485 segments
  - शिकायत: 4 segments


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/1.77k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.27G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.27G [00:00<?, ?B/s]

Quantizer grad: True
Projector grad: True
Epoch 01 | Contrastive: -1.9998 | Commitment: 0.2103 | Diversity: -1.0000 | z1_proj mean: 0.1887 | q1 mean: 0.0016
Epoch 02 | Contrastive: -1.9933 | Commitment: 0.1715 | Diversity: -1.0000 | z1_proj mean: 0.1833 | q1 mean: 0.0031
Epoch 03 | Contrastive: -1.9929 | Commitment: 0.2603 | Diversity: -1.0000 | z1_proj mean: 0.2232 | q1 mean: 0.0027
Epoch 04 | Contrastive: -1.9953 | Commitment: 0.0655 | Diversity: -1.0000 | z1_proj mean: 0.1151 | q1 mean: 0.0038
Epoch 05 | Contrastive: -1.9881 | Commitment: 0.1074 | Diversity: -1.0000 | z1_proj mean: 0.1464 | q1 mean: 0.0024
Epoch 06 | Contrastive: -1.9977 | Commitment: 0.1505 | Diversity: -1.0000 | z1_proj mean: 0.1722 | q1 mean: 0.0050
Epoch 07 | Contrastive: -1.9894 | Commitment: 0.0940 | Diversity: -1.0000 | z1_proj mean: 0.1444 | q1 mean: 0.0033
Epoch 08 | Contrastive: -1.9988 | Commitment: 0.0660 | Diversity: -1.0000 | z1_proj mean: 0.1152 | q1 mean: 0.0039
Epoch 09 | Contrastive: -1.9948 | Comm

In [4]:
# Token Evaluation and Demo Script
import torch
from transformers import Wav2Vec2Model
import torch.nn.functional as F
from torch.nn import Linear
import random
import os
from collections import defaultdict

def jaccard_similarity(seq1, seq2):
    set1, set2 = set(seq1), set(seq2)
    intersection = len(set1.intersection(set2))
    union = len(set1.union(set2))
    return intersection / union if union > 0 else 0.0

def evaluate_token_consistency(root_dir, split="test", num_words=10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset = BESTSTDWordDataset(root_dir=root_dir, split=split)
    word_to_segments = dataset.word_to_segments
    # selected_words = random.sample(list(word_to_segments.keys()), num_words)
    selected_words = list(word_to_segments.keys())

    # Load models
    xlsr = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-xlsr-53").to(device)
    xlsr.eval()
    projector = Linear(1024, 1024).to(device)
    quantizer_ckpt_path = os.path.join(root_dir, "checkpoints")
    projector.load_state_dict(torch.load(os.path.join(quantizer_ckpt_path, "projector.pt")))
    projector.eval()

    # from train import VectorQuantizer
    quantizer = VectorQuantizer(codebook_size=512, dim=1024).to(device)
    quantizer.load_state_dict(torch.load(os.path.join(quantizer_ckpt_path, "quantizer.pt")))
    quantizer.eval()

    similarities = []
    print("\n Jaccard Similarity on Tokens:")
    for word in selected_words:
        segments = word_to_segments[word]
        if len(segments) < 2:
            continue
        seg1_info, seg2_info = random.sample(segments, 2)
        seg1 = dataset._load_segment(*seg1_info).squeeze(0).to(device)
        seg2 = dataset._load_segment(*seg2_info).squeeze(0).to(device)

        with torch.no_grad():
            z1 = xlsr(seg1.unsqueeze(0), return_dict=True).last_hidden_state
            z2 = xlsr(seg2.unsqueeze(0), return_dict=True).last_hidden_state
            z1_proj = F.layer_norm(projector(z1), z1.shape[-1:])
            z2_proj = F.layer_norm(projector(z2), z2.shape[-1:])
            _, tok1 = quantizer(z1_proj)
            _, tok2 = quantizer(z2_proj)

        sim = jaccard_similarity(tok1.squeeze().tolist(), tok2.squeeze().tolist())
        similarities.append(sim)
        print(f"  - {word}: {sim:.4f}")

    print(f"\n Avg Jaccard similarity: {sum(similarities)/len(similarities):.4f}\n")

def token_demo(root_dir, split="test", word="महिला"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataset = BESTSTDWordDataset(root_dir=root_dir, split=split)
    xlsr = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-xlsr-53").to(device)
    xlsr.eval()
    projector = Linear(1024, 1024).to(device)
    ckpt_path = os.path.join(root_dir, "checkpoints")
    projector.load_state_dict(torch.load(os.path.join(ckpt_path, "projector.pt")))
    projector.eval()

    quantizer = VectorQuantizer(codebook_size=512, dim=1024).to(device)
    quantizer.load_state_dict(torch.load(os.path.join(ckpt_path, "quantizer.pt")))
    quantizer.eval()

    print(f"\n Token sequences for word: {word}")
    for i, info in enumerate(dataset.word_to_segments[word][:2]):
        seg = dataset._load_segment(*info).squeeze(0).to(device)
        with torch.no_grad():
            z = xlsr(seg.unsqueeze(0), return_dict=True).last_hidden_state
            z_proj = F.layer_norm(projector(z), z.shape[-1:])
            _, tokens = quantizer(z_proj)
        print(f"  Segment {i+1} tokens: {tokens.squeeze().tolist()}\n")


In [5]:
evaluate_token_consistency("/content/drive/MyDrive/EE698R_Project/", split="test")
token_demo("/content/drive/MyDrive/EE698R_Project/", word="महिला")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/1.77k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.27G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.27G [00:00<?, ?B/s]


 Jaccard Similarity on Tokens:
  - पुलिस: 0.3158
  - ने: 0.1818
  - महिला: 0.5000
  - की: 0.3333
  - शिकायत: 0.7143
  - पर: 0.6154
  - केस: 0.4615
  - दर्ज: 0.3077
  - कर: 0.2222
  - शुरू: 0.3571
  - दी: 0.2857
  - है: 0.1818
  - क्योंकि: 0.3333
  - ज्यादातर: 0.2500
  - अपने: 0.4545
  - ऊपर: 0.3846
  - हो: 0.2308
  - रहे: 0.2857
  - को: 0.4000
  - नजरअंदाज: 0.2308
  - जाती: 0.4211
  - हैं: 0.5000
  - रिपोर्ट: 0.5000
  - के: 0.3333
  - पोल: 0.2778
  - खोल: 0.4118
  - खरीदा: 0.3333
  - जिसे: 0.3529
  - उन्होंने: 0.5556
  - हॉलीवुड: 0.3333
  - से: 0.4167
  - मंगलवार: 0.3333
  - तीन: 0.4444
  - तलाक: 0.3333
  - बिल: 0.2143
  - विरोध: 0.5333
  - और: 0.3571
  - पास: 0.5000
  - गया: 0.3333
  - यह: 0.5385
  - किसी: 0.3333
  - भी: 0.3333
  - लंबी: 0.6364
  - होती: 0.4211
  - तस्वीर: 0.3889
  - का: 0.3077
  - शाही: 0.4545
  - दिया: 0.3846
  - मांग: 0.5000
  - लखनऊ: 0.5455
  - सांसद: 0.3889
  - गृह: 0.4375
  - मंत्री: 0.4118
  - सिंह: 0.5000
  - फिल्म: 0.3684
  - में: 0.3077
  - सुशांत: 0.4737
 

In [None]:
import torch
import torch.nn.functional as F
from torch.nn import Linear
from transformers import Wav2Vec2Model
from tqdm import tqdm
import os
import random
from collections import defaultdict


def compute_embedding(segment, xlsr, projector, quantizer):
    with torch.no_grad():
        z = xlsr(segment.unsqueeze(0), return_dict=True).last_hidden_state
        z_proj = F.layer_norm(projector(z), z.shape[-1:])
        z_q, _ = quantizer(z_proj)
        return z_q.mean(dim=1)  # Mean pooling to get fixed vector [1, D]

def evaluate_retrieval(root_dir, split="test", max_queries=50):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataset = BESTSTDWordDataset(root_dir=root_dir, split=split)
    word_to_segments = dataset.word_to_segments

    xlsr = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-xlsr-53").to(device)
    xlsr.eval()

    projector = Linear(1024, 1024).to(device)
    quantizer = VectorQuantizer(codebook_size=512, dim=1024).to(device)
    ckpt_path = os.path.join(root_dir, "checkpoints")
    projector.load_state_dict(torch.load(os.path.join(ckpt_path, "projector.pt")))
    quantizer.load_state_dict(torch.load(os.path.join(ckpt_path, "quantizer.pt")))
    projector.eval()
    quantizer.eval()

    # 1. Create embeddings for all segments
    index = []  # (embedding, word)
    print(" Indexing segments...")
    for word, segments in tqdm(word_to_segments.items()):
        for info in segments:
            seg = dataset._load_segment(*info).squeeze(0).to(device)
            emb = compute_embedding(seg, xlsr, projector, quantizer)
            index.append((emb.squeeze(), word))

    # 2. Perform retrieval and compute MRR
    all_words = list(word_to_segments.keys())
    mrr_total = 0
    n_queries = 0

    print("\n Running retrieval on sample queries...")
    for word in random.sample(all_words, min(len(all_words), max_queries)):
        if len(word_to_segments[word]) < 2:
            continue
        query_seg = dataset._load_segment(*random.choice(word_to_segments[word])).squeeze(0).to(device)
        query_emb = compute_embedding(query_seg, xlsr, projector, quantizer)

        # Similarity search
        sims = [(F.cosine_similarity(query_emb, emb.unsqueeze(0)).item(), tgt_word) for emb, tgt_word in index]
        ranked = sorted(sims, key=lambda x: -x[0])

        # Find rank of correct
        for rank, (_, retrieved_word) in enumerate(ranked):
            if retrieved_word == word:
                mrr_total += 1.0 / (rank + 1)
                break
        n_queries += 1

    mrr = mrr_total / n_queries if n_queries else 0
    print(f"\n Mean Reciprocal Rank (MRR): {mrr:.4f} over {n_queries} queries")


In [6]:
!pip install dtw --quiet

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for dtw (setup.py) ... [?25l[?25hdone


In [10]:
import os
import random
import torch
import torch.nn.functional as F
import numpy as np
from transformers import Wav2Vec2Model
from torch.nn import Linear
from dtw import accelerated_dtw
from tqdm import tqdm


def dtw_token_distance(seq1, seq2):
    x = torch.tensor(seq1).view(-1, 1).float()
    y = torch.tensor(seq2).view(-1, 1).float()
    dist, _, _, _ = accelerated_dtw(x.numpy(), y.numpy(), dist='euclidean')
    return dist


def evaluate_retrieval_dtw(root_dir, split="test", num_queries=20):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataset = BESTSTDWordDataset(root_dir=root_dir, split=split)
    word_to_segments = dataset.word_to_segments

    # Load models
    xlsr = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-xlsr-53").to(device)
    xlsr.eval()
    projector = Linear(1024, 1024).to(device)
    ckpt_path = os.path.join(root_dir, "checkpoints")
    projector.load_state_dict(torch.load(os.path.join(ckpt_path, "projector.pt")))
    projector.eval()

    quantizer = VectorQuantizer(codebook_size=512, dim=1024).to(device)
    quantizer.load_state_dict(torch.load(os.path.join(ckpt_path, "quantizer.pt")))
    quantizer.eval()

    # Step 1: Build the database (indexed segments)
    print("\n Indexing segments...")
    index = []
    index_meta = []
    for word, segs in tqdm(word_to_segments.items()):
        for info in segs:
            try:
                seg = dataset._load_segment(*info).squeeze(0).to(device)
            except FileNotFoundError:
                continue
            with torch.no_grad():
                z = xlsr(seg.unsqueeze(0), return_dict=True).last_hidden_state
                z_proj = F.layer_norm(projector(z), z.shape[-1:])
                _, tokens = quantizer(z_proj)
            index.append(tokens.squeeze().tolist())
            index_meta.append((word, info))

    # Step 2: Query and retrieve
    print("\n Running retrieval on sample queries...")
    mrr_total = 0
    query_words = random.sample(list(word_to_segments.keys()), num_queries)
    for word in query_words:
        segs = word_to_segments[word]
        if len(segs) < 2:
            continue
        query_info = random.choice(segs)
        try:
            query_seg = dataset._load_segment(*query_info).squeeze(0).to(device)
        except FileNotFoundError:
            continue

        with torch.no_grad():
            zq = xlsr(query_seg.unsqueeze(0), return_dict=True).last_hidden_state
            zq_proj = F.layer_norm(projector(zq), zq.shape[-1:])
            _, query_tokens = quantizer(zq_proj)
        query_tokens = query_tokens.squeeze().tolist()

        distances = [dtw_token_distance(query_tokens, tokens) for tokens in index]
        sorted_indices = np.argsort(distances)

        for rank, idx in enumerate(sorted_indices, start=1):
            if index_meta[idx][0] == word:
                mrr_total += 1.0 / rank
                break

    mrr = mrr_total / num_queries
    print(f"\n Mean Reciprocal Rank (DTW-based MRR): {mrr:.4f} over {num_queries} queries")

In [9]:
import os
import random
import torch
import torch.nn.functional as F
import numpy as np
import pickle
from transformers import Wav2Vec2Model
from torch.nn import Linear
from dtw import accelerated_dtw
from tqdm import tqdm
from collections import defaultdict

def dtw_token_distance(seq1, seq2):
    x = torch.tensor(seq1).view(-1, 1).float()
    y = torch.tensor(seq2).view(-1, 1).float()
    dist, _, _, _ = accelerated_dtw(x.numpy(), y.numpy(), dist='euclidean')
    return dist

def build_and_save_index(root_dir, split="test", index_file="retrieval_index.pkl"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataset = BESTSTDWordDataset(root_dir=root_dir, split=split)
    word_to_segments = dataset.word_to_segments

    # Load models
    xlsr = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-xlsr-53").to(device)
    xlsr.eval()
    projector = Linear(1024, 1024).to(device)
    ckpt_path = os.path.join(root_dir, "checkpoints")
    projector.load_state_dict(torch.load(os.path.join(ckpt_path, "projector.pt")))
    projector.eval()

    quantizer = VectorQuantizer(codebook_size=512, dim=1024).to(device)
    quantizer.load_state_dict(torch.load(os.path.join(ckpt_path, "quantizer.pt")))
    quantizer.eval()

    # Build index
    index = []
    print("Building and saving segment index...")
    for word, segs in tqdm(word_to_segments.items()):
        for info in segs:
            try:
                seg = dataset._load_segment(*info).squeeze(0).to(device)
                with torch.no_grad():
                    z = xlsr(seg.unsqueeze(0), return_dict=True).last_hidden_state
                    z_proj = F.layer_norm(projector(z), z.shape[-1:])
                    _, tokens = quantizer(z_proj)
                index.append((tokens.squeeze().tolist(), word))
            except Exception as e:
                print(f"Skipping segment due to error: {e}")
                continue

    with open(os.path.join(root_dir, index_file), "wb") as f:
        pickle.dump(index, f)

    print(f"Saved {len(index)} indexed segments to {index_file}")
    return index, dataset

def compute_mrr_from_index(index, dataset, num_queries=20):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    word_to_segments = dataset.word_to_segments

    # Load models
    xlsr = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-xlsr-53").to(device)
    xlsr.eval()
    projector = Linear(1024, 1024).to(device)
    ckpt_path = os.path.join(dataset.root_dir, "checkpoints")
    projector.load_state_dict(torch.load(os.path.join(ckpt_path, "projector.pt")))
    projector.eval()

    quantizer = VectorQuantizer(codebook_size=512, dim=1024).to(device)
    quantizer.load_state_dict(torch.load(os.path.join(ckpt_path, "quantizer.pt")))
    quantizer.eval()

    print("\nRunning DTW-based MRR evaluation...")
    total_mrr = 0
    count = 0
    for word in random.sample(list(word_to_segments.keys()), num_queries):
        segs = word_to_segments[word]
        if len(segs) < 2:
            continue

        query_info = random.choice(segs)
        query_seg = dataset._load_segment(*query_info).squeeze(0).to(device)

        with torch.no_grad():
            zq = xlsr(query_seg.unsqueeze(0), return_dict=True).last_hidden_state
            zq_proj = F.layer_norm(projector(zq), zq.shape[-1:])
            _, query_tokens = quantizer(zq_proj)
        query_tokens = query_tokens.squeeze().tolist()

        distances = [dtw_token_distance(query_tokens, tokens) for tokens, _ in index]
        sorted_indices = np.argsort(distances)

        for rank, idx in enumerate(sorted_indices, 1):
            if index[idx][1] == word:
                total_mrr += 1.0 / rank
                break
        count += 1

    mrr = total_mrr / count if count else 0
    print(f"\nMean Reciprocal Rank (DTW): {mrr:.4f} over {count} queries")

# index, dataset = build_and_save_index("/content/drive/MyDrive/EE698R_Project/")
# compute_mrr_from_index(index, dataset)

In [8]:
import torch
import torchaudio
import torch.nn.functional as F
from transformers import Wav2Vec2Model
from torch.nn import Linear
import pickle
import os
import numpy as np
from dtw import accelerated_dtw

# Compute DTW distance between token sequences
def dtw_token_distance(seq1, seq2):
    x = torch.tensor(seq1).view(-1, 1).float()
    y = torch.tensor(seq2).view(-1, 1).float()
    dist, _, _, _ = accelerated_dtw(x.numpy(), y.numpy(), dist='euclidean')
    return dist

# Load and process a waveform
def load_and_preprocess(wav_path, target_sr=16000, max_sec=1.0):
    waveform, sr = torchaudio.load(wav_path)
    if sr != target_sr:
        waveform = torchaudio.transforms.Resample(sr, target_sr)(waveform)
    segment_len = int(target_sr * max_sec)
    if waveform.shape[1] < segment_len:
        waveform = F.pad(waveform, (0, segment_len - waveform.shape[1]))
    return waveform[:, :segment_len]

# Main function for query-by-example
def query_by_example(wav_path, index_path, root_dir, top_k=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load XLS-R + projector + quantizer
    xlsr = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-xlsr-53").to(device).eval()
    projector = Linear(1024, 1024).to(device)
    projector.load_state_dict(torch.load(os.path.join(root_dir, "checkpoints", "projector.pt")))
    projector.eval()
    quantizer = VectorQuantizer(codebook_size=512, dim=1024).to(device)
    quantizer.load_state_dict(torch.load(os.path.join(root_dir, "checkpoints", "quantizer.pt")))
    quantizer.eval()

    # Load saved token index
    with open(index_path, "rb") as f:
        index = pickle.load(f)  # list of (token_seq, label)

    # Preprocess query audio
    waveform = load_and_preprocess(wav_path).squeeze(0).to(device)

    # Generate token sequence
    with torch.no_grad():
        z = xlsr(waveform.unsqueeze(0), return_dict=True).last_hidden_state
        z_proj = F.layer_norm(projector(z), z.shape[-1:])
        _, query_tokens = quantizer(z_proj)
    query_tokens = query_tokens.squeeze().tolist()

    # Compute DTW distances
    print(f"\nRunning query on {wav_path} ...")
    distances = [(dtw_token_distance(query_tokens, tokens), word) for tokens, word in index]
    top_matches = sorted(distances, key=lambda x: x[0])[:top_k]

    print(f"\nTop-{top_k} matches:")
    for rank, (dist, word) in enumerate(top_matches, 1):
        print(f"  {rank}. {word} (DTW distance = {dist:.2f})")

# query_by_example("query.wav", "/content/drive/MyDrive/EE698R_Project/retrieval_index.pkl", "/content/drive/MyDrive/EE698R_Project/")

In [7]:
import torch
import torchaudio
import torch.nn.functional as F
from transformers import Wav2Vec2Model
from torch.nn import Linear
import pickle
import os
import numpy as np
from scipy.spatial.distance import cosine

# Load and process a waveform
def load_and_preprocess(wav_path, target_sr=16000, max_sec=1.0):
    waveform, sr = torchaudio.load(wav_path)
    if sr != target_sr:
        waveform = torchaudio.transforms.Resample(sr, target_sr)(waveform)
    segment_len = int(target_sr * max_sec)
    if waveform.shape[1] < segment_len:
        waveform = F.pad(waveform, (0, segment_len - waveform.shape[1]))
    return waveform[:, :segment_len]

# Main function for cosine similarity query-by-example
def query_by_example_cosine(wav_path, index_path, root_dir, top_k=5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load XLS-R + projector + quantizer
    xlsr = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-xlsr-53").to(device).eval()
    projector = Linear(1024, 1024).to(device)
    projector.load_state_dict(torch.load(os.path.join(root_dir, "checkpoints", "projector.pt")))
    projector.eval()
    quantizer = VectorQuantizer(codebook_size=512, dim=1024).to(device)
    quantizer.load_state_dict(torch.load(os.path.join(root_dir, "checkpoints", "quantizer.pt")))
    quantizer.eval()

    # Load saved index: list of (embedding_tensor, word_label)
    with open(index_path, "rb") as f:
        index = pickle.load(f)

    # Preprocess query audio
    waveform = load_and_preprocess(wav_path).squeeze(0).to(device)

    # Generate embedding for query
    with torch.no_grad():
        z = xlsr(waveform.unsqueeze(0), return_dict=True).last_hidden_state
        z_proj = F.layer_norm(projector(z), z.shape[-1:])
        z_q, _ = quantizer(z_proj)
        query_embedding = z_q.mean(dim=1).squeeze().cpu().numpy()  # [D]

    # Compute cosine similarities
    print(f"\n Running cosine similarity query on {wav_path} ...")
    similarities = []
    for db_embedding, word in index:
        db_embedding = np.array(db_embedding)
        sim = 1 - cosine(query_embedding, db_embedding)
        similarities.append((sim, word))

    top_matches = sorted(similarities, key=lambda x: -x[0])[:top_k]

    print(f"\n Top-{top_k} matches:")
    for rank, (sim, word) in enumerate(top_matches, 1):
        print(f"  {rank}. {word} (Cosine similarity = {sim:.4f})")

In [None]:
evaluate_retrieval("/content/drive/MyDrive/EE698R_Project/", split="valid")

 Indexing segments...


100%|██████████| 3118/3118 [16:36<00:00,  3.13it/s]



 Running retrieval on sample queries...

 Mean Reciprocal Rank (MRR): 0.8788 over 50 queries


In [None]:
evaluate_retrieval("/content/drive/MyDrive/EE698R_Project/", split="test")

 Indexing segments...


100%|██████████| 2017/2017 [09:34<00:00,  3.51it/s]



 Running retrieval on sample queries...

 Mean Reciprocal Rank (MRR): 0.9090 over 50 queries


In [None]:
evaluate_retrieval_dtw("/content/drive/MyDrive/EE698R_Project/", split="test")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/1.77k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.27G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.27G [00:00<?, ?B/s]


 Indexing segments...



  0%|          | 0/2017 [00:00<?, ?it/s][A
  0%|          | 1/2017 [00:38<21:20:44, 38.12s/it][A
  0%|          | 2/2017 [04:14<80:06:53, 143.13s/it][A
  0%|          | 3/2017 [04:22<45:30:06, 81.33s/it] [A
  0%|          | 4/2017 [09:03<89:36:05, 160.24s/it][A
  0%|          | 5/2017 [09:03<57:17:15, 102.50s/it][A
  0%|          | 6/2017 [10:54<58:46:06, 105.20s/it][A
  0%|          | 7/2017 [10:54<39:33:46, 70.86s/it] [A
  0%|          | 8/2017 [10:55<27:12:59, 48.77s/it][A
  0%|          | 9/2017 [11:42<26:49:21, 48.09s/it][A
  0%|          | 10/2017 [11:51<20:03:54, 35.99s/it][A
  1%|          | 11/2017 [11:58<15:04:44, 27.06s/it][A
  1%|          | 12/2017 [17:06<62:47:09, 112.73s/it][A
  1%|          | 13/2017 [17:08<44:02:39, 79.12s/it] [A
  1%|          | 14/2017 [17:09<30:50:15, 55.42s/it][A
  1%|          | 15/2017 [17:21<23:38:09, 42.50s/it][A
  1%|          | 16/2017 [17:22<16:39:42, 29.98s/it][A
  1%|          | 17/2017 [17:50<16:14:28, 29.23s/it][A
  1%


 Running retrieval on sample queries...

 Mean Reciprocal Rank (DTW-based MRR): 1.0000 over 20 queries


In [None]:
evaluate_retrieval_dtw("/content/drive/MyDrive/EE698R_Project/", split="valid")


 Indexing segments...


  1%|          | 22/3118 [35:18<82:48:47, 96.29s/it]


KeyboardInterrupt: 

In [None]:
index, dataset = build_and_save_index("/content/drive/MyDrive/EE698R_Project/", split="valid")

Building and saving segment index...


100%|██████████| 3118/3118 [25:30<00:00,  2.04it/s]

Saved 31410 indexed segments to retrieval_index.pkl





In [None]:
# DTW
query_by_example("./अनुमान_10.wav", "/content/drive/MyDrive/EE698R_Project/retrieval_index.pkl", "/content/drive/MyDrive/EE698R_Project/")


Running query on ./अनुमान_10.wav ...

Top-5 matches:
  1. हास्य (DTW distance = 951.00)
  2. सकता (DTW distance = 965.00)
  3. रेलवे (DTW distance = 970.00)
  4. स्पेशल (DTW distance = 1025.00)
  5. यह (DTW distance = 1036.00)


In [None]:
query_by_example_cosine(
    wav_path="./अनुमान_10.wav",
    index_path="/content/drive/MyDrive/EE698R_Project/retrieval_index.pkl",
    root_dir="/content/drive/MyDrive/EE698R_Project",
    top_k=5
)


 Running cosine similarity query on ./अनुमान_10.wav ...


ValueError: shapes (1024,) and (49,) not aligned: 1024 (dim 0) != 49 (dim 0)