In [1]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


# Dataset Preparation

##### Converting .mp4 files from MELD dataset to .mp3 for audio data

In [None]:
# convert video to audio
from pathlib import Path
import importlib
import src.data_pipeline_helper_functions.conversion as conv
importlib.reload(conv)
from src.data_pipeline_helper_functions.conversion import cnv2mp3

project_root = Path.cwd().parent

# train split
src_train = project_root / "data" / "video_data" / "train_splits"
dst_train = project_root / "data" / "audio_data" / "train_splits"
print("----------------------------------------Converting train split----------------------------------------")
print("src:", src_train)
print("dst:", dst_train)
try:
    cnv2mp3(str(src_train), str(dst_train))
except Exception as e:
    print("conversion raised:", repr(e))


# dev split    
src_dev = project_root / "data" / "video_data" / "dev_splits"
dst_dev = project_root / "data" / "audio_data" / "dev_splits"
print("----------------------------------------Converting dev split----------------------------------------")
print("src:", src_dev)
print("dst:", dst_dev)
try:
    cnv2mp3(str(src_dev), str(dst_dev))
except Exception as e:
    print("conversion raised:", repr(e))
    
# test split    
src_test = project_root / "data" / "video_data" / "test_splits"
dst_test = project_root / "data" / "audio_data" / "test_splits"
print("----------------------------------------Converting test split----------------------------------------")
print("src:", src_test)
print("dst:", dst_test)
try:
    cnv2mp3(str(src_test), str(dst_test))
except Exception as e:
    print("conversion raised:", repr(e))

----------------------------------------Converting train split----------------------------------------
src: c:\Users\Marc\Desktop\Programming\Main-Project\snn\multimodal-snn\data\video_data\train_splits
dst: c:\Users\Marc\Desktop\Programming\Main-Project\snn\multimodal-snn\data\audio_data\train_splits
wrote: dia0_utt0.wav
wrote: dia0_utt1.wav
wrote: dia0_utt10.wav
wrote: dia0_utt11.wav
wrote: dia0_utt12.wav
wrote: dia0_utt13.wav
wrote: dia0_utt2.wav
wrote: dia0_utt3.wav
wrote: dia0_utt4.wav
wrote: dia0_utt5.wav
wrote: dia0_utt6.wav
wrote: dia0_utt7.wav
wrote: dia0_utt8.wav
wrote: dia0_utt9.wav
wrote: dia1000_utt0.wav
wrote: dia1000_utt1.wav
wrote: dia1000_utt2.wav
wrote: dia1000_utt3.wav
wrote: dia1001_utt0.wav
wrote: dia1001_utt1.wav
wrote: dia1001_utt10.wav
wrote: dia1001_utt11.wav
wrote: dia1001_utt2.wav
wrote: dia1001_utt3.wav
wrote: dia1001_utt4.wav
wrote: dia1001_utt5.wav
wrote: dia1001_utt6.wav
wrote: dia1001_utt7.wav
wrote: dia1001_utt8.wav
wrote: dia1001_utt9.wav
wrote: dia100

### Loading the audio Labels and aligning with data

In [None]:
import src.load_labels.audio_labels
importlib.reload(src.load_labels.audio_labels)
from src.load_labels.audio_labels import load_meld_labels

# Usage example:
pickle_path = "data/features/data_emotion.p"
utt_id_to_label, label_index = load_meld_labels(pickle_path)

print(list(utt_id_to_label.keys())[:5])
print(label_index)


['0_0', '0_1', '0_2', '0_3', '0_5']
{'neutral': 0, 'surprise': 1, 'fear': 2, 'sadness': 3, 'joy': 4, 'disgust': 5, 'anger': 6}


In [80]:
# Usage:
import src.load_labels.audio_align
importlib.reload(src.load_labels.audio_align)
from src.load_labels.audio_align import split_and_align_data
train_data, dev_data, test_data = split_and_align_data("data/audio_data", utt_id_to_label)

2025-10-19 17:41:41 [INFO] Processing 9988 audio files in train_splits
2025-10-19 17:41:41 [INFO] Processing 1112 audio files in dev_splits
2025-10-19 17:41:41 [INFO] Processing 2747 audio files in test_splits


### Detecting outliers

