In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import roc_auc_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.metrics import precision_recall_curve

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
brset_embed = pd.read_csv('embeddings.csv')
brset_split = pd.read_csv('split.csv')

In [4]:
text_column_names = brset_embed.columns[brset_embed.columns.str.match('text_\d+')]
image_column_names = brset_embed.columns[brset_embed.columns.str.match('image_\d+')]
text_columns = brset_embed[text_column_names]
image_columns = brset_embed[image_column_names]

In [5]:
text_embed = torch.tensor(text_columns.values)
image_embed = torch.tensor(image_columns.values)
y = torch.tensor(brset_embed['DR_2'].values)

### Training Function

In [6]:
def train(model, train_loader, val_loader, criterion, optimizer, num_epochs=10, verbose=True, scheduler=None):
    model.to(device)
    history = {'train_loss': [], 'val_loss': [], 'val_auc': [], 'val_accuracy': [], 'val_f1': []}
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for X, y in train_loader:
            X = X.to(device).float()
            y = y.to(device).float()
            optimizer.zero_grad()
            val_logits = model(X)
            loss = criterion(val_logits, y.unsqueeze(1))
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
        train_loss /= len(train_loader)
        history['train_loss'].append(train_loss)

        model.eval()
        with torch.no_grad():
            val_logits = []
            val_labels = []
            val_loss = 0
            for X, y in val_loader:
                X = X.to(device).float()
                y = y.to(device).float()
                val_labels.extend(y.tolist())
                y_pred = model(X)
                val_logits.append(y_pred.cpu().numpy())
                loss = criterion(y_pred, y.unsqueeze(1))
                val_loss += loss.item()
            val_loss /= len(val_loader)
            history['val_loss'].append(val_loss)
            val_logits = np.concatenate(val_logits)
            val_preds = nn.Sigmoid()(torch.tensor(val_logits)).cpu().numpy()
            auc = roc_auc_score(val_labels, val_preds)
            history['val_auc'].append(auc)
            accuracy = accuracy_score(val_labels, val_preds > 0.5)
            history['val_accuracy'].append(accuracy)
            f1 = f1_score(val_labels, val_preds > 0.5)
            history['val_f1'].append(f1)
            if scheduler is not None:
                scheduler.step(val_loss)
                last_lr = scheduler.get_last_lr()[0]
            else:
                last_lr = optimizer.param_groups[0]['lr']
            if verbose:
                print(f'Epoch {epoch+1}/{num_epochs}, train loss: {train_loss:.4f}, val loss: {val_loss:.4f}, val auc: {auc:.4f}, val accuracy: {accuracy:.4f}, val f1: {f1:.4f}, LR: {last_lr}')
    return history


def get_probs(model, loader):
    model.eval()
    model.to(device)
    y_hat = torch.tensor([]).to(device)
    with torch.no_grad():
        for X,_ in loader:
            X = X.to(device).float()
            y_hat = torch.cat((y_hat, model(X)))
    return y_hat.cpu().numpy().flatten()

def get_optimal_f1_threshold(y_true, y_pred):
    epsilon = 1e-10
    precision, recall, thresholds = precision_recall_curve(y_true, y_pred)
    f1 = 2 * precision * recall / (precision + recall + epsilon)
    return thresholds[np.argmax(f1)]

# Simple Dataset to support embeddings
class SimpleDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [7]:
# Copied from https://github.com/luisnakayama/BRSET/blob/main/src/FocalLoss.py
class BinaryFocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(BinaryFocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        bce_loss = nn.BCEWithLogitsLoss(reduction='none')(inputs, targets)
        p_t = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1 - p_t) ** self.gamma * bce_loss

        if self.reduction == 'mean':
            return torch.mean(focal_loss)
        elif self.reduction == 'sum':
            return torch.sum(focal_loss)
        else:
            return focal_loss

### Image Only Model - Embedding data split

In [8]:
from sklearn.model_selection import train_test_split

train_embed_idx = brset_split[brset_split['embeddings_split'] == 'train'].index
test_embed_idx = brset_split[brset_split['embeddings_split'] == 'test'].index
train_img_emsplit = image_embed[train_embed_idx]
test_img_emsplit = image_embed[test_embed_idx]
train_y_emsplit = y[train_embed_idx]
test_y_emsplit = y[test_embed_idx]

