In [12]:
import torch
from torch.utils.data import Dataset

import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from collections import defaultdict
import json
import torch.optim as optim
from torch.nn import TransformerEncoder, TransformerEncoderLayer

In [13]:
class MultimodalDataset(Dataset):
    def __init__(self, base_path, split, bert_feature_size='bert_text_features_128'):
        """
        Initialize the dataset by loading the tensor files.

        :param base_path: The path where the .pt files are stored
        :param split: The data split to load ('train', 'validate', or 'test')
        :param bert_feature_size: The size of the BERT features to load
        """
        self.audio_features = torch.load(f'{base_path}/{split}_audio_features.pt')
        self.facial_features = torch.load(f'{base_path}/{split}_facial_features.pt')
        self.bert_features = torch.load(f'{base_path}/{split}_{bert_feature_size}.pt')
        self.labels = torch.load(f'{base_path}/{split}_labels.pt')

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            'audio_features': self.audio_features[idx],
            'facial_features': self.facial_features[idx],
            'bert_features': self.bert_features[idx],
            'label': self.labels[idx]
        }

In [14]:
# Assuming the data is saved in './tensor_data' directory
base_path = './tensor_data'
bert_feature_size = 'bert_text_features_512'  # or 256, 128 based on what is needed

train_dataset = MultimodalDataset(base_path, 'train', bert_feature_size)
val_dataset = MultimodalDataset(base_path, 'validate', bert_feature_size)

# These dataloaders can be passed directly to the ModelTrainer class
train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=256, shuffle=False)

In [36]:
class ModelTrainer:
    def __init__(self, model, train_dataset, val_dataset, model_name, epochs, save_interval, lr=1e-3, device='cuda'):
        self.model = model.to(device)
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.model_name = model_name
        self.start_epoch = 0
        self.epochs = epochs
        self.save_interval = save_interval
        self.lr = lr
        self.device = device
        self.history = defaultdict(list)
        self.checkpoint_dir = f'modelCheckPoints/{self.model_name}'
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)

    def save_checkpoint(self, epoch):
        state = {'epoch': epoch, 'state_dict': self.model.state_dict()}
        torch.save(state, f'{self.checkpoint_dir}/{epoch}.pt')

    def load_checkpoint(self):
        checkpoints = [ckpt for ckpt in os.listdir(self.checkpoint_dir) if ckpt.endswith('.pt')]
        if checkpoints:
            latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('.')[0]))
            checkpoint = torch.load(f'{self.checkpoint_dir}/{latest_checkpoint}', map_location=self.device)
            self.model.load_state_dict(checkpoint['state_dict'])
            self.start_epoch = checkpoint['epoch'] + 1
            print(f"Loaded checkpoint: {latest_checkpoint} at epoch {checkpoint['epoch']}")
        else:
            print("No checkpoints found, starting from scratch.")

    def save_history(self):
        with open(f'{self.checkpoint_dir}/history.json', 'w') as f:
            json.dump(self.history, f)

    def train_one_epoch(self, dataloader, criterion, max_grad_norm=1.0):
        self.model.train()
        total_loss = 0
        correct_predictions = 0
    
        for batch in dataloader:
            audio = batch['audio'].to(self.device)
            vision = batch['vision'].to(self.device)
            text_bert = batch['text_bert'].to(self.device)
            labels = batch['label'].to(self.device)
    
            self.optimizer.zero_grad()
            outputs = self.model(audio, vision, text_bert)
    
            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
            self.optimizer.step()
    
            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct_predictions += (predicted == labels).sum().item()
    
        avg_loss = total_loss / len(dataloader.dataset)
        accuracy = correct_predictions / len(dataloader.dataset)
        return avg_loss, accuracy

    def validate(self, dataloader, criterion):
        self.model.eval()
        total_loss = 0
        correct_predictions = 0
    
        with torch.no_grad():
            for batch in dataloader:
                audio = batch['audio'].to(self.device)
                vision = batch['vision'].to(self.device)
                text_bert = batch['text_bert'].to(self.device)
                labels = batch['label'].to(self.device)
    
                outputs = self.model(audio, vision, text_bert)
                loss = criterion(outputs, labels)
                total_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                correct_predictions += (predicted == labels).sum().item()
    
        avg_loss = total_loss / len(dataloader.dataset)
        accuracy = correct_predictions / len(dataloader.dataset)
        return avg_loss, accuracy

    def train(self, criterion):
        self.load_checkpoint()
        train_dataloader = DataLoader(self.train_dataset, batch_size=256, shuffle=True)
        val_dataloader = DataLoader(self.val_dataset, batch_size=256, shuffle=False)

        for epoch in range(self.start_epoch, self.epochs):
            train_loss, train_acc = self.train_one_epoch(train_dataloader, criterion)
            val_loss, val_acc = self.validate(val_dataloader, criterion)
    
            self.history['train_loss'].append(train_loss)
            self.history['train_acc'].append(train_acc)
            self.history['val_loss'].append(val_loss)
            self.history['val_acc'].append(val_acc)
    
            print(f"Epoch {epoch+1}/{self.epochs}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.4f}")
    
            if (epoch + 1) % self.save_interval == 0:
                self.save_checkpoint(epoch + 1)
    
            self.save_history()