In [None]:
import src.data_pipeline_helper_functions.outlier_detection
importlib.reload(src.data_pipeline_helper_functions.outlier_detection)
from src.data_pipeline_helper_functions.outlier_detection import detect_outliers

# Assuming train_data is defined as (paths_list, labels_list)
filtered_train_data = detect_outliers(train_data)

2025-10-19 17:41:51 [INFO] Extracting heuristic features from 9988 samples




2025-10-19 17:43:09 [INFO] Number of samples before filtering: 9988
2025-10-19 17:43:09 [INFO] Number of samples after filtering: 9496


In [90]:
filtered_dev_data = detect_outliers(dev_data)
filtered_test_data = detect_outliers(test_data)

2025-10-19 17:51:27 [INFO] Extracting heuristic features from 1112 samples




2025-10-19 17:51:46 [INFO] Number of samples before filtering: 1112
2025-10-19 17:51:46 [INFO] Number of samples after filtering: 1056
2025-10-19 17:51:46 [INFO] Extracting heuristic features from 2615 samples




2025-10-19 17:52:32 [INFO] Number of samples before filtering: 2615
2025-10-19 17:52:32 [INFO] Number of samples after filtering: 2486


In [91]:
import shutil

def save_filtered_audio_video(filtered_train_data, audio_base_dir, video_base_dir, output_audio_dir, output_video_dir):
    """
    Save filtered audio and corresponding video files to new directories.

    Args:
        filtered_train_data (list of tuples): (audio_path, label).
        audio_base_dir (str or Path): Base dir of original audio files.
        video_base_dir (str or Path): Base dir containing corresponding videos.
        output_audio_dir (str or Path): Output directory for filtered audio.
        output_video_dir (str or Path): Output directory for filtered video.
    """
    audio_base_dir = Path(audio_base_dir)
    video_base_dir = Path(video_base_dir)
    output_audio_dir = Path(output_audio_dir)
    output_video_dir = Path(output_video_dir)

    for audio_path, label in filtered_train_data:
        audio_rel_path = Path(audio_path).relative_to(audio_base_dir)
        dest_audio_path = output_audio_dir / audio_rel_path
        dest_audio_path.parent.mkdir(parents=True, exist_ok=True)

        # Copy audio
        shutil.copy2(audio_path, dest_audio_path)

        # Corresponding video path, assuming same filename but .mp4 extension
        video_path = video_base_dir / audio_rel_path.with_suffix('.mp4')
        if video_path.exists():
            video_rel_path = video_path.relative_to(video_base_dir)
            dest_video_path = output_video_dir / video_rel_path
            dest_video_path.parent.mkdir(parents=True, exist_ok=True)
            shutil.copy2(video_path, dest_video_path)
        else:
            print(f"Warning: No corresponding video found for {audio_path}")

    print(f"Copied filtered audio to {output_audio_dir}")
    print(f"Copied corresponding videos to {output_video_dir}")



In [92]:
def check_matching_filenames(audio_dir, video_dir):
    audio_dir = Path(audio_dir)
    video_dir = Path(video_dir)

    # Get all audio filenames without extension, relative to audio_dir
    audio_files = [f.with_suffix('').relative_to(audio_dir) for f in audio_dir.rglob('*') if f.is_file()]
    audio_set = set(str(f) for f in audio_files)

    # Get all video filenames without extension, relative to video_dir
    video_files = [f.with_suffix('').relative_to(video_dir) for f in video_dir.rglob('*') if f.is_file()]
    video_set = set(str(f) for f in video_files)

    # Compare sets
    missing_in_video = audio_set - video_set
    missing_in_audio = video_set - audio_set

    if not missing_in_video and not missing_in_audio:
        print("All filenames match between audio and video directories.")
        return True
    else:
        if missing_in_video:
            print("Files missing in video directory:")
            for f in missing_in_video:
                print(f)
        if missing_in_audio:
            print("Files missing in audio directory:")
            for f in missing_in_audio:
                print(f)
        return False

##### saving filtered train data

In [86]:
save_filtered_audio_video(
    filtered_train_data=filtered_train_data,
    audio_base_dir="data/audio_data/train_splits",
    video_base_dir="data/video_data/train_splits",
    output_audio_dir="data/filtered_audio/train_splits",
    output_video_dir="data/filtered_video/train_splits"
)


