In [None]:
# This file is to train the model for sarcasm detection using intermodality inconsistency detection

import torch
import torch.nn as nn
import torch.optim as optim
import json
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
import pickle
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
class IntermodalityInconsistencyDetector(nn.Module):
    def __init__(self, num_emotions=7, embedding_dim=32):
        super().__init__()
        self.modalities = ['text', 'image', 'audio']
        
        # Define emotion polarity mapping as a tensor
        self.register_buffer('emotion_polarity', torch.tensor([
            -0.8,  # angry
            -0.7,  # disgust
            -0.6,  # fear
            0.8,   # happy
            -0.5,  # sad
            0.3,   # surprise
            0.0    # neutral
        ], dtype=torch.float32))
        
        self.classifier = nn.Sequential(
            nn.Linear(num_emotions * 3 + 3, 64),  # 3 modalities * emotions + 3 polarity diffs
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )
        

    def compute_polarity_mismatch(self, x):
        """Compute polarity differences between modalities"""
        # Get polarity scores for each modality
        polarity_scores = torch.matmul(x, self.emotion_polarity)  # (batch_size, num_modalities)
        
        # Compute pairwise differences
        text_image_diff = torch.abs(polarity_scores[:, 0] - polarity_scores[:, 1])
        text_audio_diff = torch.abs(polarity_scores[:, 0] - polarity_scores[:, 2])
        image_audio_diff = torch.abs(polarity_scores[:, 1] - polarity_scores[:, 2])
        
        return torch.stack([text_image_diff, text_audio_diff, image_audio_diff], dim=1)
    
    def forward(self, x):
        """
        x shape: (batch_size, num_modalities, num_emotions)
        """
        batch_size = x.size(0)
        
        # Compute polarity mismatches
        polarity_diffs = self.compute_polarity_mismatch(x)
        
        # Flatten emotion distributions
        emotion_features = x.reshape(batch_size, -1)  # Flatten all modalities
        
        # Concatenate with polarity differences
        combined_features = torch.cat([emotion_features, polarity_diffs], dim=1)
        
        # Get sarcasm prediction
        sarcasm_logits = self.classifier(combined_features)
        sarcasm_probs = torch.sigmoid(sarcasm_logits)
        
        return sarcasm_probs, polarity_diffs



indices_file = "split_indices.p"

def pickle_loader(filename):
    return pickle.load(open(filename, 'rb'), encoding="latin1")
split_indices = pickle_loader(indices_file)
# dataset = EmbeddingDataset(data, label_data)
device = 'cuda'

emotion_to_polarity = {
    0: -0.8,    # angry/anger (strong negative)
    1: -0.7,    # disgust (strong negative but slightly less than anger)
    2: -0.6,    # fear/fearful (negative but less intense than anger/disgust)
    3: 0.8,     # happy/joy (strong positive)
    4: -0.5,    # sad/sadness (moderate negative)
    5: 0.3,     # surprise/surprised (mildly positive - can be positive or negative but often more positive)
    6: 0.0      # neutral/calm (middle point)
}

class MultimodalDataset(Dataset):
    def __init__(self, embedding_dict, label_dict):
        """
        Args:
            embedding_dict: Dictionary with structure 
                          {"id": {"text": [0,0,0,0,0,0,1], 
                                 "image": [0,0,0,1,0,0,0],
                                 "audio": [0,0,0,0,0,0,1]}}
            label_dict: Dictionary with structure {"id": {"sarcasm": label}}
        """
        self.ids = list(embedding_dict.keys())
        self.modalities = ['text', 'image', 'audio']
        
        # Convert one-hot vectors to tensors
        self.embeddings = {
            mod: [torch.tensor(embedding_dict[id][mod], dtype=torch.float32) 
                 for id in self.ids]
            for mod in self.modalities
        }
        self.labels = [torch.tensor(label_dict[id]['sarcasm'], dtype=torch.long) 
                      for id in self.ids]

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, index):
        return (
            {mod: self.embeddings[mod][index] for mod in self.modalities},
            self.labels[index]
        )

def get_dataloader(dataset, indices, batch_size, shuffle):
    subset = torch.utils.data.Subset(dataset, indices)
    return DataLoader(subset, batch_size=batch_size, shuffle=shuffle)

def train_epoch(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    
    for batch_embeddings, labels in train_loader:
        optimizer.zero_grad()
        
        # Stack modalities
        stacked_embeddings = torch.stack(
            [batch_embeddings[mod].to(device) for mod in ['text', 'image', 'audio']], 
            dim=1
        )
        labels = labels.to(device).float()
        
        # Forward pass
        sarcasm_probs, polarity_diffs = model(stacked_embeddings)
        
        # Binary cross entropy loss
        sarcasm_loss = criterion(sarcasm_probs.squeeze(), labels)
        
        # Add regularization for polarity differences
        polarity_reg = -torch.mean(polarity_diffs * labels.unsqueeze(1)) * 0.1
        
        # Total loss
        loss = sarcasm_loss + polarity_reg
        
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    return total_loss / len(train_loader)

def validate(model, val_loader, device):
    model.eval()
    y_true, y_pred = [], []
    polarity_diffs_list = []
    
    with torch.no_grad():
        for batch_embeddings, labels in val_loader:
            modality_embeddings = {
                mod: embeds.to(device) 
                for mod, embeds in batch_embeddings.items()
            }
            
            stacked_embeddings = torch.stack(
                [modality_embeddings[mod] for mod in ['text', 'image', 'audio']], 
                dim=1
            )
            
            # Get predictions
            sarcasm_probs, polarity_diffs = model(stacked_embeddings)
            predictions = (sarcasm_probs.squeeze() > 0.5).float()  # Changed for binary classification

            labels_np = labels.cpu().numpy()
            preds_np = predictions.cpu().numpy()
            polarity_np = polarity_diffs.cpu().numpy()
            
            # If single prediction, convert to array
            if np.ndim(preds_np) == 0:
                preds_np = np.array([preds_np])
            if np.ndim(labels_np) == 0:
                labels_np = np.array([labels_np])
            
            y_true.extend(labels_np)
            y_pred.extend(preds_np)
            polarity_diffs_list.append(polarity_np)  # Changed from extend to append
    
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    polarity_diffs_array = np.stack(polarity_diffs_list)
    
    accuracy = accuracy_score(y_true, y_pred)
    precision, recall, f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average='binary'
    )
    
    return accuracy, precision, recall, f1, polarity_diffs_array

