In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import WeightedRandomSampler

from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, recall_score

import os
import pandas as pd
import numpy as np
import time
import datetime

from scipy.io import loadmat
import glob
import gc
import h5py

import matplotlib.pyplot as plt

base_dir = "../storage/transform/edaicwoz_"
mods_name = ['audio_densenet', 'visual_resnet', 'audio_vgg16']
mods_size = [1920, 2048, 4096]

In [2]:
class HDF5Dataset(Dataset):
    def __init__(self, h5_path):
        h5_path = base_dir + h5_path
        self.h5_path = h5_path
        with h5py.File(h5_path, 'r', libver='latest', swmr=True) as f:
            self.participants = list(f.keys())

    def __len__(self):
        return len(self.participants)

    def __getitem__(self, idx):
        with h5py.File(self.h5_path, 'r', libver='latest', swmr=True) as f:
            grp = f[self.participants[idx]]
            features = {k: torch.from_numpy(grp[k][:]) for k in grp.keys() if k != 'info'}
            label = torch.tensor(grp['info'][1][()])
            return features, label

In [3]:
def collate_fn(batch):
    features, labels = zip(*batch)
    
    padded_batch = {}
    for modality in features[0].keys():
        sequences = [item[modality] for item in features]
        padded = pad_sequence(sequences, batch_first=True, padding_value=0)
        padded_batch[modality] = padded
        
        masks = torch.ones_like(padded[:, :, 0])
        for i, seq in enumerate(sequences):
            masks[i, len(seq):] = 0
        padded_batch[f'{modality}_mask'] = masks.bool()
    
    return padded_batch, torch.stack(labels)

In [4]:
class PTSDTransformer(nn.Module):
    def __init__(self, d_model=128):
        super().__init__()
        self.audio_transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, nhead=8, dim_feedforward=64, batch_first=True),
            num_layers=2
        )
        
        self.audio_transformer2 = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, nhead=8, dim_feedforward=64, batch_first=True),
            num_layers=2
        )
        
        self.video_transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, nhead=8, dim_feedforward=64, batch_first=True),
            num_layers=2
        )
        
        self.audio_proj = nn.Linear(mods_size[0], d_model)
        self.video_proj = nn.Linear(mods_size[1], d_model)
        
        self.audio_proj2 = nn.Linear(mods_size[2], d_model)
        
        self.classifier = nn.Linear(len(mods_size)*d_model, 2)

    def forward(self, audio, video, audio2, audio_mask=None, video_mask=None, audio_mask2=None):
        # audio pathway
        audio = self.audio_proj(audio)
        audio = self.audio_transformer(audio, src_key_padding_mask=audio_mask)
        audio_pooled = audio.mean(dim=1)
        
        # video pathway
        video = self.video_proj(video)
        video = self.video_transformer(video, src_key_padding_mask=video_mask)
        video_pooled = video.mean(dim=1)
                                     
        audio2 = self.audio_proj2(audio2)
        audio2 = self.audio_transformer2(audio2, src_key_padding_mask=audio_mask2)
        audio_pooled2 = audio2.mean(dim=1)
        
        # fusion
        fused = torch.cat([audio_pooled, video_pooled, audio_pooled2], dim=-1)
        return self.classifier(fused)

In [5]:
def create_dataloader(h5_path, batch_size=32, pin_memory=False, num_workers=0, shuffle=True):
    dataset = HDF5Dataset(h5_path)
    return DataLoader(
        dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        shuffle=shuffle,
        pin_memory=pin_memory,
        num_workers=num_workers
    )

In [6]:
def create_balanced_dataloader(h5_path, batch_size, collate_fn,  pin_memory=False, num_workers=0):
    dataset = HDF5Dataset(h5_path)
    
    print("[Start] create balanced dataset")
    
    all_labels = []
    for i in range(len(dataset)):
        _, label = dataset[i]
        all_labels.append(label.item())
    
    class_counts = torch.bincount(torch.tensor(all_labels))
    
    weights_per_class = 1.0 / class_counts.float()
    
    sample_weights = [weights_per_class[label] for label in all_labels]
    
    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )
    
    print("[End] create balanced dataset")
    
    return DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        collate_fn=collate_fn,
        pin_memory=pin_memory,
        num_workers=num_workers,
    )