Copied filtered audio to data\filtered_audio\train_splits
Copied corresponding videos to data\filtered_video\train_splits


In [88]:
# Example usage
check_matching_filenames('data/filtered_audio/train_splits', 'data/filtered_video/train_splits')


All filenames match between audio and video directories.


True

##### saving filtered dev data

In [93]:
save_filtered_audio_video(
    filtered_train_data=filtered_dev_data,
    audio_base_dir="data/audio_data/dev_splits",
    video_base_dir="data/video_data/dev_splits",
    output_audio_dir="data/filtered_audio/dev_splits",
    output_video_dir="data/filtered_video/dev_splits"
)


Copied filtered audio to data\filtered_audio\dev_splits
Copied corresponding videos to data\filtered_video\dev_splits


In [94]:
# Example usage
check_matching_filenames('data/filtered_audio/dev_splits', 'data/filtered_video/dev_splits')

All filenames match between audio and video directories.


True

##### saving filtered test data

In [95]:
save_filtered_audio_video(
    filtered_train_data=filtered_test_data,
    audio_base_dir="data/audio_data/test_splits",
    video_base_dir="data/video_data/test_splits",
    output_audio_dir="data/filtered_audio/test_splits",
    output_video_dir="data/filtered_video/test_splits"
)

Copied filtered audio to data\filtered_audio\test_splits
Copied corresponding videos to data\filtered_video\test_splits


In [96]:
# Example usage
check_matching_filenames('data/filtered_audio/test_splits', 'data/filtered_video/test_splits')

All filenames match between audio and video directories.


True

### Audio and video Feature Extraction and Siamese NN setup

In [1]:
import torch, torchvision
print(torch.__version__)
print(torch.version.cuda)      # should be 11.8
print(torch.cuda.is_available())


import sys
import logging

# Logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    stream=sys.stdout
)
logger = logging.getLogger(__name__)

2.6.0+cu124
12.4
True


In [2]:
from transformers import Wav2Vec2Processor, Wav2Vec2Model
import clip

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load pretrained models
wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
wav2vec_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base").to(device).eval()
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
clip_model = clip_model.eval()

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def extract_audio_embedding_segment(waveform_segment_1d, sr):
    inputs = wav2vec_processor(waveform_segment_1d, sampling_rate=sr, return_tensors="pt", padding=True)
    with torch.no_grad():
        outputs = wav2vec_model(**{k: v.to(device) for k, v in inputs.items()})
    hidden_states = outputs.last_hidden_state.mean(dim=1).squeeze(0)
    embedding = hidden_states / hidden_states.norm()
    return embedding.float()

In [4]:
import cv2
from PIL import Image

# Video feature extraction for segment (list of frames)
def extract_video_embedding_segment(frames):
    if len(frames) == 0:
        return None
    processed = torch.cat([clip_preprocess(Image.fromarray(cv2.cvtColor(f, cv2.COLOR_BGR2RGB))).unsqueeze(0).to(device) for f in frames])
    with torch.no_grad():
        feats = clip_model.encode_image(processed)
    feats = feats.float()
    feats /= feats.norm(dim=-1, keepdim=True)
    avg_feat = feats.mean(dim=0)
    avg_feat /= avg_feat.norm()
    return avg_feat.float()

In [5]:
import numpy as np

def sliding_window_audio(y, sr, win_sec=1.0, hop_sec=0.5):
    win_samples = int(win_sec * sr)
    hop_samples = int(hop_sec * sr)
    segments = []
    n = len(y)
    if n < win_samples:
        pad = np.zeros(win_samples - n, dtype=y.dtype)
        seg = np.concatenate([y, pad], axis=0)
        segments.append(seg)
    else:
        for start in range(0, n - win_samples + 1, hop_samples):
            segments.append(y[start:start+win_samples])
        if len(segments) == 0:
            segments.append(y[-win_samples:])
    return segments

In [6]:
import cv2
import threading
import queue
from typing import Optional, Tuple, List

