In [1]:
import torch, torch.nn as nn, torch.nn.functional as F
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))

True
NVIDIA GeForce RTX 4060 Laptop GPU


In [2]:
import torchaudio
import os, numpy as np
from tqdm import tqdm

In [3]:
from transformers import AutoTokenizer, AutoModel
import json

In [None]:
from aac_datasets import Clotho

# Add download=True for each to download the datasets if not already present
train_dataset = Clotho(root=".", subset="dev")
eval_dataset = Clotho(root=".", subset="eval")
val_dataset = Clotho(root=".", subset="val")

In [5]:
item = train_dataset[0]
audio, captions = item["audio"], item["captions"]
captions

['A muddled noise of broken channel of the TV',
 'A television blares the rhythm of a static TV.',
 'Loud television static dips in and out of focus',
 'The loud buzz of static constantly changes pitch and volume.',
 'heavy static and the beginnings of a signal on a transistor radio']

In [6]:
len(train_dataset), len(eval_dataset), len(val_dataset)

(3839, 1045, 1045)

In [7]:
sample = train_dataset[0]
print("Captions:", sample["captions"])
print("Audio shape:", sample["audio"].shape)

Captions: ['A muddled noise of broken channel of the TV', 'A television blares the rhythm of a static TV.', 'Loud television static dips in and out of focus', 'The loud buzz of static constantly changes pitch and volume.', 'heavy static and the beginnings of a signal on a transistor radio']
Audio shape: torch.Size([1, 1153825])


In [8]:
mel_transform = torchaudio.transforms.MelSpectrogram(
    sample_rate=44100,   # Clotho uses 44.1kHz audio
    n_fft=1024,
    hop_length=512,
    n_mels=64
)
to_db = torchaudio.transforms.AmplitudeToDB(stype='power')

In [9]:
def generate_mel_spectrograms(dataset, split_name):
    if os.path.exists(f"features/mel_{split_name}") and len(os.listdir(f"features/mel_{split_name}")) == len(dataset):
        print(f"✅ Spectrograms already exist. Skipping generation.")
    else:
        print("Generating spectrograms...")
        os.makedirs(f"features/mel_{split_name}", exist_ok=True)
        
        for item in tqdm(dataset):
            waveform = item["audio"]
            
            # Normalize to [1, samples] regardless of input
            if waveform.dim() == 1:  # [samples] → [1, samples]
                waveform = waveform.unsqueeze(0)
            elif waveform.dim() == 2 and waveform.shape[0] == 2:  # [2, samples] → [1, samples]
                waveform = waveform.mean(dim=0, keepdim=True)  # Convert stereo to mono
            elif waveform.dim() == 2 and waveform.shape[0] != 1:
                waveform = waveform[0:1]  # Take first channel
            
            # Now waveform is guaranteed to be [1, samples]
            mel = mel_transform(waveform)  # → [1, 64, time]
            mel_db = to_db(mel)
            
            # Save with consistent shape [1, 64, time]
            assert mel_db.dim() == 3 and mel_db.shape[0] == 1, f"Unexpected shape: {mel_db.shape}"
            np.save(f"features/mel_{split_name}/{item['fname']}.npy", mel_db.numpy())
        
        print("✅ Spectrograms generated!")

In [10]:
generate_mel_spectrograms(train_dataset, "train")
generate_mel_spectrograms(eval_dataset, "eval")
generate_mel_spectrograms(val_dataset, "val")

✅ Spectrograms already exist. Skipping generation.
✅ Spectrograms already exist. Skipping generation.
✅ Spectrograms already exist. Skipping generation.


Caption Processing

In [11]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [12]:
def process_and_save_captions(dataset, split_name, max_len=30):
    save_dir = f"features/captions_{split_name}"
    os.makedirs(save_dir, exist_ok=True)
    
    all_count = 0
    print(f"Processing captions for split: {split_name} ...")

    for item in tqdm(dataset):
        fname = item["fname"]
        captions = item["captions"]  

        for i, cap in enumerate(captions):
            encoded = tokenizer(
                cap.lower(),
                padding="max_length",
                truncation=True,
                max_length=max_len,
                return_tensors="pt"
            )

            # Save encoded caption tensor dict
            torch.save(
                {"input_ids": encoded["input_ids"].squeeze(0),
                 "attention_mask": encoded["attention_mask"].squeeze(0)},
                f"{save_dir}/{fname}_cap{i}.pt"
            )
            all_count += 1

    print(f"✅ Saved {all_count} tokenized captions to {save_dir}/")

In [13]:
process_and_save_captions(train_dataset, "train")
process_and_save_captions(val_dataset, "val")
process_and_save_captions(eval_dataset, "eval")