In [7]:
def evaluate(model, loader, device):
    model.eval()
    all_labels = []
    all_preds = []
    all_probs = []
    total_loss = 0.0
    criterion = nn.CrossEntropyLoss()
    
    mods = mods_name

    with torch.no_grad():
        for features, labels in loader:
            inputs = {
                mod: features[mod].float().to(device, non_blocking=True)
                for mod in mods
            }
            masks = {
                mod: features[f'{mod}_mask'].to(device, non_blocking=True)
                for mod in mods
            }
            labels = labels.to(device, non_blocking=True)

            outputs = model(
                inputs[mods[0]],
                inputs[mods[1]],
                inputs[mods[2]],
                audio_mask=~masks[mods[0]],
                video_mask=~masks[mods[1]],
                audio_mask2=~masks[mods[2]]
            )

            loss = criterion(outputs, labels)
            total_loss += loss.item()

            probs = torch.softmax(outputs, dim=1)[:, 1].cpu().numpy()
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds)
            all_probs.extend(probs)

    metrics = {
        'loss': total_loss / len(loader),
        'accuracy': accuracy_score(all_labels, all_preds),
        'f1': f1_score(all_labels, all_preds),
        'recall': recall_score(all_labels, all_preds),
    }
    
    try:
        metrics['auc'] = roc_auc_score(all_labels, all_probs)
    except ValueError:
        metrics['auc'] = float('nan')

    return metrics

In [8]:
def train_model(model, train_loader, val_loader, test_loader, optimizer, scheduler, criterion, num_epochs, device):
    best_val_loss = float('inf')
    history = {
        'train': {'loss': [], 'accuracy': [], 'f1': [], 'recall': [], 'auc': []},
        'val': {'loss': [], 'accuracy': [], 'f1': [], 'recall': [], 'auc': []},
        'test': {'loss': None, 'accuracy': None, 'f1': None, 'recall': None, 'auc': None}
    }
    
    mods = mods_name

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for batch_idx, (features, labels) in enumerate(train_loader):
            inputs = {
                mod: features[mod].float().to(device, non_blocking=True)
                for mod in mods
            }
            masks = {
                mod: features[f'{mod}_mask'].to(device, non_blocking=True)
                for mod in mods
            }
            labels = labels.to(device, non_blocking=True)

            optimizer.zero_grad()
            outputs = model(
                inputs[mods[0]],
                inputs[mods[1]],
                inputs[mods[2]],
                audio_mask=~masks[mods[0]],
                video_mask=~masks[mods[1]],
                audio_mask2=~masks[mods[2]]
            )
            
            loss = criterion(outputs, labels)
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            train_loss += loss.item()
            
            print(f"batch idx: {batch_idx} done; train loss: {loss.item()}; avg train loss: {train_loss/(batch_idx+1)}")
            

        train_metrics = evaluate(model, train_loader, device)
        val_metrics = evaluate(model, val_loader, device)
        
        scheduler.step(val_metrics['loss'])
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Current learning rate: {current_lr}")

        for metric in ['loss', 'accuracy', 'f1', 'recall', 'auc']:
            history['train'][metric].append(train_metrics[metric])
            history['val'][metric].append(val_metrics[metric])

        if val_metrics['loss'] < best_val_loss:
            best_val_loss = val_metrics['loss']
            torch.save(model.state_dict(), 'best_model.pth')

        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"Train - Loss: {train_metrics['loss']:.4f}, Acc: {train_metrics['accuracy']:.4f}, " +
              f"F1: {train_metrics['f1']:.4f}, Recall: {train_metrics['recall']:.4f}, AUC: {train_metrics['auc']:.4f}")
        print(f"Val   - Loss: {val_metrics['loss']:.4f}, Acc: {val_metrics['accuracy']:.4f}, " +
              f"F1: {val_metrics['f1']:.4f}, Recall: {val_metrics['recall']:.4f}, AUC: {val_metrics['auc']:.4f}")

    model.load_state_dict(torch.load('best_model.pth'))
    test_metrics = evaluate(model, test_loader, device)
    history['test'] = test_metrics

    print("\nTest Results:")
    print(f"Loss: {test_metrics['loss']:.4f}, Acc: {test_metrics['accuracy']:.4f}, " +
          f"F1: {test_metrics['f1']:.4f}, Recall: {test_metrics['recall']:.4f}, AUC: {test_metrics['auc']:.4f}")

    plot_metrics(history['train'], history['val'], history['test'])
    
    return history