class ThreadedVideoReader:
    def __init__(self, path: str, queue_size: int = 96, drop_oldest: bool = True):
        self.path = path
        self.cap = cv2.VideoCapture(path)
        if not self.cap.isOpened():
            raise RuntimeError(f"Failed to open video: {path}")
        self.fps = self.cap.get(cv2.CAP_PROP_FPS) or 30.0
        self.frame_count = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
        self.q = queue.Queue(maxsize=queue_size)
        self.drop_oldest = drop_oldest
        self.stopped = False
        self.reader_thread = threading.Thread(target=self._reader, daemon=True)
        self.idx = 0

    def start(self):
        self.reader_thread.start()
        return self

    def _reader(self):
        try:
            while not self.stopped:
                ret, frame = self.cap.read()
                if not ret:
                    break
                try:
                    self.q.put((self.idx, frame), block=False)
                except queue.Full:
                    if self.drop_oldest:
                        try:
                            _ = self.q.get_nowait()
                        except queue.Empty:
                            pass
                        try:
                            self.q.put((self.idx, frame), block=False)
                        except queue.Full:
                            pass
                    else:
                        self.q.put((self.idx, frame), block=True)
                self.idx += 1
        finally:
            self.stopped = True
            self.cap.release()
            try:
                self.q.put_nowait((None, None))
            except queue.Full:
                pass

    def read_next(self, timeout: Optional[float] = 0.5) -> Optional[Tuple[int, any]]:
        if self.stopped and self.q.empty():
            return None
        try:
            item = self.q.get(timeout=timeout)
            idx, frame = item
            if idx is None:
                return None
            return item
        except queue.Empty:
            if self.stopped:
                return None
            return None

    def stop(self):
        self.stopped = True
        if self.cap:
            self.cap.release()

def iter_frame_windows(reader: ThreadedVideoReader, win_sec: float = 1.0, hop_sec: float = 0.5):
    fps = reader.fps
    win_frames = max(1, int(round(win_sec * fps)))
    hop_frames = max(1, int(round(hop_sec * fps)))
    buffer: List[Tuple[int, any]] = []

    # Prefill
    while len(buffer) < win_frames:
        item = reader.read_next()
        if item is None:
            break
        buffer.append(item)

    # Emit windows and advance by hop
    while len(buffer) >= win_frames:
        frames = [f for _, f in buffer[:win_frames]]
        yield frames
        advance = min(hop_frames, len(buffer))
        buffer = buffer[advance:]
        while len(buffer) < win_frames:
            item = reader.read_next()
            if item is None:
                break
            buffer.append(item)


In [7]:
import librosa

# keep existing extract_audio_embedding_segment, extract_video_embedding_segment

def extract_aligned_spike_trains(
    audio_path,
    video_path,
    encoder,
    win_sec=1.0,
    hop_sec=0.5,
    sr=16000,
    max_segments: int = None,   # optional cap
):
    # Audio segments (full waveform → window slices)
    y, _ = librosa.load(audio_path, sr=sr, mono=True)
    audio_segments = sliding_window_audio(y, sr, win_sec, hop_sec)

    # Video segments via threaded prefetcher + iterator
    rdr = ThreadedVideoReader(video_path, queue_size=96, drop_oldest=True).start()
    video_segments = []
    try:
        for frames in iter_frame_windows(rdr, win_sec=win_sec, hop_sec=hop_sec):
            video_segments.append(frames)
            if max_segments is not None and len(video_segments) >= max_segments:
                break
    finally:
        rdr.stop()

    # Align counts
    S = min(len(audio_segments), len(video_segments))
    if S == 0:
        raise RuntimeError(f"No aligned segments for {audio_path} and {video_path}")

    if max_segments is not None:
        S = min(S, max_segments)

    audio_spikes_seq, video_spikes_seq = [], []
    for i in range(S):
        aud_seg = torch.tensor(audio_segments[i], dtype=torch.float32)
        vid_seg = video_segments[i]
        if len(vid_seg) == 0:
            continue

        # Embeddings per segment
        audio_emb = extract_audio_embedding_segment(aud_seg, sr)
        video_emb = extract_video_embedding_segment(vid_seg)
        if audio_emb is None or video_emb is None:
            continue

        # Spike trains [T, F]
        audio_spikes = encoder.encode(audio_emb).float()
        video_spikes = encoder.encode(video_emb).float()
        audio_spikes_seq.append(audio_spikes.unsqueeze(0))  # [1, T, Fa]
        video_spikes_seq.append(video_spikes.unsqueeze(0))  # [1, T, Fv]

    if len(audio_spikes_seq) == 0 or len(video_spikes_seq) == 0:
        raise RuntimeError(f"No valid spike segments produced for {audio_path} and {video_path}")

    audio_spikes_seq = torch.cat(audio_spikes_seq, dim=0)  # [S, T, Fa]
    video_spikes_seq = torch.cat(video_spikes_seq, dim=0)  # [S, T, Fv]
    return audio_spikes_seq, video_spikes_seq


