# Multimodal audio+image pipeline

This notebook contains a reorganized, split version of the original monolithic pipeline.
Cells are grouped by purpose (settings, extraction, audio cleaning, dataset, model, training, tests).
Keep this structure for easier refactoring, commenting and testing.

In [1]:
# Libraries

import os, glob, subprocess, math
import numpy as np
from PIL import Image

from tqdm import tqdm

import torch, torch.nn as nn, torch.nn.functional as F, torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models

import torchaudio
import torchaudio.transforms as T
import torchaudio.functional as AF

from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from collections import Counter


In [None]:
# SETTINGS
RAW_VIDEOS   = "D:/Downloads/audio/full clips"     # input videos
AUDIO_DIR    = "D:/Downloads/audio/aud"   # output 2s wavs
IMAGE_DIR    = "D:/Downloads/audio/img"  # output frame jpgs


SAMPLE_RATE  = 48000
CHUNK_SECONDS= 2
N_MELS       = 64
BATCH_SIZE   = 20
EPOCHS       = 15
DEVICE       = "cuda" if torch.cuda.is_available() else "cpu"
# Print device info to confirm CUDA availability and details
print('Torch CUDA available:', torch.cuda.is_available())
print('Using device:', DEVICE)
if torch.cuda.is_available():
    try:
        print('CUDA device count:', torch.cuda.device_count())
        if torch.cuda.device_count() > 0:
            try:
                cur = torch.cuda.current_device()
                print('CUDA current device index:', cur)
                print('CUDA device name:', torch.cuda.get_device_name(cur))
            except Exception as e:
                print('Could not query CUDA device name:', e)
    except Exception as e:
        print('CUDA query error:', e)

label_map = {"full": 0, "half": 1, "empty": 2}
LABELS = {v: k.capitalize() for k, v in label_map.items()}

os.makedirs(AUDIO_DIR, exist_ok=True)
os.makedirs(IMAGE_DIR, exist_ok=True)


In [3]:
# Video -> audio/frame helpers (ffmpeg wrappers)
def extract_audio(video_path, wav_out, sample_rate=SAMPLE_RATE):
    cmd = ["ffmpeg", "-y", "-i", video_path, "-ar", str(sample_rate), "-ac", "1", wav_out]
    subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)

def extract_frame(video_path, image_out, timestamp):
    cmd = ["ffmpeg", "-y", "-i", video_path, "-ss", f"{timestamp:.3f}", "-vframes", "1", image_out]
    subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)

def fast_split(video_path, out_audio_dir, out_frame_dir):
    base = os.path.splitext(os.path.basename(video_path))[0]
    audio_pattern = os.path.join(out_audio_dir, f"{base}_seg%03d.wav")
    frame_pattern = os.path.join(out_frame_dir, f"{base}_frame%03d.jpg")

    subprocess.run(["ffmpeg", "-i", video_path, "-f", "segment", "-segment_time", "2", "-ar", "48000", "-ac", "1", "-c:a", "pcm_s16le", audio_pattern], check=True)
    subprocess.run(["ffmpeg", "-i", video_path, "-vf", "fps=0.5", "-qscale:v", "2", frame_pattern], check=True)

def prepare_dataset_from_videos(raw_videos=RAW_VIDEOS):
    video_files = []
    video_files.extend(glob.glob(os.path.join(raw_videos, "*.MOV")))
    video_files.extend(glob.glob(os.path.join(raw_videos, "*.mov")))
    video_files.extend(glob.glob(os.path.join(raw_videos, "*.mp4")))
    if not video_files:
        print(f"[WARN] No videos in {raw_videos}")
    for vf in video_files:
        print("[PROCESSING]", vf)
        fast_split(vf, AUDIO_DIR, IMAGE_DIR)


