In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import WeightedRandomSampler

from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, recall_score

import os
import pandas as pd
import numpy as np
import time
import datetime

from scipy.io import loadmat
import glob
import gc
import h5py

base_dir = "../storage/transform/edaicwoz_"

In [2]:
class HDF5Dataset(Dataset):
    def __init__(self, h5_path):
        h5_path = base_dir + h5_path
        self.h5_path = h5_path
        with h5py.File(h5_path, 'r', libver='latest', swmr=True) as f:
            self.participants = list(f.keys())

    def __len__(self):
        return len(self.participants)

    def __getitem__(self, idx):
        with h5py.File(self.h5_path, 'r', libver='latest', swmr=True) as f:
            grp = f[self.participants[idx]]
            features = {k: torch.from_numpy(grp[k][:]) for k in grp.keys() if k != 'info'}
            label = torch.tensor(grp['info'][1][()])
            return features, label

In [3]:
def collate_fn(batch):
    features, labels = zip(*batch)
    
    padded_batch = {}
    for modality in features[0].keys():
        sequences = [item[modality] for item in features]
        padded = pad_sequence(sequences, batch_first=True, padding_value=0)
        padded_batch[modality] = padded
        
        masks = torch.ones_like(padded[:, :, 0])
        for i, seq in enumerate(sequences):
            masks[i, len(seq):] = 0
        padded_batch[f'{modality}_mask'] = masks.bool()
    
    return padded_batch, torch.stack(labels)

In [4]:
class PTSDTransformer(nn.Module):
    def __init__(self, audio_dim=24, video_dim=2048, d_model=128):
        super().__init__()
        self.egemaps_transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, nhead=8, dim_feedforward=512, batch_first=True),
            num_layers=2
        )
        
        self.video_transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, nhead=16, dim_feedforward=1024, batch_first=True),
            num_layers=2
        )
        
        self.audio_proj = nn.Linear(audio_dim, d_model)
        self.video_proj = nn.Linear(video_dim, d_model)
        
        self.classifier = nn.Linear(2*d_model, 2)

    def forward(self, audio, video, audio_mask=None, video_mask=None):
        # audio pathway
        audio = self.audio_proj(audio)
        audio = self.egemaps_transformer(audio, src_key_padding_mask=audio_mask)
        audio_pooled = audio.mean(dim=1)
        
        # video pathway
        video = self.video_proj(video)
        video = self.video_transformer(video, src_key_padding_mask=video_mask)
        video_pooled = video.mean(dim=1)
        
        # fusion
        fused = torch.cat([audio_pooled, video_pooled], dim=-1)
        return self.classifier(fused)

In [5]:
def create_dataloader(h5_path, batch_size=32, pin_memory=False, num_workers=0, shuffle=True):
    dataset = HDF5Dataset(h5_path)
    return DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        shuffle=shuffle,
        pin_memory=pin_memory,
        num_workers=num_workers
    )

In [11]:
def create_balanced_dataloader(h5_path, batch_size, collate_fn,  pin_memory=False, num_workers=0):
    dataset = HDF5Dataset(h5_path)
    
    all_labels = []
    for i in range(len(dataset)):
        _, label = dataset[i]
        all_labels.append(label.item())
    
    class_counts = torch.bincount(torch.tensor(all_labels))
    
    weights_per_class = 1.0 / class_counts.float()
    
    sample_weights = [weights_per_class[label] for label in all_labels]
    
    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )
    
    return DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        collate_fn=collate_fn,
        pin_memory=pin_memory,
        num_workers=num_workers,
    )