In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import snntorch as snn
from snntorch import surrogate

# Ensure these functions are imported from your pipeline module
# from your_module import extract_aligned_spike_trains

class SpikingSiameseNetwork(nn.Module):
    def __init__(self, audio_dim, video_dim, encoder, embed_dim=128, num_steps=25):
        super().__init__()
        self.num_steps = num_steps
        self.encoder = encoder

        # Use the correct arg name for your snntorch version (spike_grad is most common)
        self.surrogate_func = surrogate.fast_sigmoid(slope=25)

        # Audio branch
        self.audio_fc1 = nn.Linear(audio_dim, 256)
        self.audio_lif1 = snn.Leaky(beta=0.9, spike_grad=self.surrogate_func)
        self.audio_fc2 = nn.Linear(256, embed_dim)
        self.audio_lif2 = snn.Leaky(beta=0.9, spike_grad=self.surrogate_func)

        # Video branch
        self.video_fc1 = nn.Linear(video_dim, 256)
        self.video_lif1 = snn.Leaky(beta=0.9, spike_grad=self.surrogate_func)
        self.video_fc2 = nn.Linear(256, embed_dim)
        self.video_lif2 = snn.Leaky(beta=0.9, spike_grad=self.surrogate_func)

    # Single-sample segment: audio_spikes/video_spikes: [T, feat]
    def forward_once_segment(self, audio_spikes, video_spikes):
        audio_spikes = audio_spikes.float()
        video_spikes = video_spikes.float()

        mem_audio1 = self.audio_lif1.init_leaky()
        mem_audio2 = self.audio_lif2.init_leaky()
        mem_video1 = self.video_lif1.init_leaky()
        mem_video2 = self.video_lif2.init_leaky()

        for t in range(self.num_steps):
            cur_audio = self.audio_fc1(audio_spikes[t])
            spk_audio1, mem_audio1 = self.audio_lif1(cur_audio, mem_audio1)
            cur_audio2 = self.audio_fc2(spk_audio1)
            spk_audio2, mem_audio2 = self.audio_lif2(cur_audio2, mem_audio2)

            cur_video = self.video_fc1(video_spikes[t])
            spk_video1, mem_video1 = self.video_lif1(cur_video, mem_video1)
            cur_video2 = self.video_fc2(spk_video1)
            spk_video2, mem_video2 = self.video_lif2(cur_video2, mem_video2)

        combined = torch.cat([spk_audio2, spk_video2], dim=0)
        combined = F.normalize(combined.unsqueeze(0), p=2, dim=1).squeeze(0)
        return combined

    # Single-sample sequence: audio_spikes_seq/video_spikes_seq: [S, T, feat]
    def forward_sequence(self, audio_spikes_seq, video_spikes_seq):
        embeddings = []
        num_segments = audio_spikes_seq.shape[0]
        for seg in range(num_segments):
            emb = self.forward_once_segment(audio_spikes_seq[seg], video_spikes_seq[seg])
            embeddings.append(emb.unsqueeze(0))
        embeddings = torch.cat(embeddings, dim=0)  # [S, embed_dim]
        final_emb = embeddings.mean(dim=0)         # [embed_dim]
        return final_emb

    # Batched segment: audio_spikes/video_spikes: [B, T, feat]
    def forward_once_segment_batched(self, audio_spikes, video_spikes):
        audio_spikes = audio_spikes.float()
        video_spikes = video_spikes.float()

        B, T, _ = audio_spikes.shape
        mem_audio1 = self.audio_lif1.init_leaky(batch_size=B)
        mem_audio2 = self.audio_lif2.init_leaky(batch_size=B)
        mem_video1 = self.video_lif1.init_leaky(batch_size=B)
        mem_video2 = self.video_lif2.init_leaky(batch_size=B)

        for t in range(T):
            cur_audio = self.audio_fc1(audio_spikes[:, t, :])
            spk_audio1, mem_audio1 = self.audio_lif1(cur_audio, mem_audio1)
            cur_audio2 = self.audio_fc2(spk_audio1)
            spk_audio2, mem_audio2 = self.audio_lif2(cur_audio2, mem_audio2)

            cur_video = self.video_fc1(video_spikes[:, t, :])
            spk_video1, mem_video1 = self.video_lif1(cur_video, mem_video1)
            cur_video2 = self.video_fc2(spk_video1)
            spk_video2, mem_video2 = self.video_lif2(cur_video2, mem_video2)

        combined = torch.cat([spk_audio2, spk_video2], dim=1)  # [B, 2*embed_dim]
        combined = F.normalize(combined, p=2, dim=1)           # [B, 2*embed_dim]
        return combined

    # Batched sequence: a_seq/v_seq: [B, S, T, feat]
    def forward_sequence_batched(self, a_seq, v_seq):
        a_seq = a_seq.float()
        v_seq = v_seq.float()

        B, S, T, F = a_seq.shape
        embs = []
        for s in range(S):
            emb = self.forward_once_segment_batched(a_seq[:, s], v_seq[:, s])  # [B, 2*embed_dim]
            embs.append(emb.unsqueeze(1))
        embs = torch.cat(embs, dim=1)  # [B, S, 2*embed_dim]
        final = embs.mean(dim=1)        # [B, 2*embed_dim]
        return final

    # Path-based forward for quick testing; still valid
    def forward(self, audio_path1, video_path1, audio_path2, video_path2):
        audio_spikes1, video_spikes1 = extract_aligned_spike_trains(audio_path1, video_path1, self.encoder)
        audio_spikes2, video_spikes2 = extract_aligned_spike_trains(audio_path2, video_path2, self.encoder)
        emb1 = self.forward_sequence(audio_spikes1, video_spikes1)
        emb2 = self.forward_sequence(audio_spikes2, video_spikes2)
        return emb1, emb2