Processing captions for split: train ...


100%|██████████| 3839/3839 [00:22<00:00, 171.34it/s]


✅ Saved 19195 tokenized captions to features/captions_train/
Processing captions for split: val ...


100%|██████████| 1045/1045 [00:06<00:00, 160.03it/s]


✅ Saved 5225 tokenized captions to features/captions_val/
Processing captions for split: eval ...


100%|██████████| 1045/1045 [00:07<00:00, 147.88it/s]

✅ Saved 5225 tokenized captions to features/captions_eval/





Dataset Class

In [14]:
from torch.utils.data import Dataset, DataLoader

In [15]:
class AudioTextDataset(Dataset):
    def __init__(self, mel_dir, caption_dir):
        self.mel_dir = mel_dir
        self.caption_dir = caption_dir
        self.caption_files = sorted(os.listdir(caption_dir))
        self.cache = {}  # memory cache

    def __len__(self):
        return len(self.caption_files)
    
    def __getitem__(self, idx):
        cap_file = self.caption_files[idx]
        base = cap_file.split("_cap")[0]
        mel_path = os.path.join(self.mel_dir, f"{base}.npy")

        # Cache mel in memory
        if base not in self.cache:
            mel = np.load(mel_path)
            self.cache[base] = torch.tensor(mel, dtype=torch.float32)
        mel = self.cache[base]

        cap = torch.load(os.path.join(self.caption_dir, cap_file))
        return mel, cap["input_ids"], cap["attention_mask"]

Audio Encoder

In [16]:
class AudioEncoder(nn.Module):
    def __init__(self, embed_dim=512):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(128, 256, 3, stride=2, padding=1), nn.ReLU(),
        )
        self.rnn = nn.LSTM(256, 256, batch_first=True, bidirectional=True)
        self.attn = nn.MultiheadAttention(512, num_heads=4, batch_first=True)
        self.proj = nn.Linear(512, embed_dim)
    
    def forward(self, mel):
        # mel: [B, 1, 64, T]
        x = self.cnn(mel)                    # [B, 256, 8, T/8]
        x = x.mean(2).permute(0, 2, 1)       # [B, T', 256]
        x, _ = self.rnn(x)
        attn_out, _ = self.attn(x, x, x)
        pooled = attn_out.mean(1)            # [B, 512]
        z = self.proj(pooled)
        return F.normalize(z, dim=-1)

Text Encoder

In [17]:
class TextEncoder(nn.Module):
    def __init__(self, embed_dim=512):
        super().__init__()
        self.bert = AutoModel.from_pretrained("bert-base-uncased")
        self.proj = nn.Linear(768, embed_dim)
    def forward(self, input_ids, attention_mask):
        out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_emb = out.last_hidden_state[:, 0, :]  # CLS token
        z = self.proj(cls_emb)
        return F.normalize(z, dim=-1)

In [18]:
# Contrastive loss
def contrastive_loss(z_a, z_t, temperature=0.07):
    sim = z_a @ z_t.T  # cosine similarities
    sim /= temperature
    labels = torch.arange(sim.size(0)).to(sim.device)
    loss_a = F.cross_entropy(sim, labels)
    loss_t = F.cross_entropy(sim.T, labels)
    return (loss_a + loss_t) / 2

Training

In [19]:
device = "cuda" if torch.cuda.is_available() else "cpu"
audio_encoder = AudioEncoder().to(device)
text_encoder = TextEncoder().to(device)

In [20]:
optimizer = torch.optim.AdamW(
    list(audio_encoder.parameters()) + list(text_encoder.parameters()), lr=1e-4
)

In [21]:
def collate_fn(batch):
    # batch is a list of tuples: (mel, input_ids, attn_mask)
    mels, ids, masks = zip(*batch)
    
    # Find max time length in this batch
    max_len = max(m.shape[-1] for m in mels)
    
    # Pad each mel on the right with zeros
    padded_mels = []
    for m in mels:
        pad_len = max_len - m.shape[-1]
        if pad_len > 0:
            m = F.pad(m, (0, pad_len))  
        padded_mels.append(m)
    
    mels = torch.stack(padded_mels)
    ids = torch.stack(ids)
    masks = torch.stack(masks)
    return mels, ids, masks

In [22]:
train_dataset = AudioTextDataset("features/mel_train", "features/captions_train")
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0, collate_fn=collate_fn)

In [23]:
eval_dataset = AudioTextDataset("features/mel_eval", "features/captions_eval")
eval_loader = DataLoader(eval_dataset, batch_size=16, shuffle=False, num_workers=0, collate_fn=collate_fn)