In [7]:
def evaluate(model, loader, device):
    model.eval()
    all_labels = []
    all_preds = []
    all_probs = []
    total_loss = 0.0
    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for features, labels in loader:
            inputs = {
                mod: features[mod].float().to(device, non_blocking=True)
                for mod in ['audio_egemaps', 'visual_resnet']
            }
            masks = {
                mod: features[f'{mod}_mask'].to(device, non_blocking=True)
                for mod in ['audio_egemaps', 'visual_resnet']
            }
            labels = labels.to(device, non_blocking=True)

            outputs = model(
                inputs['audio_egemaps'],
                inputs['visual_resnet'],
                audio_mask=~masks['audio_egemaps'],
                video_mask=~masks['visual_resnet']
            )

            loss = criterion(outputs, labels)
            total_loss += loss.item()

            probs = torch.softmax(outputs, dim=1)[:, 1].cpu().numpy()
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds)
            all_probs.extend(probs)

    metrics = {
        'loss': total_loss / len(loader),
        'accuracy': accuracy_score(all_labels, all_preds),
        'f1': f1_score(all_labels, all_preds),
        'recall': recall_score(all_labels, all_preds),
    }
    
    try:
        metrics['auc'] = roc_auc_score(all_labels, all_probs)
    except ValueError:
        metrics['auc'] = float('nan')

    return metrics

In [8]:
def train_model(model, train_loader, val_loader, test_loader, optimizer, scheduler, criterion, num_epochs, device):
    best_val_loss = float('inf')
    history = {
        'train': {'loss': [], 'accuracy': [], 'f1': [], 'recall': [], 'auc': []},
        'val': {'loss': [], 'accuracy': [], 'f1': [], 'recall': [], 'auc': []},
        'test': {'loss': None, 'accuracy': None, 'f1': None, 'recall': None, 'auc': None}
    }

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for batch_idx, (features, labels) in enumerate(train_loader):
            inputs = {
                mod: features[mod].float().to(device, non_blocking=True)
                for mod in ['audio_egemaps', 'visual_resnet']
            }
            masks = {
                mod: features[f'{mod}_mask'].to(device, non_blocking=True)
                for mod in ['audio_egemaps', 'visual_resnet']
            }
            labels = labels.to(device, non_blocking=True)

            optimizer.zero_grad()
            outputs = model(
                inputs['audio_egemaps'],
                inputs['visual_resnet'],
                audio_mask=~masks['audio_egemaps'],
                video_mask=~masks['visual_resnet']
            )
            
            loss = criterion(outputs, labels)
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            train_loss += loss.item()
            
            print(f"batch idx: {batch_idx} done; train loss: {loss.item()}; avg train loss: {train_loss/(batch_idx+1)}")
            

        train_metrics = evaluate(model, train_loader, device)
        val_metrics = evaluate(model, val_loader, device)
        
        scheduler.step(val_metrics['loss'])
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Current learning rate: {current_lr}")

        for metric in ['loss', 'accuracy', 'f1', 'recall', 'auc']:
            history['train'][metric].append(train_metrics[metric])
            history['val'][metric].append(val_metrics[metric])

        if val_metrics['loss'] < best_val_loss:
            best_val_loss = val_metrics['loss']
            torch.save(model.state_dict(), 'best_model.pth')

        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"Train - Loss: {train_metrics['loss']:.4f}, Acc: {train_metrics['accuracy']:.4f}, " +
              f"F1: {train_metrics['f1']:.4f}, Recall: {train_metrics['recall']:.4f}, AUC: {train_metrics['auc']:.4f}")
        print(f"Val   - Loss: {val_metrics['loss']:.4f}, Acc: {val_metrics['accuracy']:.4f}, " +
              f"F1: {val_metrics['f1']:.4f}, Recall: {val_metrics['recall']:.4f}, AUC: {val_metrics['auc']:.4f}")

    model.load_state_dict(torch.load('best_model.pth'))
    test_metrics = evaluate(model, test_loader, device)
    history['test'] = test_metrics

    print("\nTest Results:")
    print(f"Loss: {test_metrics['loss']:.4f}, Acc: {test_metrics['accuracy']:.4f}, " +
          f"F1: {test_metrics['f1']:.4f}, Recall: {test_metrics['recall']:.4f}, AUC: {test_metrics['auc']:.4f}")

    plot_metrics(history['train'], history['val'], history['test'])
    
    return history

