# Setup

In [None]:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import re
import os
import glob
from tqdm import tqdm
from sklearn.preprocessing import MinMaxScaler
import multiprocessing as mp
import cv2
from scipy.signal import resample
import wave

import variables as var

In [None]:
prep_df = pd.read_csv("prepped_df.csv")
prep_df

In [None]:
def find_matching_ids(directory):
    npz_ids = set()
    wav_ids = set()

    for file in os.listdir(directory):
        if file.endswith(".npz"):
            npz_ids.add(file[:-4])  # Remove .npz
        elif file.endswith(".wav"):
            wav_ids.add(file[:-4])  # Remove .wav

    common_ids = npz_ids & wav_ids  # Find IDs present in both sets
    return sorted(common_ids)

existed_ids = find_matching_ids(var.TO_PATH)
prep_df = prep_df[prep_df["Id"].isin(existed_ids)]
prep_df = prep_df.reset_index(drop=True)
len(prep_df)

In [None]:
prep_df.sample(10)

# Check distribution

In [None]:
plt.hist(prep_df['NAWP'], density=True, histtype='step', label='NAWP')
plt.hist(prep_df['ECR'], density=True, histtype='step', label='ECR')
plt.legend()
plt.show()

In [None]:
prep_df['NAWP'].min(), prep_df['NAWP'].max(), prep_df['ECR'].min(), prep_df['ECR'].max()

# Load data

In [None]:
def load_video_audio(video_file, audio_file):
    video_data = np.load(video_file)
    video_array = video_data["video"]

    with wave.open(audio_file, 'r') as wf:
        sample_rate = wf.getframerate()
        audio_array = np.frombuffer(wf.readframes(wf.getnframes()), dtype=np.int16)

    return video_array, audio_array, sample_rate

sample_id = prep_df["Id"].sample(1).values[0]
print(f"Sample ID: {sample_id}")

video_loaded, audio_loaded, sr_loaded = load_video_audio(video_file=f"{var.TO_PATH}/{sample_id}.npz", 
                                                         audio_file=f"{var.TO_PATH}/{sample_id}.wav")
print(f"Loaded Video Shape: {video_loaded.shape}")  # (num_frames, height, width, 3)
print(f"Loaded Audio Shape: {audio_loaded.shape}, Sample Rate: {sr_loaded} Hz")

plt.imshow(video_loaded[0])
plt.show()

# Create dataset

In [None]:
import torch
import numpy as np
import wave
import os

class VideoAudioDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe, prep_path, max_frames, max_audio_samples):
        self.dataframe = dataframe
        self.prep_path = prep_path
        self.ids = dataframe["Id"].values  
        self.labels = dataframe["ECR"].values  
        self.max_frames = max_frames  
        self.max_audio_samples = max_audio_samples  

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        sample_id = self.ids[idx]
        label = self.labels[idx]

        video_path = os.path.join(self.prep_path, f"{sample_id}.npz")
        audio_path = os.path.join(self.prep_path, f"{sample_id}.wav")

        video_array, audio_array, sample_rate = load_video_audio(video_path, audio_path)

        # Convert to tensors
        video_tensor = torch.tensor(video_array, dtype=torch.float32).permute(0, 3, 1, 2)  # (num_frames, C, H, W)
        audio_tensor = torch.tensor(audio_array, dtype=torch.float32)
        label_tensor = torch.tensor(label, dtype=torch.float32)

        return {
            "id": sample_id,
            "video": video_tensor,
            "audio": audio_tensor,
            "sample_rate": sample_rate,
            "label": label_tensor
        }

def collate_fn(batch, max_frames, max_audio_samples):
    """Custom collate function to pad frames and audio."""
    
    video_tensors = []
    audio_tensors = []
    labels = []
    ids = []
    sample_rates = []

    for sample in batch:
        video = sample["video"]
        audio = sample["audio"]
        num_frames = video.shape[0]
        num_audio_samples = audio.shape[0]

        # Pad video to 300 frames
        if num_frames < max_frames:
            pad_frames = max_frames - num_frames
            padded_video = torch.cat([video, torch.zeros((pad_frames, *video.shape[1:]))], dim=0)
        else:
            padded_video = video[:max_frames]

        # Pad audio to 960,000 samples
        if num_audio_samples < max_audio_samples:
            pad_audio = max_audio_samples - num_audio_samples
            padded_audio = torch.cat([audio, torch.zeros(pad_audio)], dim=0)
        else:
            padded_audio = audio[:max_audio_samples]

        video_tensors.append(padded_video)
        audio_tensors.append(padded_audio)
        labels.append(sample["label"])
        ids.append(sample["id"])
        sample_rates.append(sample["sample_rate"])

    return {
        "id": ids,
        "video": torch.stack(video_tensors),  # (batch_size, num_frames, C, H, W)
        "audio": torch.stack(audio_tensors),  # (batch_size, audio_length)
        "sample_rate": sample_rates,
        "label": torch.stack(labels)  # (batch_size,)
    }