train_img_emsplit, val_img_emsplit, train_y_emsplit, val_y_emsplit = train_test_split(train_img_emsplit, train_y_emsplit, 
                                                                                      test_size=len(test_embed_idx)/len(train_embed_idx),
                                                                                      random_state=42)

print(train_img_emsplit.shape, val_img_emsplit.shape, test_img_emsplit.shape)

torch.Size([9758, 1536]) torch.Size([3254, 1536]) torch.Size([3254, 1536])


In [9]:
image_emsplit_train_dataset = SimpleDataset(train_img_emsplit, train_y_emsplit)
image_emsplit_val_dataset = SimpleDataset(val_img_emsplit, val_y_emsplit)
image_emsplit_test_dataset = SimpleDataset(test_img_emsplit, test_y_emsplit)

image_emsplit_train_loader = DataLoader(image_emsplit_train_dataset, batch_size=32, shuffle=True)
image_emsplit_val_loader = DataLoader(image_emsplit_val_dataset, batch_size=32, shuffle=False)
image_emsplit_test_loader = DataLoader(image_emsplit_test_dataset, batch_size=32, shuffle=False)

In [10]:
# Train image only model
image_only_model_emsplit = nn.Sequential(
    nn.Linear(1536, 256),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(256, 1),
    # nn.Sigmoid()
)

In [11]:
p1 = sum(train_y_emsplit)/len(train_y_emsplit)
p0 = 1 - p1
pos_weight = torch.tensor(p0/p1).to(device)

# criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion = BinaryFocalLoss(alpha=1-p1, gamma=2)
optimizer = optim.Adam(image_only_model_emsplit.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
history = train(image_only_model_emsplit, image_emsplit_train_loader, image_emsplit_val_loader, criterion, optimizer, num_epochs=50, scheduler=scheduler)

  pos_weight = torch.tensor(p0/p1).to(device)
  _torch_pytree._register_pytree_node(


Epoch 1/50, train loss: 0.0627, val loss: 0.0376, val auc: 0.9234, val accuracy: 0.9468, val f1: 0.5014, LR: 0.001
Epoch 2/50, train loss: 0.0400, val loss: 0.0422, val auc: 0.9221, val accuracy: 0.9428, val f1: 0.3716, LR: 0.001
Epoch 3/50, train loss: 0.0366, val loss: 0.0421, val auc: 0.9381, val accuracy: 0.9471, val f1: 0.6195, LR: 0.001
Epoch 4/50, train loss: 0.0362, val loss: 0.0339, val auc: 0.9409, val accuracy: 0.9536, val f1: 0.5973, LR: 0.001
Epoch 5/50, train loss: 0.0359, val loss: 0.0348, val auc: 0.9428, val accuracy: 0.9518, val f1: 0.5501, LR: 0.001
Epoch 6/50, train loss: 0.0348, val loss: 0.0340, val auc: 0.9446, val accuracy: 0.9527, val f1: 0.5497, LR: 0.001
Epoch 7/50, train loss: 0.0332, val loss: 0.0351, val auc: 0.9420, val accuracy: 0.9542, val f1: 0.5803, LR: 0.001
Epoch 8/50, train loss: 0.0324, val loss: 0.0330, val auc: 0.9449, val accuracy: 0.9530, val f1: 0.5405, LR: 0.001
Epoch 9/50, train loss: 0.0318, val loss: 0.0325, val auc: 0.9440, val accuracy:

In [12]:
# Evaluate image only model on test set
y_probs = get_probs(image_only_model_emsplit, image_emsplit_test_loader)
y_preds = (np.array(y_probs) > 0.5).astype(int)
image_only_roc = roc_auc_score(test_y_emsplit.numpy(), y_probs)
image_only_accuracy = accuracy_score(test_y_emsplit.numpy(), y_preds)
image_only_f1 = f1_score(test_y_emsplit.numpy(), y_preds)
print(f'Image Only ROC: {image_only_roc}, Accuracy: {image_only_accuracy}, F1: {image_only_f1}')

threshold = get_optimal_f1_threshold(test_y_emsplit.numpy(), y_probs)
y_preds = (np.array(y_probs) > threshold).astype(int)
image_only_accuracy = accuracy_score(test_y_emsplit.numpy(), y_preds)
image_only_f1 = f1_score(test_y_emsplit.numpy(), y_preds)
print(f'Image Only Accuracy: {image_only_accuracy}, F1: {image_only_f1}')

Image Only ROC: 0.9477691788939887, Accuracy: 0.9640442532267978, F1: 0.6443768996960486
Image Only Accuracy: 0.9683466502765826, F1: 0.7178082191780821


### Image Only Model - Resplit Data

In [13]:
train_idx = brset_split[brset_split['split'] == 'train'].index
val_idx = brset_split[brset_split['split'] == 'val'].index
test_idx = brset_split[brset_split['split'] == 'test'].index

image_train = image_embed[train_idx]
image_val = image_embed[val_idx]
image_test = image_embed[test_idx]

y_train = y[train_idx]
y_val = y[val_idx]
y_test = y[test_idx]

# DataSet
image_train_dataset = SimpleDataset(image_train, y_train)
image_val_dataset = SimpleDataset(image_val, y_val)
image_test_dataset = SimpleDataset(image_test, y_test)

# DataLoader
image_train_loader = DataLoader(image_train_dataset, batch_size=32, shuffle=True)
image_val_loader = DataLoader(image_val_dataset, batch_size=32, shuffle=False)
image_test_loader = DataLoader(image_test_dataset, batch_size=32, shuffle=False)

In [14]:
image_only_model = nn.Sequential(
    nn.Linear(1536, 256),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(256, 1),
    # nn.Sigmoid()
)

In [15]:
criterion = BinaryFocalLoss(alpha=1-p1, gamma=2)
optimizer = optim.Adam(image_only_model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
history = train(image_only_model, image_train_loader, image_val_loader, criterion, optimizer, num_epochs=50, scheduler=scheduler)

Epoch 1/50, train loss: 0.0574, val loss: 0.0364, val auc: 0.9224, val accuracy: 0.9518, val f1: 0.4678, LR: 0.001
Epoch 2/50, train loss: 0.0430, val loss: 0.0392, val auc: 0.9187, val accuracy: 0.9472, val f1: 0.3485, LR: 0.001
Epoch 3/50, train loss: 0.0384, val loss: 0.0342, val auc: 0.9333, val accuracy: 0.9533, val f1: 0.4967, LR: 0.001
Epoch 4/50, train loss: 0.0378, val loss: 0.0357, val auc: 0.9316, val accuracy: 0.9545, val f1: 0.6318, LR: 0.001
Epoch 5/50, train loss: 0.0375, val loss: 0.0347, val auc: 0.9325, val accuracy: 0.9567, val f1: 0.5220, LR: 0.001
Epoch 6/50, train loss: 0.0337, val loss: 0.0320, val auc: 0.9349, val accuracy: 0.9588, val f1: 0.6510, LR: 0.001
Epoch 7/50, train loss: 0.0336, val loss: 0.0367, val auc: 0.9381, val accuracy: 0.9592, val f1: 0.6581, LR: 0.001
Epoch 8/50, train loss: 0.0323, val loss: 0.0321, val auc: 0.9405, val accuracy: 0.9598, val f1: 0.5677, LR: 0.001
Epoch 9/50, train loss: 0.0322, val loss: 0.0300, val auc: 0.9405, val accuracy:

In [16]:
y_probs = get_probs(image_only_model, image_test_loader)
y_preds = (np.array(y_probs) > 0.5).astype(int)
image_only_roc = roc_auc_score(y[test_idx].numpy(), y_probs)
image_only_accuracy = accuracy_score(y[test_idx].numpy(), y_preds)
image_only_f1 = f1_score(y[test_idx].numpy(), y_preds)
print(f'Image Only ROC: {image_only_roc}, Accuracy: {image_only_accuracy}, F1: {image_only_f1}')

threshold = get_optimal_f1_threshold(y[test_idx].numpy(), y_probs)
y_preds = (np.array(y_probs) > threshold).astype(int)
image_only_accuracy = accuracy_score(y[test_idx].numpy(), y_preds)
image_only_f1 = f1_score(y[test_idx].numpy(), y_preds)
print(f'Image Only Accuracy: {image_only_accuracy}, F1: {image_only_f1}')

Image Only ROC: 0.9571335873519634, Accuracy: 0.9612665232093452, F1: 0.6012658227848101
Image Only Accuracy: 0.9631109744850906, F1: 0.7196261682242989


### Text Only Model

In [17]:
# split text_embed into train and test based on brset_embed['split']
text_train = text_embed[train_idx]
text_val = text_embed[val_idx]
text_test = text_embed[test_idx]


In [18]:
text_train_dataset = SimpleDataset(text_train, y_train)
text_val_dataset = SimpleDataset(text_val, y_val)
text_test_dataset = SimpleDataset(text_test, y_test)

text_train_loader = DataLoader(text_train_dataset, batch_size=32, shuffle=True)
text_val_loader = DataLoader(text_val_dataset, batch_size=32, shuffle=False)
text_test_loader = DataLoader(text_test_dataset, batch_size=32, shuffle=False)

In [19]:
text_only_model = nn.Sequential(
    nn.Linear(4096, 256),
    nn.ReLU(),
    nn.Dropout(0.2),
    # nn.BatchNorm1d(256),
    nn.Linear(256, 1),
    # nn.Sigmoid()
)

In [20]:
# criterion = nn.BCEWithLogitsLoss()
criterion = BinaryFocalLoss(alpha=1-p1, gamma=2)
optimizer = optim.Adam(text_only_model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
history = train(text_only_model, text_train_loader, text_val_loader, criterion, optimizer, num_epochs=50, scheduler=scheduler)

Epoch 1/50, train loss: 0.0736, val loss: 0.0320, val auc: 0.9431, val accuracy: 0.9527, val f1: 0.4380, LR: 0.001
Epoch 2/50, train loss: 0.0348, val loss: 0.0323, val auc: 0.9517, val accuracy: 0.9619, val f1: 0.7328, LR: 0.001
Epoch 3/50, train loss: 0.0305, val loss: 0.0266, val auc: 0.9576, val accuracy: 0.9699, val f1: 0.7216, LR: 0.001
Epoch 4/50, train loss: 0.0269, val loss: 0.0243, val auc: 0.9544, val accuracy: 0.9773, val f1: 0.8093, LR: 0.001
Epoch 5/50, train loss: 0.0332, val loss: 0.0288, val auc: 0.9563, val accuracy: 0.9668, val f1: 0.6805, LR: 0.001
Epoch 6/50, train loss: 0.0318, val loss: 0.0358, val auc: 0.9591, val accuracy: 0.9647, val f1: 0.6326, LR: 0.001
Epoch 7/50, train loss: 0.0316, val loss: 0.0291, val auc: 0.9585, val accuracy: 0.9659, val f1: 0.6520, LR: 0.001
Epoch 8/50, train loss: 0.0326, val loss: 0.0331, val auc: 0.9596, val accuracy: 0.9628, val f1: 0.6134, LR: 0.001
Epoch 9/50, train loss: 0.0338, val loss: 0.0274, val auc: 0.9602, val accuracy:

In [21]:
# Evaluate text only model on test set
y_probs = get_probs(text_only_model, text_test_loader)
y_preds = (np.array(y_probs) > 0.5).astype(int)
text_only_roc = roc_auc_score(y_test.numpy(), y_probs)
text_only_accuracy = accuracy_score(y_test.numpy(), y_preds)
text_only_f1 = f1_score(y_test.numpy(), y_preds)
print(f'Text Only ROC: {text_only_roc}, Accuracy: {text_only_accuracy}, F1: {text_only_f1}')

threshold = get_optimal_f1_threshold(y_test.numpy(), y_probs)
y_preds = (np.array(y_probs) > threshold).astype(int)
text_only_accuracy = accuracy_score(y_test.numpy(), y_preds)
text_only_f1 = f1_score(y_test.numpy(), y_preds)
print(f'Text Only Accuracy: {text_only_accuracy}, F1: {text_only_f1}')

Text Only ROC: 0.9761734522853989, Accuracy: 0.9342145711650784, F1: 0.0
Text Only Accuracy: 0.9787888103289272, F1: 0.8345323741007193


### Simple Early Fusion Model

In [22]:
simple_early_fusion_model = nn.Sequential(
    nn.Linear(4096+1536, 512),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(512, 1),
    # nn.Sigmoid()
)

combined_train = torch.cat((text_train, image_train), dim=1)
combined_val = torch.cat((text_val, image_val), dim=1)
combined_test = torch.cat((text_test, image_test), dim=1)

combined_train_dataset = SimpleDataset(combined_train, y_train)
combined_val_dataset = SimpleDataset(combined_val, y_val)
combined_test_dataset = SimpleDataset(combined_test, y_test)

combined_train_loader = DataLoader(combined_train_dataset, batch_size=32, shuffle=True)
combined_val_loader = DataLoader(combined_val_dataset, batch_size=32, shuffle=False)
combined_test_loader = DataLoader(combined_test_dataset, batch_size=32, shuffle=False)

In [23]:
criterion = BinaryFocalLoss(alpha=1-p1, gamma=2)
optimizer = optim.Adam(simple_early_fusion_model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
history = train(simple_early_fusion_model, combined_train_loader, combined_val_loader, criterion, optimizer, num_epochs=50, scheduler=scheduler)

Epoch 1/50, train loss: 0.1764, val loss: 0.0254, val auc: 0.9623, val accuracy: 0.9662, val f1: 0.6584, LR: 0.001
Epoch 2/50, train loss: 0.0275, val loss: 0.0268, val auc: 0.9644, val accuracy: 0.9717, val f1: 0.7745, LR: 0.001
Epoch 3/50, train loss: 0.0257, val loss: 0.0259, val auc: 0.9679, val accuracy: 0.9616, val f1: 0.5847, LR: 0.001
Epoch 4/50, train loss: 0.0270, val loss: 0.0237, val auc: 0.9671, val accuracy: 0.9681, val f1: 0.6848, LR: 0.001
Epoch 5/50, train loss: 0.0284, val loss: 0.0298, val auc: 0.9673, val accuracy: 0.9647, val f1: 0.6302, LR: 0.001
Epoch 6/50, train loss: 0.0250, val loss: 0.0258, val auc: 0.9688, val accuracy: 0.9687, val f1: 0.6871, LR: 0.001
Epoch 7/50, train loss: 0.0244, val loss: 0.0251, val auc: 0.9648, val accuracy: 0.9671, val f1: 0.6667, LR: 0.001
Epoch 8/50, train loss: 0.0221, val loss: 0.0271, val auc: 0.9683, val accuracy: 0.9622, val f1: 0.7545, LR: 0.001
Epoch 9/50, train loss: 0.0260, val loss: 0.0191, val auc: 0.9705, val accuracy:

In [24]:
y_probs = get_probs(simple_early_fusion_model, combined_test_loader)
y_preds = (np.array(y_probs) > 0.5).astype(int)
early_fusion_roc = roc_auc_score(y_test.numpy(), y_probs)
early_fusion_accuracy = accuracy_score(y_test.numpy(), y_preds)
early_fusion_f1 = f1_score(y_test.numpy(), y_preds)
print(f'Early Fusion ROC: {early_fusion_roc}, Accuracy: {early_fusion_accuracy}, F1: {early_fusion_f1}')

threshold = get_optimal_f1_threshold(y_test.numpy(), y_probs)
y_preds = (np.array(y_probs) > threshold).astype(int)
early_fusion_accuracy = accuracy_score(y_test.numpy(), y_preds)
early_fusion_f1 = f1_score(y_test.numpy(), y_preds)
print(f'Early Fusion Accuracy: {early_fusion_accuracy}, F1: {early_fusion_f1}')

Early Fusion ROC: 0.9857199090945434, Accuracy: 0.9342145711650784, F1: 0.0
Early Fusion Accuracy: 0.9797110359667999, F1: 0.8382352941176471


### Simple Late Fusion Model

In [25]:
# Define separate modules for text and image processing
class TextModule(nn.Module):
    def __init__(self):
        super(TextModule, self).__init__()
        self.fc1 = nn.Linear(4096, 256)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return x

class ImageModule(nn.Module):
    def __init__(self):
        super(ImageModule, self).__init__()
        self.fc1 = nn.Linear(1536, 256)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return x

class SimpleLateFusionModel(nn.Module):
    def __init__(self):
        super(SimpleLateFusionModel, self).__init__()
        self.text_module = TextModule()
        self.image_module = ImageModule()
        self.fc1 = nn.Linear(512, 128)
        self.fc2 = nn.Linear(128, 1)
        # self.output = nn.Sigmoid()

    def forward(self, combined_data):
        text_data = combined_data[:, :4096]
        image_data = combined_data[:, 4096:]
        text_features = self.text_module(text_data)
        image_features = self.image_module(image_data)
        combined_features = torch.cat((text_features, image_features), dim=1)
        x = self.fc1(combined_features)
        x = torch.relu(x)
        x = self.fc2(x)
        # x = self.output(x)
        return x

late_fusion_model = SimpleLateFusionModel()

In [26]:
criterion = BinaryFocalLoss(alpha=1-p1, gamma=2)
optimizer = optim.Adam(late_fusion_model.parameters(), lr=0.0001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
history = train(late_fusion_model, combined_train_loader, combined_val_loader, criterion, optimizer, num_epochs=50, scheduler=scheduler)

Epoch 1/50, train loss: 0.0402, val loss: 0.0275, val auc: 0.9575, val accuracy: 0.9631, val f1: 0.6296, LR: 0.0001
Epoch 2/50, train loss: 0.0258, val loss: 0.0225, val auc: 0.9648, val accuracy: 0.9717, val f1: 0.7444, LR: 0.0001
Epoch 3/50, train loss: 0.0215, val loss: 0.0358, val auc: 0.9677, val accuracy: 0.9601, val f1: 0.5578, LR: 0.0001
Epoch 4/50, train loss: 0.0204, val loss: 0.0192, val auc: 0.9704, val accuracy: 0.9773, val f1: 0.8032, LR: 0.0001
Epoch 5/50, train loss: 0.0182, val loss: 0.0191, val auc: 0.9705, val accuracy: 0.9782, val f1: 0.8076, LR: 0.0001
Epoch 6/50, train loss: 0.0163, val loss: 0.0184, val auc: 0.9709, val accuracy: 0.9794, val f1: 0.8184, LR: 0.0001
Epoch 7/50, train loss: 0.0181, val loss: 0.0173, val auc: 0.9722, val accuracy: 0.9807, val f1: 0.8372, LR: 0.0001
Epoch 8/50, train loss: 0.0157, val loss: 0.0196, val auc: 0.9708, val accuracy: 0.9782, val f1: 0.8337, LR: 0.0001
Epoch 9/50, train loss: 0.0159, val loss: 0.0218, val auc: 0.9723, val a

In [27]:
y_probs = get_probs(late_fusion_model, combined_test_loader)
y_preds = (np.array(y_probs) > 0.5).astype(int)
late_fusion_roc = roc_auc_score(y_test.numpy(), y_probs)
late_fusion_accuracy = accuracy_score(y_test.numpy(), y_preds)
late_fusion_f1 = f1_score(y_test.numpy(), y_preds)
print(f'Late Fusion ROC: {late_fusion_roc}, Accuracy: {late_fusion_accuracy}, F1: {late_fusion_f1}')

threshold = get_optimal_f1_threshold(y_test.numpy(), y_probs)
y_preds = (np.array(y_probs) > threshold).astype(int)
late_fusion_accuracy = accuracy_score(y_test.numpy(), y_preds)
late_fusion_f1 = f1_score(y_test.numpy(), y_preds)
print(f'Late Fusion Accuracy: {late_fusion_accuracy}, F1: {late_fusion_f1}')

Late Fusion ROC: 0.9867116888548557, Accuracy: 0.984014755610206, F1: 0.8645833333333333
Late Fusion Accuracy: 0.9864740239778665, F1: 0.8916256157635468


### Attention Model

In [28]:
class ImageAttentionModule(nn.Module):
    def __init__(self):
        super(ImageAttentionModule, self).__init__()
        self.fc1 = nn.Linear(1536, 256)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        self.attention = nn.MultiheadAttention(embed_dim=16, num_heads=4, batch_first=True)

    def forward(self, x):
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.dropout(x)

        # Reshape x from [batch_size, 256] to [batch_size, 16, 16] for attention
        x = x.view(-1, 16, 16)

        # Apply attention
        attn_output, _ = self.attention(x, x, x)

        # Flatten the output for the final fully connected layer
        x = attn_output.reshape(-1, 256)  # Reshape back to original shape after attention
        return x

class TextAttentionModule(nn.Module):
    def __init__(self):
        super(TextAttentionModule, self).__init__()
        self.fc1 = nn.Linear(4096, 256)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        self.attention = nn.MultiheadAttention(embed_dim=16, num_heads=4, batch_first=True)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)

        x = x.view(-1, 16, 16)

        # Apply attention
        attn_output, _ = self.attention(x, x, x)

        # Flatten the output for the final fully connected layer
        x = attn_output.reshape(-1, 256)  # Reshape back to original shape after attention
        return x

class AttentionFusionModel(nn.Module):
    def __init__(self):
        super(AttentionFusionModel, self).__init__()
        self.text_attention = TextAttentionModule()
        self.image_attention = ImageAttentionModule()
        self.cross_attention = nn.MultiheadAttention(embed_dim=256, num_heads=4, batch_first=True)
        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, 1)
        # self.output = nn.Sigmoid()

    def forward(self, combined_data):
        text_data = combined_data[:, :4096]
        image_data = combined_data[:, 4096:]
        text_output = self.text_attention(text_data)
        image_output = self.image_attention(image_data)
        # combined_features = torch.cat((text_output, image_output), dim=1)
        combined_features, _ = self.cross_attention(text_output.unsqueeze(1), image_output.unsqueeze(1), image_output.unsqueeze(1))
        combined_features = combined_features.squeeze(1)
        x = self.fc1(combined_features)

        x = torch.relu(x)
 
        x = self.fc2(x)
        # x = self.output(x)
        return x
    
attention_fusion_model = AttentionFusionModel()

In [29]:
criterion = BinaryFocalLoss(alpha=1-p1, gamma=2)
optimizer = optim.Adam(attention_fusion_model.parameters(), lr=0.0001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
history = train(attention_fusion_model, combined_train_loader, combined_val_loader, criterion, optimizer, num_epochs=50, scheduler=scheduler)

Epoch 1/50, train loss: 0.0604, val loss: 0.0451, val auc: 0.8778, val accuracy: 0.9404, val f1: 0.1849, LR: 0.0001
Epoch 2/50, train loss: 0.0431, val loss: 0.0388, val auc: 0.9155, val accuracy: 0.9496, val f1: 0.4533, LR: 0.0001
Epoch 3/50, train loss: 0.0383, val loss: 0.0354, val auc: 0.9235, val accuracy: 0.9561, val f1: 0.5600, LR: 0.0001
Epoch 4/50, train loss: 0.0352, val loss: 0.0363, val auc: 0.9283, val accuracy: 0.9515, val f1: 0.4397, LR: 0.0001
Epoch 5/50, train loss: 0.0352, val loss: 0.0410, val auc: 0.9310, val accuracy: 0.9518, val f1: 0.4291, LR: 0.0001
Epoch 6/50, train loss: 0.0344, val loss: 0.0348, val auc: 0.9347, val accuracy: 0.9561, val f1: 0.6416, LR: 0.0001
Epoch 7/50, train loss: 0.0321, val loss: 0.0337, val auc: 0.9394, val accuracy: 0.9570, val f1: 0.5238, LR: 0.0001
Epoch 8/50, train loss: 0.0319, val loss: 0.0306, val auc: 0.9391, val accuracy: 0.9628, val f1: 0.6721, LR: 0.0001
Epoch 9/50, train loss: 0.0310, val loss: 0.0306, val auc: 0.9420, val a

In [30]:
y_probs = get_probs(attention_fusion_model, combined_test_loader)
y_preds = (np.array(y_probs) > 0.5).astype(int)
attention_fusion_roc = roc_auc_score(y_test.numpy(), y_probs)
attention_fusion_accuracy = accuracy_score(y_test.numpy(), y_preds)
attention_fusion_f1 = f1_score(y_test.numpy(), y_preds)
print(f'Attention Fusion ROC: {attention_fusion_roc}, Accuracy: {attention_fusion_accuracy}, F1: {attention_fusion_f1}')

threshold = get_optimal_f1_threshold(y_test.numpy(), y_probs)
y_preds = (np.array(y_probs) > threshold).astype(int)
attention_fusion_accuracy = accuracy_score(y_test.numpy(), y_preds)
attention_fusion_f1 = f1_score(y_test.numpy(), y_preds)
print(f'Attention Fusion Accuracy: {attention_fusion_accuracy}, F1: {attention_fusion_f1}')

Attention Fusion ROC: 0.9535139756375838, Accuracy: 0.9624961573931755, F1: 0.6369047619047619
Attention Fusion Accuracy: 0.9674146941284968, F1: 0.7282051282051282
