# Single Modality AutoEncoder ( V + A )

## Imports 

In [8]:
import h5py
import torch
from torch.utils.data import Dataset

import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from collections import defaultdict
import json
import torch.optim as optim

# Dataset

In [9]:
class CombinedDataset(Dataset):
    def __init__(self, file_path, bert_feature_size='bert_text_features_128', split='train', dtype=torch.float32):
        self.file_path = file_path
        self.bert_feature_size = bert_feature_size
        self.split = split
        self.dtype = dtype 
        self.data_keys = []
        with h5py.File(self.file_path, 'r') as file:
 
            for key in file.keys():
                if file[key].attrs['split'] == self.split:
                    self.data_keys.append(key)

    def __len__(self):
        return len(self.data_keys)

    def __getitem__(self, idx):
        with h5py.File(self.file_path, 'r') as file:

            group_key = self.data_keys[idx]
            group = file[group_key]


            label = group.attrs['label']
            text = group.attrs['text']
            audio_features = torch.from_numpy(group['audio_features_averaged'][()]).type(self.dtype)
            facial_features = torch.from_numpy(group['averaged_facial_features'][()]).type(self.dtype)
            bert_features = torch.from_numpy(group[self.bert_feature_size][()]).type(self.dtype)


            label_to_index = {'Positive': 2, 'Neutral': 1, 'Negative': 0}
            label_index = label_to_index[label]


            sample = {
                'label': label_index,
                'text': text,
                'audio_features': audio_features,
                'facial_features': facial_features,
                'bert_features': bert_features
            }

            return sample

In [10]:
file_path = './combined_features.h5'
datasets = {}
bert_feature_sizes = ['bert_text_features_128', 'bert_text_features_256', 'bert_text_features_512']
splits = ['train', 'validate', 'test']

for feature_size in bert_feature_sizes:
    datasets[feature_size] = {}
    for split in splits:
        dataset_key = f"{split}_{feature_size}"
        datasets[feature_size][split] = CombinedDataset(file_path, bert_feature_size=feature_size, split=split)

### Common Trainer Class 

In [11]:
class ModelTrainer:
    def __init__(self, model, train_dataset, val_dataset, model_name, epochs, save_interval, lr=1e-3, device='cuda'):
        self.model = model.to(device)
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.model_name = model_name
        self.start_epoch = 0
        self.epochs = epochs
        self.save_interval = save_interval
        self.lr = lr
        self.device = device
        self.history = defaultdict(list)
        self.checkpoint_dir = f'modelCheckPoints/{self.model_name}'
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        self.optimizer = None
        
    def save_checkpoint(self, epoch):
        state = {'epoch': epoch, 'state_dict': self.model.state_dict()}
        torch.save(state, f'{self.checkpoint_dir}/{epoch}.pt')

    def load_checkpoint(self):
        checkpoints = [ckpt for ckpt in os.listdir(self.checkpoint_dir) if ckpt.endswith('.pt')]
        if checkpoints:
            latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('.')[0]))
            checkpoint = torch.load(f'{self.checkpoint_dir}/{latest_checkpoint}', map_location=self.device)
            self.model.load_state_dict(checkpoint['state_dict'])
            self.start_epoch = checkpoint['epoch'] + 1  
            print(f"Loaded checkpoint: {latest_checkpoint} at epoch {checkpoint['epoch']}")
        else:
            self.start_epoch = 0  
            print("No checkpoints found, starting from scratch.")

    def save_history(self):
        with open(f'{self.checkpoint_dir}/history.json', 'w') as f:
            json.dump(self.history, f)
            
    def initialize_optimizer(self):
        sample_batch = next(iter(DataLoader(self.train_dataset, batch_size=1, shuffle=True)))
        facial_features = sample_batch['facial_features'].to(self.device)
        audio_features = sample_batch['audio_features'].to(self.device)
        _ = self.model(facial_features, audio_features)

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)

    def train_one_epoch(self, dataloader, criterion):
        self.model.train()
        total_loss = 0
        correct_predictions = 0

        for batch in dataloader:
            facial_features = batch['facial_features'].to(self.device)
            audio_features = batch['audio_features'].to(self.device)
            labels = batch['label'].to(self.device)

            self.optimizer.zero_grad()
            outputs = self.model(facial_features, audio_features)

            loss = criterion(outputs, labels)
            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct_predictions += (predicted == labels).sum().item()

        avg_loss = total_loss / len(dataloader.dataset)
        accuracy = correct_predictions / len(dataloader.dataset)
        return avg_loss, accuracy

    def validate(self, dataloader, criterion):
        self.model.eval()
        total_loss = 0
        correct_predictions = 0

        with torch.no_grad():
            for batch in dataloader:
                facial_features = batch['facial_features'].to(self.device)
                audio_features = batch['audio_features'].to(self.device)
                labels = batch['label'].to(self.device)

                outputs = self.model(facial_features, audio_features)
                loss = criterion(outputs, labels)
                total_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                correct_predictions += (predicted == labels).sum().item()

        avg_loss = total_loss / len(dataloader.dataset)
        accuracy = correct_predictions / len(dataloader.dataset)
        return avg_loss, accuracy

    def train(self, criterion):
        self.initialize_optimizer()
        self.load_checkpoint()

        train_dataloader = DataLoader(self.train_dataset, batch_size=128, shuffle=True)
        val_dataloader = DataLoader(self.val_dataset, batch_size=128, shuffle=False)

        try:
            for epoch in range(self.start_epoch, self.epochs):
                train_loss, train_acc = self.train_one_epoch(train_dataloader, criterion)
                val_loss, val_acc = self.validate(val_dataloader, criterion)

                self.history['train_loss'].append(train_loss)
                self.history['train_acc'].append(train_acc)
                self.history['val_loss'].append(val_loss)
                self.history['val_acc'].append(val_acc)

                print(f"Epoch {epoch+1}/{self.epochs}, "
                      f"Train Loss: {train_loss:.4f}, "
                      f"Train Accuracy: {train_acc:.4f}, "
                      f"Val Loss: {val_loss:.4f}, "
                      f"Val Accuracy: {val_acc:.4f}")

                if (epoch + 1) % self.save_interval == 0:
                    self.save_checkpoint(epoch + 1)

                self.save_history()

        except KeyboardInterrupt:
            print("\nTraining interrupted by user. Saving last model state...")
            self.save_checkpoint(epoch + 1)
            self.save_history()