In [9]:
from src.model_training.spike_encoder import DynamicSpikeEncoder
encoder = DynamicSpikeEncoder(T=25, use_dynamic_encoding=True)

In [10]:
model = SpikingSiameseNetwork(audio_dim=768, video_dim=512, encoder=encoder, embed_dim=128, num_steps=25).to(device).float()

In [11]:
from src.model_training.loss import ContrastiveLoss
loss_fn = ContrastiveLoss(margin=1.0)

In [12]:
audio_path1 = "data/filtered_audio/train_splits/dia0_utt0.wav"
video_path1 = "data/filtered_video/train_splits/dia0_utt0.mp4"
audio_path2 = "data/filtered_audio/train_splits/dia0_utt1.wav"
video_path2 = "data/filtered_video/train_splits/dia0_utt1.mp4"
label = torch.tensor(1.0, device=device)

emb1, emb2 = model(audio_path1, video_path1, audio_path2, video_path2)
loss = loss_fn(emb1, emb2, label)
print(f"Loss: {loss.item()}")

Loss: 2.559999989770745e-10


In [13]:
import importlib
import src.load_labels.aud_vid_pairs
importlib.reload(src.load_labels.aud_vid_pairs)
from src.load_labels.aud_vid_pairs import build_audio_video_pairs_csvs

In [None]:
build_audio_video_pairs_csvs(
    audio_base="data/filtered_audio",
    video_base="data/filtered_video",
    emotion_pickle="data/features/data_emotion.p",
    out_base_dir="data/metadata",
    win_sec=1.0,
    hop_sec=0.5,
    include_meta=True,
)

2025-10-19 22:25:29 [INFO] [train_splits] Scanning split directories: A=data\filtered_audio\train_splits, V=data\filtered_video\train_splits


2025-10-19 22:25:30 [INFO] [train_splits] audio=9496, video=9496, common=9496
2025-10-19 22:31:36 [INFO] [train_splits] probed items=9496, bad_audio=0, bad_video=0
2025-10-19 22:31:36 [INFO] [train_splits] valid items with ≥1 aligned segment: 9496