In [37]:
class MultimodalAttentionClassifier(nn.Module):
    def __init__(self, audio_feature_dim, facial_feature_dim, bert_feature_dim, num_classes, nhead, num_encoder_layers, dim_feedforward, dropout=0.1):
        super(MultimodalAttentionClassifier, self).__init__()

        # Original total feature dimension
        original_total_dim = audio_feature_dim + facial_feature_dim + bert_feature_dim

        # Find a new total dimension that is divisible by the number of heads
        # This is a simple approach: you might want to fine-tune this depending on your model's needs
        transformer_dim = original_total_dim + (nhead - (original_total_dim % nhead)) if original_total_dim % nhead != 0 else original_total_dim

        self.embedding = nn.Linear(original_total_dim, transformer_dim)

        # Transformer Encoder Layer
        encoder_layer = TransformerEncoderLayer(d_model=transformer_dim, nhead=nhead,
                                                dim_feedforward=dim_feedforward, dropout=dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)

        # Classifier
        self.classifier = nn.Linear(transformer_dim, num_classes)

    def forward(self, audio_features, facial_features, bert_features):
        # Concatenate all features
        combined_features = torch.cat((audio_features, facial_features, bert_features), dim=1)

        # Embedding the input features to match transformer dimension
        embedded_features = self.embedding(combined_features).unsqueeze(1)  # Add sequence dimension

        # Transformer encoder
        transformed_features = self.transformer_encoder(embedded_features).squeeze(1)

        # Classification
        logits = self.classifier(transformed_features)
        return logits

In [38]:
# Assuming train_dataloader is already defined and loads the correct dataset
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)

# Get a single batch to infer feature dimensions
sample_batch = next(iter(train_dataloader))
audio_feature_dim = sample_batch['audio_features'].shape[1]
facial_feature_dim = sample_batch['facial_features'].shape[1]
bert_feature_dim = sample_batch['bert_features'].shape[1]

print(audio_feature_dim,facial_feature_dim,bert_feature_dim)

45 2048 768


In [29]:
# Define the model
model = MultimodalAttentionClassifier(
    audio_feature_dim=audio_feature_dim,
    facial_feature_dim=facial_feature_dim,
    bert_feature_dim=bert_feature_dim,
    num_classes=3,  # Assuming 3 classes for classification
    nhead=8,  # Number of attention heads
    num_encoder_layers=6,  # Number of transformer encoder layers
    dim_feedforward=2048  # Dimension of feedforward network in transformer
)

# Assuming you have a criterion and other parameters defined for ModelTrainer
epochs = 100
save_interval = 5
model_name = "multimodal_attention_classifier"
criterion = torch.nn.CrossEntropyLoss()
trainer.lr = 0.01 

# Initialize the trainer
trainer = ModelTrainer(
    model=model,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    model_name=model_name,
    epochs=epochs,
    save_interval=save_interval
)