In [24]:
num_epochs = 10

In [None]:
ds = AudioTextDataset("features/mel_train", "features/captions_train")
print("Dataset length:", len(ds))
x = ds[0]  
print([t.shape for t in x if torch.is_tensor(t)])

Dataset length: 19195
[torch.Size([1, 64, 1504]), torch.Size([30]), torch.Size([30])]


  cap = torch.load(os.path.join(self.caption_dir, cap_file))


In [26]:
best_loss = float('inf')

In [27]:
for epoch in range(num_epochs):
    audio_encoder.train()
    text_encoder.train()
    train_loss = 0.0
    for mel, input_ids, attn_mask in tqdm(train_loader):
        mel, input_ids, attn_mask = mel.to(device), input_ids.to(device), attn_mask.to(device)
        z_a = audio_encoder(mel)
        z_t = text_encoder(input_ids, attn_mask)
        loss = contrastive_loss(z_a, z_t)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    
    audio_encoder.eval()
    text_encoder.eval()
    eval_loss = 0.0
    with torch.no_grad():
        for mel, input_ids, attn_mask in eval_loader:
            mel, input_ids, attn_mask = mel.to(device), input_ids.to(device), attn_mask.to(device)
            z_a = audio_encoder(mel)
            z_t = text_encoder(input_ids, attn_mask)
            loss = contrastive_loss(z_a, z_t)
            eval_loss += loss.item()
    
    print(f"Epoch {epoch+1}: Train loss={train_loss/len(train_loader):.4f}, Eval loss={eval_loss/len(eval_loader):.4f}")

    # Save checkpoint when eval loss improves
    if eval_loss < best_loss:
        torch.save({
            "audio_encoder": audio_encoder.state_dict(),
            "text_encoder": text_encoder.state_dict()
        }, "new_best_model.pt")
        best_loss = eval_loss


  cap = torch.load(os.path.join(self.caption_dir, cap_file))
100%|██████████| 1200/1200 [04:13<00:00,  4.74it/s]


Epoch 1: Train loss=2.1193, Eval loss=2.4171


100%|██████████| 1200/1200 [04:08<00:00,  4.84it/s]


Epoch 2: Train loss=1.5191, Eval loss=2.3237


100%|██████████| 1200/1200 [04:04<00:00,  4.91it/s]


Epoch 3: Train loss=1.1873, Eval loss=2.3596


100%|██████████| 1200/1200 [04:06<00:00,  4.86it/s]


Epoch 4: Train loss=0.9815, Eval loss=2.3728


100%|██████████| 1200/1200 [04:07<00:00,  4.85it/s]


Epoch 5: Train loss=0.8343, Eval loss=2.5284


100%|██████████| 1200/1200 [04:05<00:00,  4.89it/s]


Epoch 6: Train loss=0.6746, Eval loss=2.6497


100%|██████████| 1200/1200 [04:04<00:00,  4.92it/s]


Epoch 7: Train loss=0.5727, Eval loss=2.6003


100%|██████████| 1200/1200 [04:04<00:00,  4.91it/s]


Epoch 8: Train loss=0.4714, Eval loss=2.7315


100%|██████████| 1200/1200 [04:05<00:00,  4.88it/s]


Epoch 9: Train loss=0.4221, Eval loss=2.7312


100%|██████████| 1200/1200 [04:04<00:00,  4.91it/s]


Epoch 10: Train loss=0.3823, Eval loss=2.7940


Model Evaluation

In [28]:
def compute_embeddings(model_a, model_t, dataset, device, batch_size=16):
    model_a.eval(); model_t.eval()
    audio_embs, text_embs = [], []

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=collate_fn)

    with torch.no_grad():
        for mel, input_ids, attn_mask in tqdm(loader, desc="Embedding val set"):
            mel, input_ids, attn_mask = mel.to(device), input_ids.to(device), attn_mask.to(device)
            z_a = model_a(mel)
            z_t = model_t(input_ids, attn_mask)
            audio_embs.append(z_a.cpu())
            text_embs.append(z_t.cpu())

    return torch.cat(audio_embs), torch.cat(text_embs)

In [67]:
val_dataset = AudioTextDataset("features/mel_val", "features/captions_val")
audio_embs, text_embs = compute_embeddings(audio_encoder, text_encoder, val_dataset, device)

  cap = torch.load(os.path.join(self.caption_dir, cap_file))
Embedding val set: 100%|██████████| 327/327 [00:23<00:00, 14.01it/s]


In [30]:
similarity = audio_embs @ text_embs.T 