In [None]:
from typing import Dict, Any, List, Tuple, Optional
from torch.utils.data import Dataset

# import your embedding and encoder utilities from your pipeline
# from your_module import (
#     extract_audio_embedding_segment,
#     extract_video_embedding_segment,
#     sliding_window_audio,
#     ThreadedVideoReader,
#     iter_frame_windows,
#     encoder,
# )

class AVPairSegmentsDataset(Dataset):
    """
    Supports both CSV schemas:
      - 5 columns: audio_path1, video_path1, audio_path2, video_path2, label
      - meta CSV: + fps1, win_sec1, hop_sec1, seg_aligned1, fps2, win_sec2, hop_sec2, seg_aligned2
    Returns:
      {'a1':[S1,T,Fa], 'v1':[S1,T,Fv], 'a2':[S2,T,Fa], 'v2':[S2,T,Fv], 'label': int, 'meta': {...}}
    """
    def __init__(
        self,
        csv_file: str,
        sr: int = 16000,
        win_sec: float = 1.0,
        hop_sec: float = 0.5,
        max_segments: Optional[int] = None,
        use_threaded_video: bool = True,
    ):
        self.rows: List[Dict[str, Any]] = []
        with open(csv_file, "r", newline="", encoding="utf-8") as f:
            reader = csv.DictReader(f)
            for r in reader:
                self.rows.append(r)
        self.sr = sr
        self.default_win = win_sec
        self.default_hop = hop_sec
        self.max_segments = max_segments
        self.use_threaded_video = use_threaded_video  # toggles ThreadedVideoReader vs legacy loader [web:98][web:100]

    def _get_params(self, r: Dict[str, Any], side: int):
        # Side-specific parameters with fallback to defaults if meta columns are absent [web:84]
        win = float(r.get(f"win_sec{side}", self.default_win))
        hop = float(r.get(f"hop_sec{side}", self.default_hop))
        seg_cap = r.get(f"seg_aligned{side}")
        seg_aligned = int(seg_cap) if seg_cap is not None and str(seg_cap).isdigit() else None
        return win, hop, seg_aligned

    def _load_audio_segments(self, audio_path: str, win: float, hop: float) -> List[np.ndarray]:
        y, _ = librosa.load(audio_path, sr=self.sr, mono=True)
        return sliding_window_audio(y, self.sr, win, hop)  # reuses your function [web:90]

    def _iter_video_segments(self, video_path: str, win: float, hop: float):
        if not self.use_threaded_video:
            # Fallback to legacy full decode; not recommended for long clips [web:98]
            frames_by_window = sliding_window_video_frames(video_path, win, hop)
            for frames in frames_by_window:
                yield frames
            return
        rdr = ThreadedVideoReader(video_path, queue_size=96, drop_oldest=True).start()  # threaded prefetch [web:98]
        try:
            for frames in iter_frame_windows(rdr, win_sec=win, hop_sec=hop):  # windowed, overlapping [web:98]
                yield frames
        finally:
            rdr.stop()

    def _segments_to_spikes_av(
        self,
        audio_segments: List[np.ndarray],
        video_windows_iter,
        seg_aligned: Optional[int] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        a_seq: List[torch.Tensor] = []
        v_seq: List[torch.Tensor] = []
        # Iterate video windows on the fly; align with audio segs count and optional seg cap [web:84]
        limit = seg_aligned if seg_aligned is not None else len(audio_segments)
        if self.max_segments is not None:
            limit = min(limit, self.max_segments)

        count = 0
        for i, vid_seg in enumerate(video_windows_iter):
            if i >= len(audio_segments):
                break
            if count >= limit:
                break
            aud_seg = torch.tensor(audio_segments[i], dtype=torch.float32)
            if len(vid_seg) == 0:
                continue
            a_emb = extract_audio_embedding_segment(aud_seg, self.sr)   # [Fa]
            v_emb = extract_video_embedding_segment(vid_seg)             # [Fv]
            if a_emb is None or v_emb is None:
                continue
            a_spk = encoder.encode(a_emb)  # [T, Fa]
            v_spk = encoder.encode(v_emb)  # [T, Fv]
            a_seq.append(a_spk.unsqueeze(0))
            v_seq.append(v_spk.unsqueeze(0))
            count += 1

        if len(a_seq) == 0 or len(v_seq) == 0:
            raise RuntimeError("No valid aligned segments after encoding.")  # [web:84]
        a_seq = torch.cat(a_seq, dim=0)  # [S, T, Fa]
        v_seq = torch.cat(v_seq, dim=0)  # [S, T, Fv]
        return a_seq, v_seq

    def _build_pair(self, audio_path: str, video_path: str, win: float, hop: float, seg_aligned: Optional[int]):
        a_segs = self._load_audio_segments(audio_path, win, hop)  # [S_a] arrays [web:90]
        v_iter = self._iter_video_segments(video_path, win, hop)  # generator of frame lists [web:98]
        return self._segments_to_spikes_av(a_segs, v_iter, seg_aligned=seg_aligned)

    def __len__(self) -> int:
        return len(self.rows)  # [web:84]

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        r = self.rows[idx]
        a1p, v1p = r["audio_path1"], r["video_path1"]
        a2p, v2p = r["audio_path2"], r["video_path2"]
        y = int(r["label"])

        win1, hop1, seg1 = self._get_params(r, 1)
        win2, hop2, seg2 = self._get_params(r, 2)

        a1, v1 = self._build_pair(a1p, v1p, win1, hop1, seg1)  # [S1,T,Fa/Fv] [web:84]
        a2, v2 = self._build_pair(a2p, v2p, win2, hop2, seg2)

        return {
            "a1": a1, "v1": v1,
            "a2": a2, "v2": v2,
            "label": y,
            "meta": {
                "a1": a1p, "v1": v1p, "a2": a2p, "v2": v2p,
                "win1": win1, "hop1": hop1, "seg1": seg1,
                "win2": win2, "hop2": hop2, "seg2": seg2,
            }
        }  # [web:84]



In [None]:
def collate_av_pair_sequences(batch: List[Dict[str, Any]], pad_value: float = 0.0) -> Dict[str, Any]:
    # Find max segments in batch for each branch
    max_S1 = max(b["a1"].shape[0] for b in batch)
    max_S2 = max(b["a2"].shape[0] for b in batch)

    a1_list, v1_list, a2_list, v2_list, y_list, metas = [], [], [], [], [], []
    for b in batch:
        a1, v1 = b["a1"], b["v1"]  # [S, T, Fa], [S, T, Fv]
        a2, v2 = b["a2"], b["v2"]
        # Pad on S dimension
        def pad_S(x, maxS):
            S, T, F = x.shape
            if S == maxS:
                return x
            pad_tensor = torch.zeros((maxS - S, T, x.shape[2]), dtype=x.dtype)
            return torch.cat([x, pad_tensor.fill_(pad_value)], dim=0)
        a1_list.append(pad_S(a1, max_S1))
        v1_list.append(pad_S(v1, max_S1))
        a2_list.append(pad_S(a2, max_S2))
        v2_list.append(pad_S(v2, max_S2))
        y_list.append(int(b["label"]))
        metas.append(b["meta"])

    a1 = torch.stack(a1_list, dim=0)  # [B, S1, T, Fa]
    v1 = torch.stack(v1_list, dim=0)  # [B, S1, T, Fv]
    a2 = torch.stack(a2_list, dim=0)  # [B, S2, T, Fa]
    v2 = torch.stack(v2_list, dim=0)  # [B, S2, T, Fv]
    y = torch.tensor(y_list, dtype=torch.float32)  # contrastive expects float labels often
    return {"a1": a1, "v1": v1, "a2": a2, "v2": v2, "y": y, "meta": metas}


In [None]:
from torch.utils.data import DataLoader

train_ds = AVPairSegmentsDataset(
    csv_file="data/metadata/train_splits_pairs_with_metadata.csv",
    sr=16000,
    win_sec=1.0,
    hop_sec=0.5,
    max_segments=16,
    use_threaded_video=True
)

train_loader = DataLoader(
    train_ds,
    batch_size=4,
    shuffle=True,
    num_workers=4,        # overlap per-sample I/O/compute
    pin_memory=True,      # for CUDA transfers
    collate_fn=collate_av_pair_sequences,
    persistent_workers=True,
)