# Start training
trainer.train(criterion)



No checkpoints found, starting from scratch.
Epoch 1/100, Train Loss: 0.0172, Train Accuracy: 0.3256, Val Loss: 0.0087, Val Accuracy: 0.3364
Epoch 2/100, Train Loss: 0.0089, Train Accuracy: 0.3329, Val Loss: 0.0087, Val Accuracy: 0.3371
Epoch 3/100, Train Loss: 0.0088, Train Accuracy: 0.3322, Val Loss: 0.0087, Val Accuracy: 0.3371
Epoch 4/100, Train Loss: 0.0089, Train Accuracy: 0.3262, Val Loss: 0.0090, Val Accuracy: 0.3364
Epoch 5/100, Train Loss: 0.0090, Train Accuracy: 0.3316, Val Loss: 0.0087, Val Accuracy: 0.3371
Epoch 6/100, Train Loss: 0.0088, Train Accuracy: 0.3343, Val Loss: 0.0090, Val Accuracy: 0.3371
Epoch 7/100, Train Loss: 0.0088, Train Accuracy: 0.3269, Val Loss: 0.0086, Val Accuracy: 0.3371
Epoch 8/100, Train Loss: 0.0088, Train Accuracy: 0.3324, Val Loss: 0.0093, Val Accuracy: 0.3364
Epoch 9/100, Train Loss: 0.0089, Train Accuracy: 0.3344, Val Loss: 0.0088, Val Accuracy: 0.3364
Epoch 10/100, Train Loss: 0.0089, Train Accuracy: 0.3217, Val Loss: 0.0086, Val Accuracy: 0

In [39]:
# Define the model
model = MultimodalAttentionClassifier(
    audio_feature_dim=audio_feature_dim,
    facial_feature_dim=facial_feature_dim,
    bert_feature_dim=bert_feature_dim,
    num_classes=3,  # Assuming 3 classes for classification
    nhead=8,  # Number of attention heads
    num_encoder_layers=6,  # Number of transformer encoder layers
    dim_feedforward=2048  # Dimension of feedforward network in transformer
)

# Assuming you have a criterion and other parameters defined for ModelTrainer
epochs = 100
save_interval = 5
model_name = "multimodal_attention_classifier2"
criterion = torch.nn.CrossEntropyLoss()

# Initialize the trainer
trainer2 = ModelTrainer(
    model=model,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    model_name=model_name,
    epochs=epochs,
    save_interval=save_interval,
    lr=0.001
)

# Start training
trainer2.train(criterion)

No checkpoints found, starting from scratch.
Epoch 1/100, Train Loss: 0.0262, Train Accuracy: 0.3269, Val Loss: 0.0083, Val Accuracy: 0.3264
Epoch 2/100, Train Loss: 0.0052, Train Accuracy: 0.3306, Val Loss: 0.0051, Val Accuracy: 0.3364
Epoch 3/100, Train Loss: 0.0045, Train Accuracy: 0.3239, Val Loss: 0.0048, Val Accuracy: 0.3264
Epoch 4/100, Train Loss: 0.0045, Train Accuracy: 0.3393, Val Loss: 0.0047, Val Accuracy: 0.3364
Epoch 5/100, Train Loss: 0.0044, Train Accuracy: 0.3369, Val Loss: 0.0048, Val Accuracy: 0.3264
Epoch 6/100, Train Loss: 0.0045, Train Accuracy: 0.3286, Val Loss: 0.0049, Val Accuracy: 0.3371
Epoch 7/100, Train Loss: 0.0045, Train Accuracy: 0.3283, Val Loss: 0.0051, Val Accuracy: 0.3371
Epoch 8/100, Train Loss: 0.0045, Train Accuracy: 0.3290, Val Loss: 0.0047, Val Accuracy: 0.3264
Epoch 9/100, Train Loss: 0.0044, Train Accuracy: 0.3293, Val Loss: 0.0048, Val Accuracy: 0.3371
Epoch 10/100, Train Loss: 0.0044, Train Accuracy: 0.3329, Val Loss: 0.0049, Val Accuracy: 0