# Importing Libraries

In this notebook, we will utilize several libraries to build and train a multimodal question-answering model. Below is a brief description of the key libraries used:

- **os**: For interacting with the operating system, such as file path management.
- **torch**: The PyTorch library for building and training deep learning models.
- **torch.nn**: Provides modules and classes for building neural networks.
- **pandas**: For data manipulation and analysis.
- **torch.utils.data**: For creating custom datasets and data loaders.
- **transformers**: Hugging Face library for pre-trained transformer models like ViT and PhoBERT.
- **PIL (Pillow)**: For image processing.
- **sklearn.preprocessing**: For preprocessing tasks like label encoding.

These libraries will help us preprocess data, define the model architecture, and train the model efficiently.
```

In [None]:
import os
import torch
import torch.nn as nn
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from transformers import ViTModel, ViTImageProcessor, AutoTokenizer, AutoModel
from PIL import Image
from sklearn.preprocessing import LabelEncoder


## Setting Up Device

In this section, we configure the device to utilize GPU if available. This ensures faster computations, especially when working with deep learning models. If a GPU is not available, the code will fall back to using the CPU.

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Loading Pre-trained Models

In this section, we load the pre-trained ViT (Vision Transformer) and PhoBERT models. These models will serve as the backbone for extracting visual and textual features, respectively. The ViT model is used for image feature extraction, while PhoBERT is utilized for processing textual data.

In [None]:
# Load ViT model and feature extractor
vit_path = "google/vit-base-patch16-224-in21k"
feature_extractor = ViTImageProcessor.from_pretrained(vit_path)
model_vit = ViTModel.from_pretrained(vit_path).to(device)

# Load PhoBERT model and tokenizer
phobert_path = "vinai/phobert-base"
tokenizer = AutoTokenizer.from_pretrained(phobert_path)
model_bert = AutoModel.from_pretrained(phobert_path).to(device)

## Defining the Custom Model

In this section, we define a custom model that combines the Vision Transformer (ViT) and PhoBERT for multimodal question answering. The model integrates visual and textual features using a fusion mechanism and outputs predictions based on the combined features.

In [None]:

class QADataset(Dataset):
    def __init__(self, csv_file):
        self.data = pd.read_csv(csv_file)
        self.le = LabelEncoder()
        self.data['encoded_answer'] = self.le.fit_transform(self.data['answer'])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
        # Process image
        image = Image.open(row['image_path']).convert("RGB")
        image_inputs = feature_extractor(images=image, return_tensors="pt")
        
        # Process question text
        text_inputs = tokenizer(row['question'], 
                                return_tensors="pt", 
                                truncation=True, 
                                padding='max_length', 
                                max_length=128)
        
        label = torch.tensor(row['encoded_answer'], dtype=torch.long)
        
        return {
            'image_pixel_values': image_inputs['pixel_values'].squeeze(0),
            'input_ids': text_inputs['input_ids'].squeeze(0),
            'attention_mask': text_inputs['attention_mask'].squeeze(0),
            'label': label
        }

## Custom Collate Function

In this section, we define a custom collate function to handle batching for the multimodal dataset. This function ensures that the image pixel values, input IDs, attention masks, and labels are properly stacked into tensors for efficient processing during training and evaluation.

In [None]:
def collate_fn(batch):
    image_pixel_values = torch.stack([item['image_pixel_values'] for item in batch])
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    labels = torch.stack([item['label'] for item in batch])
    
    return {
        'image_pixel_values': image_pixel_values,
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels
    }

# Load datasets
train_dataset = QADataset("/kaggle/input/data-instruments/instrument_train/instrument_train.csv")
val_dataset = QADataset("/kaggle/input/data-instruments/instrument_val/instrument_val.csv")
test_dataset = QADataset("/kaggle/input/data-instruments/instrument_test/instrument_test.csv")

# DataLoader
batch_size = 8
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

# Save the label encoder classes for inference
label_classes = train_dataset.le.classes_
print(f"Number of classes: {len(label_classes)}")

## Cross-Attention Module

This section defines the Cross-Attention module, which is a key component of the multimodal fusion model. The Cross-Attention mechanism allows the model to attend to relevant features across different modalities (e.g., vision and text) by computing attention scores between query, key, and value tensors. This enables effective integration of visual and textual information for downstream tasks.

In [None]:
class CrossAttention(nn.Module):
    def __init__(self, dim, num_heads=8, dropout=0.1):
        super(CrossAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.out_proj = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, q, k, v, mask=None):
        batch_size = q.shape[0]
        
        # Project and reshape for multi-head attention
        q = self.q_proj(q).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(k).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(v).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Calculate attention scores
        attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        
        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(2)  # Add head dimensions
            attn = attn.masked_fill(mask == 0, -1e9)
        
        attn = torch.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        
        # Apply attention to values
        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
        
        # Output projection
        out = self.out_proj(out)
        return out

## Transformer Fusion Layer

This section defines the Transformer Fusion Layer, which is responsible for integrating visual and textual features. It uses cross-attention mechanisms to allow the model to attend to relevant features across modalities, followed by feed-forward networks (FFN) for further processing. This layer is a crucial component of the multimodal fusion model.

In [None]:
class TransformerFusionLayer(nn.Module):
    def __init__(self, dim, num_heads=8, dropout=0.1):
        super(TransformerFusionLayer, self).__init__()
        self.vis_to_text_attn = CrossAttention(dim, num_heads, dropout)
        self.text_to_vis_attn = CrossAttention(dim, num_heads, dropout)
        
        self.vis_norm1 = nn.LayerNorm(dim)
        self.vis_norm2 = nn.LayerNorm(dim)
        self.text_norm1 = nn.LayerNorm(dim)
        self.text_norm2 = nn.LayerNorm(dim)
        
        self.vis_ffn = nn.Sequential(
            nn.Linear(dim, 4 * dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(4 * dim, dim),
            nn.Dropout(dropout)
        )
        
        self.text_ffn = nn.Sequential(
            nn.Linear(dim, 4 * dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(4 * dim, dim),
            nn.Dropout(dropout)
        )
        
    def forward(self, vis_feats, text_feats, text_mask=None):
        # Cross-attention from vision to text
        vis_attn_output = self.vis_to_text_attn(vis_feats, text_feats, text_feats, text_mask)
        vis_feats = self.vis_norm1(vis_feats + vis_attn_output)
        
        # FFN for vision features
        vis_ffn_output = self.vis_ffn(vis_feats)
        vis_output = self.vis_norm2(vis_feats + vis_ffn_output)
        
        # Cross-attention from text to vision
        text_attn_output = self.text_to_vis_attn(text_feats, vis_feats, vis_feats)
        text_feats = self.text_norm1(text_feats + text_attn_output)
        
        # FFN for text features
        text_ffn_output = self.text_ffn(text_feats)
        text_output = self.text_norm2(text_feats + text_ffn_output)
        
        return vis_output, text_output

## Multimodal Fusion Model

In this section, we define the multimodal fusion model that integrates visual and textual features using a combination of cross-attention mechanisms and feed-forward networks. The model is designed to handle both image and text inputs, enabling effective feature fusion for downstream tasks such as question answering.

In [None]:
class FeatureProjector(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(FeatureProjector, self).__init__()
        self.proj = nn.Linear(input_dim, output_dim)
        self.norm = nn.LayerNorm(output_dim)
        
    def forward(self, x):
        return self.norm(self.proj(x))
    
class FusionQAModel(nn.Module):
    def __init__(self, output_dim, fusion_dim=768, num_fusion_layers=2):
        super(FusionQAModel, self).__init__()
        
        # Project features to same dimension for fusion
        self.vis_projector = FeatureProjector(768, fusion_dim)  # ViT hidden size -> fusion_dim
        self.text_projector = FeatureProjector(768, fusion_dim) # PhoBERT hidden size -> fusion_dim
        
        # Fusion layers
        self.fusion_layers = nn.ModuleList([
            TransformerFusionLayer(fusion_dim) for _ in range(num_fusion_layers)
        ])
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(fusion_dim * 2, fusion_dim),
            nn.LayerNorm(fusion_dim),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(fusion_dim, output_dim)
        )
        
    def forward(self, vis_feats, text_feats, text_attention_mask=None):
        # Extract CLS token from ViT
        vis_cls = vis_feats[:, 0]
        
        # Extract all token embeddings from PhoBERT
        text_tokens = text_feats
        text_cls = text_tokens[:, 0]
        
        # Project features to same dimension
        vis_proj = self.vis_projector(vis_cls).unsqueeze(1)  # [B, 1, fusion_dim]
        text_proj = self.text_projector(text_tokens)         # [B, seq_len, fusion_dim]
        
        # Apply fusion layers
        for fusion_layer in self.fusion_layers:
            vis_proj, text_proj = fusion_layer(vis_proj, text_proj, text_attention_mask)
        
        # Extract CLS tokens after fusion
        vis_cls_fused = vis_proj.squeeze(1)
        text_cls_fused = text_proj[:, 0]
        
        # Combine modalities for classification
        combined_features = torch.cat([vis_cls_fused, text_cls_fused], dim=1)
        logits = self.classifier(combined_features)
        
        return logits

## Complete Question Answering Model

In this section, we define the `CompleteQAModel` class, which combines the pre-trained Vision Transformer (ViT), PhoBERT, and the FusionQAModel. This model integrates visual and textual features using the fusion mechanism and outputs predictions for multimodal question answering tasks.

In [None]:
class CompleteQAModel(nn.Module):
    def __init__(self, vit_model, bert_model, fusion_model):
        super(CompleteQAModel, self).__init__()
        self.vit_model = vit_model
        self.bert_model = bert_model
        self.fusion_model = fusion_model
        
    def forward(self, image_pixel_values, input_ids, attention_mask):
        # Get ViT features
        vis_outputs = self.vit_model(pixel_values=image_pixel_values)
        vis_feats = vis_outputs.last_hidden_state
        
        # Get PhoBERT features
        text_outputs = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
        text_feats = text_outputs.last_hidden_state
        
        # Fusion and classification
        logits = self.fusion_model(vis_feats, text_feats, attention_mask)
        return logits

## Validate Model

In this section, we define a function to validate the model's performance on the validation dataset. The function computes the average validation loss and accuracy, which are used to monitor the model's performance during training. It ensures that the model generalizes well to unseen data.

In [None]:
# Validation function
def validate_model(model, dataloader, criterion):
    model.eval()
    val_loss = 0.0
    correct, total = 0, 0

    with torch.no_grad():
        for batch in dataloader:
            image_pixel_values = batch['image_pixel_values'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(image_pixel_values, input_ids, attention_mask)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    avg_val_loss = val_loss / len(dataloader)
    val_accuracy = 100 * correct / total
    return avg_val_loss, val_accuracy

## Evaluate Model

In this section, we define a function to evaluate the model's performance on the test dataset. The function computes the overall accuracy and returns the predictions and ground truth labels for further analysis. This step is crucial for assessing the model's generalization ability on unseen data.

In [None]:
def evaluate_model(model, dataloader):
    model.eval()
    correct, total = 0, 0
    all_preds, all_labels = [], []
    
    with torch.no_grad():
        for batch in dataloader:
            image_pixel_values = batch['image_pixel_values'].to(device)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(image_pixel_values, input_ids, attention_mask)
            _, predicted = torch.max(outputs, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")
    
    return all_preds, all_labels

## Fusion Model, Complete Model, Optimizer, and Training

In this section, we initialize the Fusion Model and Complete Model, configure the optimizer and loss function, and set up the training loop. We also handle loading checkpoints if they exist, allowing us to resume training from a saved state. The training process involves forward and backward passes, loss computation, and model parameter updates. At the end of each epoch, the model is validated on the validation dataset, and the learning rate is adjusted based on the validation loss. Checkpoints are saved periodically to ensure progress is not lost.

In [None]:
# Checkpoint path
checkpoint_path = "fusion_checkpoint.pth"

# Initialize the model
output_dim = len(train_dataset.le.classes_)
fusion_model = FusionQAModel(output_dim=output_dim, fusion_dim=768, num_fusion_layers=2).to(device)
complete_model = CompleteQAModel(model_vit, model_bert, fusion_model).to(device)

# Freeze base models (optional - can be adjusted based on your dataset size)
for param in model_vit.parameters():
    param.requires_grad = False
    
for param in model_bert.parameters():
    param.requires_grad = False

# Optimizer and loss
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW([
    {'params': fusion_model.parameters(), 'lr': 3e-4},
    {'params': model_vit.parameters(), 'lr': 1e-5},
    {'params': model_bert.parameters(), 'lr': 1e-5}
], weight_decay=1e-3)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

# Load checkpoint if exists
start_epoch = 0
train_losses, val_losses = [], []

if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    complete_model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    train_losses = checkpoint['train_losses']
    val_losses = checkpoint['val_losses']
    start_epoch = checkpoint['epoch'] + 1
    print(f"Resuming training from Epoch {start_epoch}")
else:
    print("Starting new training...")

# Training loop
num_epochs = 15
log_interval = 10  # Log every N batches

print("Starting training...")
for epoch in range(start_epoch, num_epochs):
    complete_model.train()
    epoch_loss = 0.0
    
    for batch_idx, batch in enumerate(train_dataloader):
        # Move data to device
        image_pixel_values = batch['image_pixel_values'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = complete_model(image_pixel_values, input_ids, attention_mask)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Record loss
        epoch_loss += loss.item()
        train_losses.append(loss.item())
        
        # Logging
        if (batch_idx + 1) % log_interval == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(train_dataloader)}], Loss: {loss.item():.4f}')
    
    # End of epoch validation
    val_loss, val_acc = validate_model(complete_model, val_dataloader, criterion)
    val_losses.append(val_loss)
    
    # Update learning rate
    scheduler.step(val_loss)
    
    # Print epoch stats
    print(f'Epoch [{epoch+1}/{num_epochs}] complete - Train Loss: {epoch_loss/len(train_dataloader):.4f}, '
            f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.2f}%')
    
    # Save checkpoint
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': complete_model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'train_losses': train_losses,
        'val_losses': val_losses,
    }
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved at epoch {epoch+1}")

print("Training complete!")

# Save final model
final_model_path = "final_fusion_qa_model.pth"
torch.save({
    'model_state_dict': complete_model.state_dict(),
    'label_encoder_classes': label_classes,
    'config': {
        'fusion_dim': 768,
        'num_fusion_layers': 2,
        'output_dim': output_dim
    }
}, final_model_path)
print(f"Final model saved at {final_model_path}")
print("Evaluating on test set...")
all_preds, all_labels = evaluate_model(complete_model, test_dataloader)