def train_model(dataset, split_indices, model, device, batch_size=32):
    results = []

    
    for fold, (train_indices, val_indices) in enumerate(split_indices):
        print(f"Starting fold {fold+1}")
        
        train_loader = get_dataloader(dataset, train_indices, batch_size, shuffle=True)
        val_loader = get_dataloader(dataset, val_indices, batch_size=1, shuffle=False)
        
        model = model.to(device)
        optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
        criterion = nn.BCELoss()  # Changed from CrossEntropyLoss
        
        best_acc = 0
        best_epoch = 0
        early_stop = 20
        epochs = 0
        
        while True:
            # Train epoch
            total_loss = train_epoch(model, train_loader, optimizer, criterion, device)
            
            # Validate
            accuracy, precision, recall, f1, attention_maps = validate(
                model, val_loader, device
            )
            
            # Save best model
            if accuracy > best_acc:
                print(f'Fold {fold+1}, Epoch {epochs}')
                print(f'Loss: {total_loss:.4f}, Accuracy: {accuracy:.4f}')
                print(f'Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}')
                
                best_acc = accuracy
                best_precision = precision
                best_recall = recall
                best_f1 = f1
                best_epoch = epochs
                
                # Save model and attention maps
                torch.save({
                    'model_state_dict': model.cpu().state_dict(),
                    'attention_maps': attention_maps,
                    'metrics': {
                        'accuracy': accuracy,
                        'precision': precision,
                        'recall': recall,
                        'f1': f1
                    }
                }, f'model/intermodal_fold_{fold+1}.pt')
                
                model.to(device)
            
            # Early stopping
            if epochs - best_epoch > early_stop:
                break
            epochs += 1
        
        results.append({
            'fold': fold + 1,
            'best_accuracy': best_acc,
            'best_epoch': best_epoch,
            'precision': best_precision,
            'recall': best_recall,
            'f1': best_f1
        })
        
        print(f"Fold {fold+1} complete. Best accuracy: {best_acc} at epoch {best_epoch}")
        print(f" Precision: {best_precision}, Recall: {best_recall}, F1: {best_f1}")
    
    average = {
        'accuracy': np.mean([r['best_accuracy'] for r in results]),
        'precision': np.mean([r['precision'] for r in results]),
        'recall': np.mean([r['recall'] for r in results]),
        'f1': np.mean([r['f1'] for r in results])
    }

    print("Average metrics")
    print(f"Accuracy: {average['accuracy']:.4f}")
    print(f"Precision: {average['precision']:.4f}")
    print(f"Recall: {average['recall']:.4f}")
    print(f"F1: {average['f1']:.4f}")


    
    return results

# Example usage
if __name__ == "__main__":
    data = json.load(open("emotion.json"))
    label_data = json.load(open("sarcasm_data.json"))
    
    # Initialize dataset and model
    dataset = MultimodalDataset(data, label_data)
    model = IntermodalityInconsistencyDetector(
        num_emotions=7,  # Since you have 7 emotion categories
)
    
    # Train model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    results = train_model(dataset, split_indices, model, device)

Starting fold 1
Fold 1, Epoch 0
Loss: 0.6631, Accuracy: 0.5217
Precision: 0.4815, Recall: 0.2000, F1: 0.2826
Fold 1, Epoch 3
Loss: 0.6521, Accuracy: 0.5290
Precision: 0.5000, Recall: 0.5231, F1: 0.5113
Fold 1, Epoch 4
Loss: 0.6499, Accuracy: 0.5580
Precision: 0.5333, Recall: 0.4923, F1: 0.5120
Fold 1 complete. Best accuracy: 0.5579710144927537 at epoch 4
 Precision: 0.5333333333333333, Recall: 0.49230769230769234, F1: 0.512
Starting fold 2
Fold 2, Epoch 0
Loss: 0.6425, Accuracy: 0.6232
Precision: 0.5952, Recall: 0.7353, F1: 0.6579
Fold 2 complete. Best accuracy: 0.6231884057971014 at epoch 0
 Precision: 0.5952380952380952, Recall: 0.7352941176470589, F1: 0.6578947368421053
Starting fold 3
Fold 3, Epoch 0
Loss: 0.6333, Accuracy: 0.7153
Precision: 0.7027, Recall: 0.7536, F1: 0.7273
Fold 3 complete. Best accuracy: 0.7153284671532847 at epoch 0
 Precision: 0.7027027027027027, Recall: 0.7536231884057971, F1: 0.7272727272727273
Starting fold 4
Fold 4, Epoch 0
Loss: 0.6067, Accuracy: 0.6204
P