In [4]:
# IMAGE TRANSFORM and audio cleaning constants
gray = transforms.Grayscale(num_output_channels=3)
transform_image = transforms.Compose([
    transforms.Resize((224,224)),
    gray,
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

# Audio cleaning for a 2s chunk (band-pass + quiet-frame spectral subtraction)
SR        = SAMPLE_RATE
BAND_LO   = 12000
BAND_HI   = 18000
NFFT      = 2048
HOP       = 512
OVERSUB   = 1.2
QUIET_PCT = 0.20

def _stft(x: torch.Tensor) -> torch.Tensor:
    if x.dim() == 2: x = x.squeeze(0)
    return torch.stft(x, n_fft=NFFT, hop_length=HOP, win_length=NFFT,
                      window=torch.hann_window(NFFT), return_complex=True, center=True)

def _istft(S: torch.Tensor, length: int) -> torch.Tensor:
    win = torch.hann_window(NFFT, device=S.device, dtype=torch.float32)
    try:
        y = torch.istft(S, n_fft=NFFT, hop_length=HOP, win_length=NFFT, window=win, length=length, center=True)
    except (TypeError, RuntimeError):
        y = torch.istft(torch.view_as_real(S), n_fft=NFFT, hop_length=HOP, win_length=NFFT, window=win, length=length, center=True)
    return y

def bandpass_chunk(waveform: torch.Tensor, sr: int) -> torch.Tensor:
    if waveform.dim() == 1:
        waveform = waveform.unsqueeze(0)
    x = waveform.squeeze(0).to(torch.float32)
    S = torch.stft(x, n_fft=NFFT, hop_length=HOP, win_length=NFFT, window=torch.hann_window(NFFT, dtype=torch.float32), return_complex=True, center=True)
    freqs = np.fft.rfftfreq(NFFT, d=1.0/sr)
    lo = int(np.searchsorted(freqs, BAND_LO))
    hi = int(np.searchsorted(freqs, BAND_HI))
    lo = max(lo, 0); hi = min(hi, S.shape[0])
    mask = torch.zeros_like(S, dtype=torch.bool)
    mask[lo:hi, :] = True
    S_bp = torch.where(mask, S, torch.zeros_like(S))
    y = _istft(S_bp, length=x.numel())
    y = y - y.mean()
    y = torch.nan_to_num(y, nan=0.0, posinf=0.0, neginf=0.0)
    return y.unsqueeze(0)

def spectral_subtract_quiet_frames(waveform: torch.Tensor, sr: int) -> torch.Tensor:
    x = waveform.squeeze(0)
    Tlen = x.shape[-1]
    S = _stft(x)
    Mag = S.abs()
    Pow = Mag**2
    freqs = np.fft.rfftfreq(NFFT, d=1.0/sr)
    lo = int(np.searchsorted(freqs, BAND_LO))
    hi = int(np.searchsorted(freqs, BAND_HI))
    lo = max(lo, 0); hi = min(hi, Mag.shape[0])
    band_pow_per_frame = Pow[lo:hi].mean(dim=0)
    T_frames = band_pow_per_frame.numel()
    k = max(1, int(round(QUIET_PCT * T_frames)))
    vals, idxs = torch.topk(-band_pow_per_frame, k)
    quiet_mask = torch.zeros_like(band_pow_per_frame, dtype=torch.bool)
    quiet_mask[idxs] = True
    Npsd = Pow[:, quiet_mask].mean(dim=1, keepdim=True)
    Pclean = torch.clamp(Pow - OVERSUB * Npsd, min=0.0)
    Mag_clean = torch.sqrt(Pclean + 1e-12)
    S_clean = Mag_clean * torch.exp(1j * S.angle())
    y_clean = _istft(S_clean, length=Tlen)
    y_clean = y_clean - y_clean.mean()
    y_clean = torch.nan_to_num(y_clean, nan=0.0, posinf=0.0, neginf=0.0)
    return y_clean.unsqueeze(0)

def clean_chunk(waveform: torch.Tensor, sr: int) -> torch.Tensor:
    if waveform.dim() == 2 and waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)
    elif waveform.dim() == 1:
        waveform = waveform.unsqueeze(0)
    if sr != SR:
        waveform = torchaudio.functional.resample(waveform, sr, SR)
        sr = SR
    y_bp = bandpass_chunk(waveform, sr)
    y_cl = spectral_subtract_quiet_frames(y_bp, sr)
    return y_cl


In [5]:
# DATASET
class AudioImageDataset(Dataset):
    def __init__(self, audio_dir, image_dir, label_map, transform_image=None, sample_rate=SAMPLE_RATE, n_mels=N_MELS, use_filters=True):
        self.audio_dir = audio_dir
        self.image_dir = image_dir
        self.label_map = label_map
        self.transform_image = transform_image
        self.sample_rate = sample_rate
        self.n_mels = n_mels
        self.use_filters = use_filters
        self.audio_files, self.labels, self.video_ids = [], [], []
        for file in os.listdir(audio_dir):
            if file.endswith(".wav") and "_seg" in file:
                frame_file = file.replace("seg", "frame").replace(".wav", ".jpg")
                if not os.path.exists(os.path.join(image_dir, frame_file)):
                    continue
                base_name = file.split("_seg")[0]
                label_str = base_name.split("_")[0].lower()
                if label_str in label_map:
                    self.audio_files.append(file)
                    self.labels.append(label_map[label_str])
                    self.video_ids.append(file)
        self.mel = T.MelSpectrogram(sample_rate=self.sample_rate, n_fft=1024, hop_length=512, n_mels=self.n_mels)
        self.db  = T.AmplitudeToDB()
    def __len__(self):
        return len(self.audio_files)
    def __getitem__(self, idx):
        audio_file = self.audio_files[idx]
        label      = self.labels[idx]
        seg_id     = self.video_ids[idx]
        a_path = os.path.join(self.audio_dir, audio_file)
        waveform, sr = torchaudio.load(a_path)
        if self.use_filters:
            waveform = clean_chunk(waveform, sr)
        else:
            if sr != self.sample_rate:
                waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate)
        spec = self.mel(waveform)
        spec = self.db(spec)
        spec = (spec - spec.mean()) / (spec.std() + 1e-6)
        spec = F.interpolate(spec.unsqueeze(0), size=(224,224), mode="bilinear", align_corners=False)
        spec = spec.mean(dim=1)
        frame_file = audio_file.replace("seg", "frame").replace(".wav", ".jpg")
        img_path = os.path.join(self.image_dir, frame_file)
        img = Image.open(img_path).convert("RGB")
        if self.transform_image:
            img = self.transform_image(img)
        return spec, img, label, seg_id


