In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import roc_auc_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
brset_embed = pd.read_csv('embeddings.csv')
brset_split = pd.read_csv('split.csv')

In [4]:
text_column_names = brset_embed.columns[brset_embed.columns.str.match('text_\d+')]
image_column_names = brset_embed.columns[brset_embed.columns.str.match('image_\d+')]
text_columns = brset_embed[text_column_names]
image_columns = brset_embed[image_column_names]

In [5]:
text_embed = torch.tensor(text_columns.values)
image_embed = torch.tensor(image_columns.values)
y = torch.tensor(brset_embed['DR_2'].values)

### Training Function

In [6]:
def train(model, train_loader, val_loader, criterion, optimizer, num_epochs=10, verbose=True, scheduler=None):
    model.to(device)
    history = {'train_loss': [], 'val_loss': [], 'val_auc': [], 'val_accuracy': [], 'val_f1': []}
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for X, y in train_loader:
            X = X.to(device).float()
            y = y.to(device).float()
            optimizer.zero_grad()
            val_logits = model(X)
            loss = criterion(val_logits, y.unsqueeze(1))
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
        train_loss /= len(train_loader)
        history['train_loss'].append(train_loss)

        model.eval()
        with torch.no_grad():
            val_logits = []
            val_labels = []
            val_loss = 0
            for X, y in val_loader:
                X = X.to(device).float()
                y = y.to(device).float()
                val_labels.extend(y.tolist())
                y_pred = model(X)
                val_logits.append(y_pred.cpu().numpy())
                loss = criterion(y_pred, y.unsqueeze(1))
                val_loss += loss.item()
            val_loss /= len(val_loader)
            history['val_loss'].append(val_loss)
            val_logits = np.concatenate(val_logits)
            val_preds = nn.Sigmoid()(torch.tensor(val_logits)).cpu().numpy()
            auc = roc_auc_score(val_labels, val_preds)
            history['val_auc'].append(auc)
            accuracy = accuracy_score(val_labels, val_preds > 0.5)
            history['val_accuracy'].append(accuracy)
            f1 = f1_score(val_labels, val_preds > 0.5)
            history['val_f1'].append(f1)
            if scheduler is not None:
                scheduler.step(val_loss)
                last_lr = scheduler.get_last_lr()[0]
            else:
                last_lr = optimizer.param_groups[0]['lr']
            if verbose:
                print(f'Epoch {epoch+1}/{num_epochs}, train loss: {train_loss:.4f}, val loss: {val_loss:.4f}, val auc: {auc:.4f}, val accuracy: {accuracy:.4f}, val f1: {f1:.4f}, LR: {last_lr}')
    return history


def get_probs(model, loader):
    model.eval()
    model.to(device)
    y_hat = torch.tensor([]).to(device)
    with torch.no_grad():
        for X,_ in loader:
            X = X.to(device).float()
            y_hat = torch.cat((y_hat, model(X)))
    return y_hat.cpu().numpy().flatten()

# Simple Dataset to support embeddings
class SimpleDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [7]:
# Copied from https://github.com/luisnakayama/BRSET/blob/main/src/FocalLoss.py
class BinaryFocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(BinaryFocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        bce_loss = nn.BCEWithLogitsLoss(reduction='none')(inputs, targets)
        p_t = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1 - p_t) ** self.gamma * bce_loss

        if self.reduction == 'mean':
            return torch.mean(focal_loss)
        elif self.reduction == 'sum':
            return torch.sum(focal_loss)
        else:
            return focal_loss

### Image Only Model - Embedding data split

In [8]:
from sklearn.model_selection import train_test_split

train_embed_idx = brset_split[brset_split['embeddings_split'] == 'train'].index
test_embed_idx = brset_split[brset_split['embeddings_split'] == 'test'].index
train_img_emsplit = image_embed[train_embed_idx]
test_img_emsplit = image_embed[test_embed_idx]
train_y_emsplit = y[train_embed_idx]
test_y_emsplit = y[test_embed_idx]

train_img_emsplit, val_img_emsplit, train_y_emsplit, val_y_emsplit = train_test_split(train_img_emsplit, train_y_emsplit, 
                                                                                      test_size=len(test_embed_idx)/len(train_embed_idx),
                                                                                      random_state=42)

print(train_img_emsplit.shape, val_img_emsplit.shape, test_img_emsplit.shape)

torch.Size([9758, 1536]) torch.Size([3254, 1536]) torch.Size([3254, 1536])