In [9]:
def plot_metrics(train_metrics, val_metrics, test_metrics):
    epochs = range(1, len(train_metrics['loss']) + 1)
    metrics = ['loss', 'accuracy', 'f1', 'recall', 'auc']

    plt.figure(figsize=(20, 12))
    
    for i, metric in enumerate(metrics, 1):
        plt.subplot(2, 3, i)
        plt.plot(epochs, train_metrics[metric], label='Train')
        plt.plot(epochs, val_metrics[metric], label='Validation')
        
        if test_metrics[metric] is not None:
            plt.axhline(y=test_metrics[metric], color='r', linestyle='--', label='Test')
        
        plt.title(metric.capitalize())
        plt.xlabel('Epochs')
        plt.legend()

    plt.tight_layout()
    plt.show()

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PTSDTransformer().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0005)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.1, patience=2
)
criterion = nn.CrossEntropyLoss()

train_loader = create_balanced_dataloader("train_normalized.h5", batch_size=1, collate_fn=collate_fn)
val_loader = DataLoader(HDF5Dataset("dev_normalized.h5"), batch_size=1, collate_fn=collate_fn, shuffle=False)
test_loader = DataLoader(HDF5Dataset("test_normalized.h5"), batch_size=1, collate_fn=collate_fn, shuffle=False)

history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    optimizer=optimizer,
    scheduler=scheduler,
    criterion=criterion,
    num_epochs=10,
    device=device
)

[Start] create balanced dataset
[End] create balanced dataset
batch idx: 0 done; train loss: 0.9249085783958435; avg train loss: 0.9249085783958435
batch idx: 1 done; train loss: 0.5163732171058655; avg train loss: 0.7206408977508545
batch idx: 2 done; train loss: 0.0026033578906208277; avg train loss: 0.4812950511307766
batch idx: 3 done; train loss: 0.5895087718963623; avg train loss: 0.508348481322173
batch idx: 4 done; train loss: 0.3536399304866791; avg train loss: 0.47740677115507424
batch idx: 5 done; train loss: 0.005417665466666222; avg train loss: 0.3987419202070062
batch idx: 6 done; train loss: 0.002528686309233308; avg train loss: 0.34214002965018153
batch idx: 7 done; train loss: 0.14559443295001984; avg train loss: 0.3175718300626613
batch idx: 8 done; train loss: 0.0038631348870694637; avg train loss: 0.28271530837648445
batch idx: 9 done; train loss: 3.758331060409546; avg train loss: 0.6302768835797906
batch idx: 10 done; train loss: 0.0020306934602558613; avg train l

  output = torch._nested_tensor_from_mask(


Current learning rate: 0.0005

Epoch 1/10
Train - Loss: 2.0504, Acc: 0.5521, F1: 0.1978, Recall: 0.1098, AUC: 0.6951
Val   - Loss: 1.4972, Acc: 0.6909, F1: 0.0000, Recall: 0.0000, AUC: 0.5062
batch idx: 0 done; train loss: 8.555012702941895; avg train loss: 8.555012702941895
batch idx: 1 done; train loss: 0.0068819401785731316; avg train loss: 4.280947321560234
batch idx: 2 done; train loss: 4.128085136413574; avg train loss: 4.22999325984468
batch idx: 3 done; train loss: 2.3777360916137695; avg train loss: 3.766928967786953
batch idx: 4 done; train loss: 0.003950174432247877; avg train loss: 3.0143332091160118
batch idx: 5 done; train loss: 1.5470702648162842; avg train loss: 2.769789385066057
batch idx: 6 done; train loss: 0.1305990070104599; avg train loss: 2.392762188200972
batch idx: 7 done; train loss: 0.015805209055542946; avg train loss: 2.0956425658077933
batch idx: 8 done; train loss: 0.030279556289315224; avg train loss: 1.8661577869724069
batch idx: 9 done; train loss: 1.0

NameError: name 'plt' is not defined

In [None]:
gc.collect()

In [None]:
torch.cuda.empty_cache()