In [6]:
# DATA LOADERS + MODEL
def create_loaders(dataset, batch_size=BATCH_SIZE, num_workers=4):
    train_size = int(0.8 * len(dataset))
    val_size   = len(dataset) - train_size
    train_ds, val_ds = random_split(dataset, [train_size, val_size])
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    print(f"Train samples: {len(train_ds)}, Val samples: {len(val_ds)}")
    return train_loader, val_loader

class MultiModalResNet(nn.Module):
    def __init__(self, num_classes=3, pretrained_image=True, pretrained_audio=True):
        super().__init__()
        self.audio_model = models.resnet18(pretrained=pretrained_audio)
        self.audio_model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.audio_model.fc = nn.Identity()
        self.image_model = models.resnet18(pretrained=pretrained_image)
        self.image_model.fc = nn.Identity()
        self.fc = nn.Linear(512*2, num_classes)
    def forward(self, audio, image, return_features=False):
        a = self.audio_model(audio)
        i = self.image_model(image)
        fused = torch.cat([a, i], dim=1)
        out = self.fc(fused)
        if return_features:
            return a, i, fused, out
        return out


In [7]:
# FEATURE VISUALIZATION (t-SNE)
def visualize_features(model, dataloader, title_suffix=""):
    model.eval()
    audio_feats, image_feats, fused_feats, labels_all = [], [], [], []
    with torch.no_grad():
        for audio, img, labels, _ in dataloader:
            audio, img, labels = audio.to(DEVICE), img.to(DEVICE), labels.to(DEVICE)
            a, i, fused, _ = model(audio, img, return_features=True)
            audio_feats.append(a.cpu()); image_feats.append(i.cpu()); fused_feats.append(fused.cpu())
            labels_all.append(labels.cpu())
    audio_feats = torch.cat(audio_feats).numpy()
    image_feats = torch.cat(image_feats).numpy()
    fused_feats = torch.cat(fused_feats).numpy()
    labels_all = torch.cat(labels_all).numpy()
    print("Label counts:", Counter(labels_all))
    n = fused_feats.shape[0]
    perpl = min(30, max(2, n//3))
    tsne = TSNE(n_components=2, random_state=42, perplexity=perpl)
    A2 = tsne.fit_transform(audio_feats)
    I2 = tsne.fit_transform(image_feats)
    F2 = tsne.fit_transform(fused_feats)
    fig, axes = plt.subplots(1,3, figsize=(18,6))
    for data2d, name, ax in zip([A2, I2, F2], ["Audio", "Image", "Fused"], axes):
        for l in sorted(set(labels_all)):
            idx = labels_all == l
            ax.scatter(data2d[idx,0], data2d[idx,1], alpha=0.6, s=20, label=LABELS[l])
        ax.set_title(f"{name} Features {title_suffix}"); ax.legend()
    plt.tight_layout(); plt.show()

# TRAINING
def train():
    dataset = AudioImageDataset(audio_dir=AUDIO_DIR, image_dir=IMAGE_DIR, label_map=label_map, transform_image=transform_image, sample_rate=SAMPLE_RATE, n_mels=N_MELS, use_filters=True)
    if len(dataset) == 0:
        print("[ERROR] No segments found. Did you run prepare_dataset_from_videos()?"); return
    train_loader, val_loader = create_loaders(dataset, batch_size=BATCH_SIZE)
    model = MultiModalResNet(num_classes=len(label_map)).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()
    print("Visualizing features BEFORE training...")
    visualize_features(model, val_loader, title_suffix="(Before)")
    for epoch in tqdm(range(EPOCHS)):
        model.train()
        running_loss, correct, total = 0.0, 0, 0
        for audio, img, labels, seg_ids in train_loader:
            audio, img, labels = audio.to(DEVICE), img.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(audio, img)
            loss = criterion(outputs, labels)
            loss.backward(); optimizer.step()
            running_loss += loss.item() * audio.size(0)
            _, preds = outputs.max(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
        train_loss = running_loss / max(1,total)
        train_acc  = correct / max(1,total)
        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        val_preds, val_labels, val_segids = [], [], []
        with torch.no_grad():
            for audio, img, labels, seg_ids in val_loader:
                audio, img, labels = audio.to(DEVICE), img.to(DEVICE), labels.to(DEVICE)
                outputs = model(audio, img)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * audio.size(0)
                _, preds = outputs.max(1)
                val_correct += (preds == labels).sum().item()
                val_total   += labels.size(0)
                val_preds.extend(preds.cpu().tolist())
                val_labels.extend(labels.cpu().tolist())
                val_segids.extend(seg_ids)
        val_loss /= max(1,val_total)
        val_acc   = val_correct / max(1,val_total)
        print(f"Epoch [{epoch+1}/{EPOCHS}] "
              f"Train Loss: {round(train_loss,4):.4f}, Train Acc: {train_acc:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
        visualize_features(model, val_loader, title_suffix=f"(Epoch {epoch+1})")
    print("Visualizing features AFTER training...")
    visualize_features(model, val_loader, title_suffix="(After)")
    os.makedirs("models", exist_ok=True)
    torch.save(model.state_dict(), os.path.join("models", "multimodal_model_battery_v3.pth"))
    print("Training complete and model saved!")



In [None]:
# MAIN
if __name__ == "__main__":
    prepare_dataset_from_videos(RAW_VIDEOS)
    train()

[PROCESSING] D:/Downloads/audio/full clips\0s_cropped-003_part_000.mov
[PROCESSING] D:/Downloads/audio/full clips\100s_cropped-004_part_000.mov


In [None]:
# QUICK SANITY / SHAPE TESTS and seeds
import random
torch.manual_seed(42); np.random.seed(42); random.seed(42)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(42)
# Shape test for model forward pass
model = MultiModalResNet(num_classes=3, pretrained_image=False, pretrained_audio=False).to(DEVICE)
a = torch.randn(2,1,224,224).to(DEVICE)
i = torch.randn(2,3,224,224).to(DEVICE)
out = model(a,i)
print('Forward output shape:', out.shape)




Forward output shape: torch.Size([2, 3])