# DynamicEncoder (Both modalities use the same)

In [12]:
class DynamicEncoder(nn.Module):
    def __init__(self, encoded_size=128, dropout_rate=0.5):
        super(DynamicEncoder, self).__init__()
        self.encoded_size = encoded_size
        self.dropout_rate = dropout_rate
        self.encoder = None

    def forward(self, x):
        if self.encoder is None:
            input_size = x.size(1)
            self.encoder = nn.Sequential(
                nn.Linear(input_size, input_size // 2), 
                nn.ReLU(),
                nn.Dropout(self.dropout_rate),
                nn.Linear(input_size // 2, self.encoded_size),
                nn.ReLU()
            ).to(x.device)
        return self.encoder(x)

In [13]:
class CombinedVideoAudioClassifier(nn.Module):
    def __init__(self, num_classes, encoded_audio_size=128, encoded_video_size=128, dropout_rate=0.5):
        super(CombinedVideoAudioClassifier, self).__init__()
        self.audio_encoder = DynamicEncoder(encoded_size=encoded_audio_size, dropout_rate=dropout_rate)
        self.video_encoder = DynamicEncoder(encoded_size=encoded_video_size, dropout_rate=dropout_rate)
        self.classifier = None
        self.encoded_audio_size = encoded_audio_size
        self.encoded_video_size = encoded_video_size
        self.dropout_rate = dropout_rate
        self.num_classes = num_classes

    def forward(self, video_features, audio_features):
        encoded_audio = self.audio_encoder(audio_features)
        encoded_video = self.video_encoder(video_features)
        
        if self.classifier is None:
            combined_feature_size = self.encoded_audio_size + self.encoded_video_size
            self.classifier = nn.Sequential(
                nn.Linear(combined_feature_size, 512),
                nn.ReLU(),
                nn.Dropout(self.dropout_rate),
                nn.Linear(512, self.num_classes)
            ).to(audio_features.device)

        combined_features = torch.cat((encoded_audio, encoded_video), dim=1)
        return self.classifier(combined_features)

In [14]:
train_dataset_512 = datasets['bert_text_features_512']['train']
validate_dataset_512 = datasets['bert_text_features_512']['validate']

device = torch.device("cuda" if torch.cuda.is_available() else "CPU")
criterion = nn.CrossEntropyLoss()

num_classes = 3  
bert_feature_size = 512
CombinedClassifier_A1_V1 = CombinedVideoAudioClassifier(bert_feature_size, num_classes,).to(device)
modelName = 'CombinedClassifier_A1_V1'

trainer = ModelTrainer(CombinedClassifier_A1_V1, train_dataset_512, validate_dataset_512, modelName, epochs=25, save_interval=5, device=device)
trainer.train(criterion)

No checkpoints found, starting from scratch.
Epoch 1/25, Train Loss: 0.0111, Train Accuracy: 0.3483, Val Loss: 0.0086, Val Accuracy: 0.3870
Epoch 2/25, Train Loss: 0.0085, Train Accuracy: 0.4208, Val Loss: 0.0083, Val Accuracy: 0.4319
Epoch 3/25, Train Loss: 0.0083, Train Accuracy: 0.4421, Val Loss: 0.0082, Val Accuracy: 0.4448
Epoch 4/25, Train Loss: 0.0083, Train Accuracy: 0.4519, Val Loss: 0.0081, Val Accuracy: 0.4490
Epoch 5/25, Train Loss: 0.0082, Train Accuracy: 0.4657, Val Loss: 0.0081, Val Accuracy: 0.4661
Epoch 6/25, Train Loss: 0.0081, Train Accuracy: 0.4677, Val Loss: 0.0081, Val Accuracy: 0.4597
Epoch 7/25, Train Loss: 0.0081, Train Accuracy: 0.4689, Val Loss: 0.0081, Val Accuracy: 0.4469
Epoch 8/25, Train Loss: 0.0080, Train Accuracy: 0.4881, Val Loss: 0.0081, Val Accuracy: 0.4547
Epoch 9/25, Train Loss: 0.0079, Train Accuracy: 0.4858, Val Loss: 0.0080, Val Accuracy: 0.4590
Epoch 10/25, Train Loss: 0.0079, Train Accuracy: 0.4962, Val Loss: 0.0079, Val Accuracy: 0.4761
Epoc