# Create DataLoader with padding
dataloader = torch.utils.data.DataLoader(
    VideoAudioDataset(prep_df, var.TO_PATH, max_frames=var.TARGET_N_FRAME, max_audio_samples=var.TARGET_AUDIO_LENGTH), 
    batch_size=2, 
    shuffle=True, 
    collate_fn=lambda batch: collate_fn(batch, max_frames=var.TARGET_N_FRAME, max_audio_samples=var.TARGET_AUDIO_LENGTH)
)

# Sample retrieval
sample_batch = next(iter(dataloader))
print(sample_batch["video"].shape)  # Expected: (batch_size, num_frames, C, H, W)
print(sample_batch["audio"].shape)  # Expected: (batch_size, audio_length)
print(sample_batch["label"].shape)  # Expected: (batch_size,)


# Define model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleVideoAudioModel(nn.Module):
    def __init__(self, 
                 video_shape=(var.TARGET_N_FRAME, 3, var.TARGET_FRAME_SIZE[0], var.TARGET_FRAME_SIZE[1]), 
                 audio_length=var.TARGET_AUDIO_LENGTH):
        super(SimpleVideoAudioModel, self).__init__()

        # Video CNN (3D Conv for spatiotemporal features)
        self.video_cnn = nn.Sequential(
            nn.Conv3d(3, 16, kernel_size=(3, 3, 3), stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool3d((1, 2, 2)),  # Only reduces H, W

            nn.Conv3d(16, 32, kernel_size=(3, 3, 3), stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool3d((1, 2, 2)),
            
            nn.Conv3d(32, 64, kernel_size=(3, 3, 3), stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool3d((1, 2, 2))
        )

        # Audio 1D CNN
        self.audio_cnn = nn.Sequential(
            nn.Conv1d(1, 16, kernel_size=5, stride=2, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(2),

            nn.Conv1d(16, 32, kernel_size=5, stride=2, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(2),
            
            nn.Conv1d(32, 64, kernel_size=5, stride=2, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(2)
        )

        # Compute output sizes
        dummy_video = torch.zeros(1, *video_shape)
        dummy_audio = torch.zeros(1, audio_length)  # Correct 3D shape
        
        with torch.no_grad():
            video_out = self.video_forward(dummy_video)
            audio_out = self.audio_forward(dummy_audio)

        feature_dim = video_out.shape[1] + audio_out.shape[1]

        # Fully connected layer
        self.fc = nn.Linear(feature_dim, 1)
        
        # Final
        self.sigmoid = nn.Sigmoid()

    def video_forward(self, video):
        batch_size, num_frames, C, H, W = video.shape
        video = video.permute(0, 2, 1, 3, 4)  # (batch, C, num_frames, H, W)
        video_features = self.video_cnn(video)
        return video_features.view(batch_size, -1)  # Flatten

    def audio_forward(self, audio):
        batch_size = audio.shape[0]
        audio = audio.unsqueeze(1)
        audio_features = self.audio_cnn(audio)  # (batch, channels, time_steps)
        return audio_features.view(batch_size, -1)

    def forward(self, video, audio):
        video_features = self.video_forward(video)
        audio_features = self.audio_forward(audio)
        combined_features = torch.cat((video_features, audio_features), dim=1)
        logits = self.fc(combined_features)
        return self.sigmoid(logits)

# Split data

In [None]:
train_df = prep_df[prep_df["Set"] == "train"]
test_df = prep_df[prep_df["Set"] == "test"]

len(train_df), len(test_df)

# Training

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE

In [None]:
# Create DataLoader with padding
train_dataloader = torch.utils.data.DataLoader(
    VideoAudioDataset(train_df, var.TO_PATH, max_frames=var.TARGET_N_FRAME, max_audio_samples=var.TARGET_AUDIO_LENGTH), 
    batch_size=8, 
    shuffle=True, 
    collate_fn=lambda batch: collate_fn(batch, max_frames=var.TARGET_N_FRAME, max_audio_samples=var.TARGET_AUDIO_LENGTH)
)

In [None]:
# Create model and optimizer
model = SimpleVideoAudioModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.MSELoss()

In [None]:
# Example training loop
num_epochs = 5
model.to(DEVICE)

for epoch in range(num_epochs):
    for bi, batch in enumerate(train_dataloader):
        video = batch["video"].to(DEVICE)  # (batch, num_frames, C, H, W)
        audio = batch["audio"].to(DEVICE)  # (batch, audio_length)
        labels = batch["label"].to(DEVICE)  # ECR labels

        optimizer.zero_grad()
        outputs = model(video, audio)
        outputs = outputs.squeeze()
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        if bi % 10 == 0:
            print(labels)
            print(outputs)
            print(f"\tBatch {bi}, Loss: {loss.item():.4f}")
            
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

# Evaluate

In [None]:
test_dataloader = torch.utils.data.DataLoader(
    VideoAudioDataset(test_df, var.PREP_PATH, max_frames=var.TARGET_N_FRAME, max_audio_samples=var.TARGET_AUDIO_LENGTH), 
    batch_size=4, 
    shuffle=True, 
    collate_fn=lambda batch: collate_fn(batch, max_frames=var.TARGET_N_FRAME, max_audio_samples=var.TARGET_AUDIO_LENGTH)
)