In [9]:
image_emsplit_train_dataset = SimpleDataset(train_img_emsplit, train_y_emsplit)
image_emsplit_val_dataset = SimpleDataset(val_img_emsplit, val_y_emsplit)
image_emsplit_test_dataset = SimpleDataset(test_img_emsplit, test_y_emsplit)

image_emsplit_train_loader = DataLoader(image_emsplit_train_dataset, batch_size=32, shuffle=True)
image_emsplit_val_loader = DataLoader(image_emsplit_val_dataset, batch_size=32, shuffle=False)
image_emsplit_test_loader = DataLoader(image_emsplit_test_dataset, batch_size=32, shuffle=False)

In [10]:
# Train image only model
image_only_model_emsplit = nn.Sequential(
    nn.Linear(1536, 256),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(256, 1),
    # nn.Sigmoid()
)

In [11]:
p1 = sum(train_y_emsplit)/len(train_y_emsplit)
p0 = 1 - p1
pos_weight = torch.tensor(p0/p1).to(device)

# criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion = BinaryFocalLoss(alpha=1-p1, gamma=2)
optimizer = optim.Adam(image_only_model_emsplit.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
history = train(image_only_model_emsplit, image_emsplit_train_loader, image_emsplit_val_loader, criterion, optimizer, num_epochs=50, scheduler=scheduler)

  pos_weight = torch.tensor(p0/p1).to(device)
  _torch_pytree._register_pytree_node(


Epoch 1/50, train loss: 0.0543, val loss: 0.0403, val auc: 0.9212, val accuracy: 0.9392, val f1: 0.3077, LR: 0.001
Epoch 2/50, train loss: 0.0398, val loss: 0.0376, val auc: 0.9331, val accuracy: 0.9474, val f1: 0.4896, LR: 0.001
Epoch 3/50, train loss: 0.0372, val loss: 0.0347, val auc: 0.9392, val accuracy: 0.9511, val f1: 0.5849, LR: 0.001
Epoch 4/50, train loss: 0.0351, val loss: 0.0353, val auc: 0.9372, val accuracy: 0.9511, val f1: 0.5714, LR: 0.001
Epoch 5/50, train loss: 0.0360, val loss: 0.0336, val auc: 0.9412, val accuracy: 0.9518, val f1: 0.5791, LR: 0.001
Epoch 6/50, train loss: 0.0365, val loss: 0.0359, val auc: 0.9395, val accuracy: 0.9465, val f1: 0.4387, LR: 0.001
Epoch 7/50, train loss: 0.0341, val loss: 0.0350, val auc: 0.9439, val accuracy: 0.9545, val f1: 0.5912, LR: 0.001
Epoch 8/50, train loss: 0.0324, val loss: 0.0323, val auc: 0.9459, val accuracy: 0.9548, val f1: 0.6441, LR: 0.001
Epoch 9/50, train loss: 0.0312, val loss: 0.0357, val auc: 0.9470, val accuracy:

In [12]:
# Evaluate image only model on test set
y_probs = get_probs(image_only_model_emsplit, image_emsplit_test_loader)
y_preds = (np.array(y_probs) > 0.5).astype(int)
image_only_roc = roc_auc_score(test_y_emsplit.numpy(), y_probs)
image_only_accuracy = accuracy_score(test_y_emsplit.numpy(), y_preds)
image_only_f1 = f1_score(test_y_emsplit.numpy(), y_preds)
print(f'Image Only ROC: {image_only_roc}, Accuracy: {image_only_accuracy}, F1: {image_only_f1}')

Image Only ROC: 0.9452943824144606, Accuracy: 0.9640442532267978, F1: 0.6422018348623854


### Image Only Model - Resplit Data

In [13]:
train_idx = brset_split[brset_split['split'] == 'train'].index
val_idx = brset_split[brset_split['split'] == 'val'].index
test_idx = brset_split[brset_split['split'] == 'test'].index

image_train = image_embed[train_idx]
image_val = image_embed[val_idx]
image_test = image_embed[test_idx]

y_train = y[train_idx]
y_val = y[val_idx]
y_test = y[test_idx]

# DataSet
image_train_dataset = SimpleDataset(image_train, y_train)
image_val_dataset = SimpleDataset(image_val, y_val)
image_test_dataset = SimpleDataset(image_test, y_test)

# DataLoader
image_train_loader = DataLoader(image_train_dataset, batch_size=32, shuffle=True)
image_val_loader = DataLoader(image_val_dataset, batch_size=32, shuffle=False)
image_test_loader = DataLoader(image_test_dataset, batch_size=32, shuffle=False)

In [14]:
image_only_model = nn.Sequential(
    nn.Linear(1536, 256),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(256, 1),
    # nn.Sigmoid()
)

In [15]:
criterion = BinaryFocalLoss(alpha=1-p1, gamma=2)
optimizer = optim.Adam(image_only_model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
history = train(image_only_model, image_train_loader, image_val_loader, criterion, optimizer, num_epochs=50, scheduler=scheduler)

Epoch 1/50, train loss: 0.0583, val loss: 0.0394, val auc: 0.9090, val accuracy: 0.9453, val f1: 0.2823, LR: 0.001
Epoch 2/50, train loss: 0.0421, val loss: 0.0354, val auc: 0.9268, val accuracy: 0.9533, val f1: 0.5190, LR: 0.001
Epoch 3/50, train loss: 0.0372, val loss: 0.0334, val auc: 0.9339, val accuracy: 0.9588, val f1: 0.5890, LR: 0.001
Epoch 4/50, train loss: 0.0358, val loss: 0.0332, val auc: 0.9373, val accuracy: 0.9573, val f1: 0.5615, LR: 0.001
Epoch 5/50, train loss: 0.0364, val loss: 0.0334, val auc: 0.9359, val accuracy: 0.9592, val f1: 0.5933, LR: 0.001
Epoch 6/50, train loss: 0.0362, val loss: 0.0322, val auc: 0.9372, val accuracy: 0.9607, val f1: 0.6322, LR: 0.001
Epoch 7/50, train loss: 0.0342, val loss: 0.0318, val auc: 0.9389, val accuracy: 0.9570, val f1: 0.6465, LR: 0.001
Epoch 8/50, train loss: 0.0338, val loss: 0.0413, val auc: 0.9399, val accuracy: 0.9484, val f1: 0.3538, LR: 0.001
Epoch 9/50, train loss: 0.0354, val loss: 0.0518, val auc: 0.9414, val accuracy:

In [16]:
y_probs = get_probs(image_only_model, image_test_loader)
y_preds = (np.array(y_probs) > 0.5).astype(int)
image_only_roc = roc_auc_score(y[test_idx].numpy(), y_probs)
image_only_accuracy = accuracy_score(y[test_idx].numpy(), y_preds)
image_only_f1 = f1_score(y[test_idx].numpy(), y_preds)
print(f'Image Only ROC: {image_only_roc}, Accuracy: {image_only_accuracy}, F1: {image_only_f1}')

Image Only ROC: 0.9567045849440144, Accuracy: 0.962188748847218, F1: 0.616822429906542


### Text Only Model

In [17]:
# split text_embed into train and test based on brset_embed['split']
text_train = text_embed[train_idx]
text_val = text_embed[val_idx]
text_test = text_embed[test_idx]


In [18]:
text_train_dataset = SimpleDataset(text_train, y_train)
text_val_dataset = SimpleDataset(text_val, y_val)
text_test_dataset = SimpleDataset(text_test, y_test)

text_train_loader = DataLoader(text_train_dataset, batch_size=32, shuffle=True)
text_val_loader = DataLoader(text_val_dataset, batch_size=32, shuffle=False)
text_test_loader = DataLoader(text_test_dataset, batch_size=32, shuffle=False)

In [19]:
text_only_model = nn.Sequential(
    nn.Linear(4096, 256),
    nn.ReLU(),
    nn.Dropout(0.2),
    # nn.BatchNorm1d(256),
    nn.Linear(256, 1),
    # nn.Sigmoid()
)

In [20]:
# criterion = nn.BCEWithLogitsLoss()
criterion = BinaryFocalLoss(alpha=1-p1, gamma=2)
optimizer = optim.Adam(text_only_model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
history = train(text_only_model, text_train_loader, text_val_loader, criterion, optimizer, num_epochs=50, scheduler=scheduler)

Epoch 1/50, train loss: 0.0727, val loss: 0.0411, val auc: 0.9503, val accuracy: 0.9380, val f1: 0.0818, LR: 0.001
Epoch 2/50, train loss: 0.0358, val loss: 0.0286, val auc: 0.9540, val accuracy: 0.9644, val f1: 0.6667, LR: 0.001
Epoch 3/50, train loss: 0.0315, val loss: 0.0266, val auc: 0.9571, val accuracy: 0.9665, val f1: 0.6877, LR: 0.001
Epoch 4/50, train loss: 0.0306, val loss: 0.0331, val auc: 0.9512, val accuracy: 0.9714, val f1: 0.7770, LR: 0.001
Epoch 5/50, train loss: 0.0299, val loss: 0.0280, val auc: 0.9601, val accuracy: 0.9582, val f1: 0.5310, LR: 0.001
Epoch 6/50, train loss: 0.0286, val loss: 0.0347, val auc: 0.9589, val accuracy: 0.9659, val f1: 0.7483, LR: 0.001
Epoch 7/50, train loss: 0.0309, val loss: 0.0289, val auc: 0.9611, val accuracy: 0.9644, val f1: 0.7500, LR: 0.001
Epoch 8/50, train loss: 0.0315, val loss: 0.0300, val auc: 0.9584, val accuracy: 0.9705, val f1: 0.7126, LR: 0.001
Epoch 9/50, train loss: 0.0305, val loss: 0.0415, val auc: 0.9397, val accuracy:

In [21]:
# Evaluate text only model on test set
y_probs = get_probs(text_only_model, text_test_loader)
y_preds = (np.array(y_probs) > 0.5).astype(int)
text_only_roc = roc_auc_score(y_test.numpy(), y_probs)
text_only_accuracy = accuracy_score(y_test.numpy(), y_preds)
text_only_f1 = f1_score(y_test.numpy(), y_preds)
print(f'Text Only ROC: {text_only_roc}, Accuracy: {text_only_accuracy}, F1: {text_only_f1}')

Text Only ROC: 0.9770691293557583, Accuracy: 0.9342145711650784, F1: 0.0


### Simple Early Fusion Model

In [22]:
simple_early_fusion_model = nn.Sequential(
    nn.Linear(4096+1536, 512),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(512, 1),
    # nn.Sigmoid()
)

combined_train = torch.cat((text_train, image_train), dim=1)
combined_val = torch.cat((text_val, image_val), dim=1)
combined_test = torch.cat((text_test, image_test), dim=1)

combined_train_dataset = SimpleDataset(combined_train, y_train)
combined_val_dataset = SimpleDataset(combined_val, y_val)
combined_test_dataset = SimpleDataset(combined_test, y_test)

combined_train_loader = DataLoader(combined_train_dataset, batch_size=32, shuffle=True)
combined_val_loader = DataLoader(combined_val_dataset, batch_size=32, shuffle=False)
combined_test_loader = DataLoader(combined_test_dataset, batch_size=32, shuffle=False)

In [23]:
criterion = BinaryFocalLoss(alpha=1-p1, gamma=2)
optimizer = optim.Adam(simple_early_fusion_model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
history = train(simple_early_fusion_model, combined_train_loader, combined_val_loader, criterion, optimizer, num_epochs=50, scheduler=scheduler)

Epoch 1/50, train loss: 0.1013, val loss: 0.0319, val auc: 0.9591, val accuracy: 0.9530, val f1: 0.4436, LR: 0.001
Epoch 2/50, train loss: 0.0286, val loss: 0.0328, val auc: 0.9659, val accuracy: 0.9561, val f1: 0.4911, LR: 0.001
Epoch 3/50, train loss: 0.0276, val loss: 0.0224, val auc: 0.9662, val accuracy: 0.9751, val f1: 0.7718, LR: 0.001
Epoch 4/50, train loss: 0.0247, val loss: 0.0466, val auc: 0.9682, val accuracy: 0.9555, val f1: 0.4803, LR: 0.001
Epoch 5/50, train loss: 0.0244, val loss: 0.0206, val auc: 0.9685, val accuracy: 0.9773, val f1: 0.8053, LR: 0.001
Epoch 6/50, train loss: 0.0237, val loss: 0.0621, val auc: 0.9651, val accuracy: 0.9395, val f1: 0.1244, LR: 0.001
Epoch 7/50, train loss: 0.0255, val loss: 0.0204, val auc: 0.9693, val accuracy: 0.9742, val f1: 0.7889, LR: 0.001
Epoch 8/50, train loss: 0.0319, val loss: 0.0405, val auc: 0.9656, val accuracy: 0.9438, val f1: 0.2343, LR: 0.001
Epoch 9/50, train loss: 0.0297, val loss: 0.0272, val auc: 0.9662, val accuracy:

In [24]:
y_probs = get_probs(simple_early_fusion_model, combined_test_loader)
y_preds = (np.array(y_probs) > 0.5).astype(int)
early_fusion_roc = roc_auc_score(y_test.numpy(), y_probs)
early_fusion_accuracy = accuracy_score(y_test.numpy(), y_preds)
early_fusion_f1 = f1_score(y_test.numpy(), y_preds)
print(f'Early Fusion ROC: {early_fusion_roc}, Accuracy: {early_fusion_accuracy}, F1: {early_fusion_f1}')

Early Fusion ROC: 0.9864564401103413, Accuracy: 0.9342145711650784, F1: 0.0


### Simple Late Fusion Model

In [25]:
# Define separate modules for text and image processing
class TextModule(nn.Module):
    def __init__(self):
        super(TextModule, self).__init__()
        self.fc1 = nn.Linear(4096, 256)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return x

class ImageModule(nn.Module):
    def __init__(self):
        super(ImageModule, self).__init__()
        self.fc1 = nn.Linear(1536, 256)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return x

class SimpleLateFusionModel(nn.Module):
    def __init__(self):
        super(SimpleLateFusionModel, self).__init__()
        self.text_module = TextModule()
        self.image_module = ImageModule()
        self.fc1 = nn.Linear(512, 128)
        self.fc2 = nn.Linear(128, 1)
        # self.output = nn.Sigmoid()

    def forward(self, combined_data):
        text_data = combined_data[:, :4096]
        image_data = combined_data[:, 4096:]
        text_features = self.text_module(text_data)
        image_features = self.image_module(image_data)
        combined_features = torch.cat((text_features, image_features), dim=1)
        x = self.fc1(combined_features)
        x = torch.relu(x)
        x = self.fc2(x)
        # x = self.output(x)
        return x

late_fusion_model = SimpleLateFusionModel()

In [26]:
criterion = BinaryFocalLoss(alpha=1-p1, gamma=2)
optimizer = optim.Adam(late_fusion_model.parameters(), lr=0.0001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
history = train(late_fusion_model, combined_train_loader, combined_val_loader, criterion, optimizer, num_epochs=50, scheduler=scheduler)

Epoch 1/50, train loss: 0.0426, val loss: 0.0289, val auc: 0.9558, val accuracy: 0.9598, val f1: 0.6507, LR: 0.0001
Epoch 2/50, train loss: 0.0255, val loss: 0.0227, val auc: 0.9661, val accuracy: 0.9705, val f1: 0.7433, LR: 0.0001
Epoch 3/50, train loss: 0.0207, val loss: 0.0202, val auc: 0.9685, val accuracy: 0.9736, val f1: 0.7749, LR: 0.0001
Epoch 4/50, train loss: 0.0183, val loss: 0.0225, val auc: 0.9690, val accuracy: 0.9742, val f1: 0.8065, LR: 0.0001
Epoch 5/50, train loss: 0.0182, val loss: 0.0182, val auc: 0.9708, val accuracy: 0.9791, val f1: 0.8238, LR: 0.0001
Epoch 6/50, train loss: 0.0180, val loss: 0.0283, val auc: 0.9722, val accuracy: 0.9681, val f1: 0.6770, LR: 0.0001
Epoch 7/50, train loss: 0.0164, val loss: 0.0285, val auc: 0.9725, val accuracy: 0.9699, val f1: 0.7012, LR: 0.0001
Epoch 8/50, train loss: 0.0164, val loss: 0.0212, val auc: 0.9721, val accuracy: 0.9754, val f1: 0.7701, LR: 0.0001
Epoch 9/50, train loss: 0.0161, val loss: 0.0167, val auc: 0.9734, val a

In [27]:
y_probs = get_probs(late_fusion_model, combined_test_loader)
y_preds = (np.array(y_probs) > 0.5).astype(int)
late_fusion_roc = roc_auc_score(y_test.numpy(), y_probs)
late_fusion_accuracy = accuracy_score(y_test.numpy(), y_preds)
late_fusion_f1 = f1_score(y_test.numpy(), y_preds)
print(f'Late Fusion ROC: {late_fusion_roc}, Accuracy: {late_fusion_accuracy}, F1: {late_fusion_f1}')

Late Fusion ROC: 0.9865317846192642, Accuracy: 0.9837073470642483, F1: 0.8616187989556137


### Attention Model

In [28]:
class ImageAttentionModule(nn.Module):
    def __init__(self):
        super(ImageAttentionModule, self).__init__()
        self.fc1 = nn.Linear(1536, 256)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        self.attention = nn.MultiheadAttention(embed_dim=16, num_heads=4, batch_first=True)

    def forward(self, x):
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.dropout(x)

        # Reshape x from [batch_size, 256] to [batch_size, 16, 16] for attention
        x = x.view(-1, 16, 16)

        # Apply attention
        attn_output, _ = self.attention(x, x, x)

        # Flatten the output for the final fully connected layer
        x = attn_output.reshape(-1, 256)  # Reshape back to original shape after attention
        return x

class TextAttentionModule(nn.Module):
    def __init__(self):
        super(TextAttentionModule, self).__init__()
        self.fc1 = nn.Linear(4096, 256)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        self.attention = nn.MultiheadAttention(embed_dim=16, num_heads=4, batch_first=True)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)

        x = x.view(-1, 16, 16)

        # Apply attention
        attn_output, _ = self.attention(x, x, x)

        # Flatten the output for the final fully connected layer
        x = attn_output.reshape(-1, 256)  # Reshape back to original shape after attention
        return x

class AttentionFusionModel(nn.Module):
    def __init__(self):
        super(AttentionFusionModel, self).__init__()
        self.text_attention = TextAttentionModule()
        self.image_attention = ImageAttentionModule()
        self.cross_attention = nn.MultiheadAttention(embed_dim=256, num_heads=4, batch_first=True)
        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, 1)
        # self.output = nn.Sigmoid()

    def forward(self, combined_data):
        text_data = combined_data[:, :4096]
        image_data = combined_data[:, 4096:]
        text_output = self.text_attention(text_data)
        image_output = self.image_attention(image_data)
        # combined_features = torch.cat((text_output, image_output), dim=1)
        combined_features, _ = self.cross_attention(text_output.unsqueeze(1), image_output.unsqueeze(1), image_output.unsqueeze(1))
        combined_features = combined_features.squeeze(1)
        x = self.fc1(combined_features)

        x = torch.relu(x)
 
        x = self.fc2(x)
        # x = self.output(x)
        return x
    
attention_fusion_model = AttentionFusionModel()

In [29]:
criterion = BinaryFocalLoss(alpha=1-p1, gamma=2)
optimizer = optim.Adam(attention_fusion_model.parameters(), lr=0.0001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
history = train(attention_fusion_model, combined_train_loader, combined_val_loader, criterion, optimizer, num_epochs=50, scheduler=scheduler)

Epoch 1/50, train loss: 0.0594, val loss: 0.0469, val auc: 0.8831, val accuracy: 0.9373, val f1: 0.0642, LR: 0.0001
Epoch 2/50, train loss: 0.0437, val loss: 0.0412, val auc: 0.9196, val accuracy: 0.9447, val f1: 0.3077, LR: 0.0001
Epoch 3/50, train loss: 0.0409, val loss: 0.0367, val auc: 0.9202, val accuracy: 0.9524, val f1: 0.4561, LR: 0.0001
Epoch 4/50, train loss: 0.0378, val loss: 0.0351, val auc: 0.9291, val accuracy: 0.9552, val f1: 0.5290, LR: 0.0001
Epoch 5/50, train loss: 0.0356, val loss: 0.0346, val auc: 0.9321, val accuracy: 0.9595, val f1: 0.6333, LR: 0.0001
Epoch 6/50, train loss: 0.0343, val loss: 0.0367, val auc: 0.9343, val accuracy: 0.9604, val f1: 0.6346, LR: 0.0001
Epoch 7/50, train loss: 0.0333, val loss: 0.0322, val auc: 0.9354, val accuracy: 0.9610, val f1: 0.6068, LR: 0.0001
Epoch 8/50, train loss: 0.0328, val loss: 0.0325, val auc: 0.9353, val accuracy: 0.9601, val f1: 0.5695, LR: 0.0001
Epoch 9/50, train loss: 0.0324, val loss: 0.0333, val auc: 0.9414, val a

In [30]:
y_probs = get_probs(attention_fusion_model, combined_test_loader)
y_preds = (np.array(y_probs) > 0.5).astype(int)
attention_fusion_roc = roc_auc_score(y_test.numpy(), y_probs)
attention_fusion_accuracy = accuracy_score(y_test.numpy(), y_preds)
attention_fusion_f1 = f1_score(y_test.numpy(), y_preds)
print(f'Attention Fusion ROC: {attention_fusion_roc}, Accuracy: {attention_fusion_accuracy}, F1: {attention_fusion_f1}')

Attention Fusion ROC: 0.9493238983556445, Accuracy: 0.9591146633876422, F1: 0.56957928802589