In [9]:
def plot_metrics(train_metrics, val_metrics, test_metrics):
    epochs = range(1, len(train_metrics['loss']) + 1)
    metrics = ['loss', 'accuracy', 'f1', 'recall', 'auc']

    plt.figure(figsize=(20, 12))
    
    for i, metric in enumerate(metrics, 1):
        plt.subplot(2, 3, i)
        plt.plot(epochs, train_metrics[metric], label='Train')
        plt.plot(epochs, val_metrics[metric], label='Validation')
        
        if test_metrics[metric] is not None:
            plt.axhline(y=test_metrics[metric], color='r', linestyle='--', label='Test')
        
        plt.title(metric.capitalize())
        plt.xlabel('Epochs')
        plt.legend()

    plt.tight_layout()
    plt.show()

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PTSDTransformer().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0005)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.1, patience=2
)
criterion = nn.CrossEntropyLoss()

train_loader = create_balanced_dataloader("train_normalized.h5", batch_size=1, collate_fn=collate_fn)
val_loader = DataLoader(HDF5Dataset("dev_normalized.h5"), batch_size=1, collate_fn=collate_fn, shuffle=False)
test_loader = DataLoader(HDF5Dataset("test_normalized.h5"), batch_size=1, collate_fn=collate_fn, shuffle=False)

history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    optimizer=optimizer,
    scheduler=scheduler,
    criterion=criterion,
    num_epochs=10,
    device=device
)

batch idx: 0 done; train loss: 0.6348230242729187; avg train loss: 0.6348230242729187
batch idx: 1 done; train loss: 0.9704708456993103; avg train loss: 0.8026469349861145
batch idx: 2 done; train loss: 0.08553390204906464; avg train loss: 0.5636092573404312
batch idx: 3 done; train loss: 0.013515986502170563; avg train loss: 0.42608593963086605
batch idx: 4 done; train loss: 4.5282063484191895; avg train loss: 1.2465100213885307
batch idx: 5 done; train loss: 0.4734707176685333; avg train loss: 1.1176701374351978
batch idx: 6 done; train loss: 0.1527927815914154; avg train loss: 0.9798305151718003
batch idx: 7 done; train loss: 5.227176189422607; avg train loss: 1.5107487244531512
batch idx: 8 done; train loss: 0.06630849093198776; avg train loss: 1.350255365173022
batch idx: 9 done; train loss: 3.452159881591797; avg train loss: 1.5604458168148994
batch idx: 10 done; train loss: 4.129878520965576; avg train loss: 1.7940306081013246
batch idx: 11 done; train loss: 2.0784449577331543; 

  output = torch._nested_tensor_from_mask(


Current learning rate: 0.0005

Epoch 1/10
Train - Loss: 0.7379, Acc: 0.5951, F1: 0.6887, Recall: 0.8588, AUC: 0.6605
Val   - Loss: 0.9711, Acc: 0.4727, F1: 0.5085, Recall: 0.8824, AUC: 0.6037
batch idx: 0 done; train loss: 0.4728243947029114; avg train loss: 0.4728243947029114
batch idx: 1 done; train loss: 0.22106553614139557; avg train loss: 0.3469449654221535
batch idx: 2 done; train loss: 0.16680647432804108; avg train loss: 0.286898801724116
batch idx: 3 done; train loss: 1.3201426267623901; avg train loss: 0.5452097579836845
batch idx: 4 done; train loss: 2.14467191696167; avg train loss: 0.8651021897792817
batch idx: 5 done; train loss: 0.04506317153573036; avg train loss: 0.728429020072023
batch idx: 6 done; train loss: 2.852041244506836; avg train loss: 1.0318021949912821
batch idx: 7 done; train loss: 0.11675974726676941; avg train loss: 0.917421889025718
batch idx: 8 done; train loss: 0.09845616668462753; avg train loss: 0.8264256976544857
batch idx: 9 done; train loss: 0.01


KeyboardInterrupt



In [35]:
gc.collect()

0

In [13]:
torch.cuda.empty_cache()