In [31]:
def recall_at_k(sim_matrix, k=10):
    correct = 0
    for i in range(sim_matrix.size(0)):
        topk = sim_matrix[i].topk(k).indices
        if i in topk:  # correct caption retrieved
            correct += 1
    return correct / sim_matrix.size(0)

In [32]:
def mean_average_precision(sim_matrix):
    avg_precisions = []
    for i in range(sim_matrix.size(0)):
        # Sort by descending similarity
        scores = sim_matrix[i].argsort(descending=True)
        # Find rank of the correct match
        rank = (scores == i).nonzero(as_tuple=True)[0].item() + 1
        avg_precisions.append(1.0 / rank)
    return sum(avg_precisions) / len(avg_precisions)

In [33]:
# Audio → Text
r1_a2t = recall_at_k(similarity, 1)
r5_a2t = recall_at_k(similarity, 5)
r10_a2t = recall_at_k(similarity, 10)
map_a2t = mean_average_precision(similarity)

# Text → Audio (transpose)
r1_t2a = recall_at_k(similarity.T, 1)
r5_t2a = recall_at_k(similarity.T, 5)
r10_t2a = recall_at_k(similarity.T, 10)
map_t2a = mean_average_precision(similarity.T)

print(f"Audio→Text Recall@1/5/10: {r1_a2t:.3f}, {r5_a2t:.3f}, {r10_a2t:.3f}, MAP={map_a2t:.3f}")
print(f"Text→Audio Recall@1/5/10: {r1_t2a:.3f}, {r5_t2a:.3f}, {r10_t2a:.3f}, MAP={map_t2a:.3f}")

Audio→Text Recall@1/5/10: 0.008, 0.031, 0.054, MAP=0.028
Text→Audio Recall@1/5/10: 0.007, 0.031, 0.055, MAP=0.027


Model Inference

In [36]:
def audio_to_text(audio_path, k=5):
    wav, sr = torchaudio.load(audio_path)
    if wav.dim() == 2:
        wav = wav.mean(dim=0, keepdim=True)
    mel = mel_transform(wav)
    mel_db = to_db(mel).unsqueeze(0).to(device)

    with torch.no_grad():
        z_a = audio_encoder(mel_db)
        z_a = F.normalize(z_a, dim=-1).cpu()   
    sims = (z_a @ text_embs.T).squeeze(0)     
    topk = sims.topk(k)
    return topk.indices.tolist(), topk.values.tolist()


In [71]:
val_dataset = Clotho(root=".", subset="val")

In [72]:
all_val = []

In [77]:
for i in range(len(val_dataset)):
    sample_val = val_dataset[i]

    for caption in sample_val['captions']:
        all_val.append(caption)

In [78]:
len(all_val)

10450

In [79]:
idxs, scores = audio_to_text(r"CLOTHO_v2.1\clotho_audio_files\validation\zipping backpack and rustling papers.wav", k=5)
for i, (idx, sc) in enumerate(zip(idxs, scores)):
    print(f"Rank {i+1}: Caption: {all_val[idx]}, score={sc:.3f}")

Rank 1: Caption: Paper rustles as someone walks across the dust and sticks on the ground., score=0.933
Rank 2: Caption: A man speaks loudly then a buzzer plays then a man speaks loudly again a flute begins to play., score=0.895
Rank 3: Caption: The monotone voice of a man speaks repeatedly followed by a melody played on a keyboard., score=0.890
Rank 4: Caption: The rain pours over the houses and dies off after a while., score=0.866
Rank 5: Caption: A heavy rain is falling on a windy day., score=0.865


In [64]:
def text_to_audio(query_text, k=5):
    encoded = tokenizer(query_text, padding="max_length",
                        truncation=True, max_length=30, return_tensors="pt")
    with torch.no_grad():
        z_t = text_encoder(encoded["input_ids"].to(device),
                           encoded["attention_mask"].to(device))
        z_t = F.normalize(z_t, dim=-1).cpu()   # ensure CPU
    sims = (z_t @ audio_embs.T).squeeze(0)
    topk = sims.topk(k)
    return topk.indices.tolist(), topk.values.tolist()

In [91]:
idxs, scores = text_to_audio("Sound of wind turbine rotor", k=5)
for i, (idx, sc) in enumerate(zip(idxs, scores)):
    print(f"Rank {i+1}: Audio #{idx}, score={sc:.3f}")

Rank 1: Audio #5162, score=0.864
Rank 2: Audio #5161, score=0.864
Rank 3: Audio #5163, score=0.864
Rank 4: Audio #5164, score=0.864
Rank 5: Audio #5160, score=0.864
