In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Batch
from torch_geometric.utils import to_dense_batch
from transformers import AutoModel

from src.dataset.astroturf_dataset import AstroturfCampaignMultiModalDataset
from src.modules.graph_encoder import UPFDGraphSageNet

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class CrossModelAttentionBlock(nn.Module):
    def __init__(self, embedding_dim: int, num_heads: int, feed_forward_dim: int):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.mha = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=num_heads, batch_first=True)
        self.ff_1 = nn.Linear(in_features=embedding_dim, out_features=feed_forward_dim)
        self.ff_2 = nn.Linear(in_features=feed_forward_dim, out_features=embedding_dim)
        self.norm = nn.LayerNorm(embedding_dim)

    def forward(self, embedding_a, embedding_b):
        mha_out, _ = self.mha(embedding_a, embedding_b, embedding_b)
        out = F.relu(self.ff_1(mha_out))
        out = self.ff_2(out)
        final_out = self.norm(out + mha_out)
        return final_out


class MultiModalModelForClassification(nn.Module):
    def __init__(self,
                 text_encoder: nn.Module,
                 graph_encoder: nn.Module,
                 vision_encoder: nn.Module,
                 self_attention_heads: int,
                 embedding_dim: int,
                 num_cross_modal_attention_blocks: int,
                 num_cross_modal_attention_heads: int,
                 self_attn_ff_dim: int,
                 num_cross_modal_attention_ff_dim: int,
                 output_channels: int):
        super().__init__()

        # Use the provided encoders and freeze them for PEFT.
        self.text_encoder = text_encoder
        for param in self.text_encoder.parameters():
            param.requires_grad = False

        self.graph_encoder = graph_encoder
        for param in self.graph_encoder.parameters():
            param.requires_grad = False

        self.vision_encoder = vision_encoder
        for param in self.vision_encoder.parameters():
            param.requires_grad = False

        # Assuming the text encoder has a config with hidden_size.
        self.text_embedding_size = self.text_encoder.config.hidden_size

        # Assuming the vision encoder (Vision transformer) has a config with hidden_size.
        self.vision_embedding_size = self.vision_encoder.config.hidden_size

        self.embedding_dim = embedding_dim

        ############ PROJECTION ############
        self.text_projection = nn.Linear(in_features=self.text_embedding_size, out_features=embedding_dim)
        # Adjust the in_features for the graph projection if needed.
        self.graph_projection = nn.Linear(in_features=graph_encoder.hidden_channels, out_features=embedding_dim)

        self.vision_projection = nn.Linear(in_features=self.vision_embedding_size, out_features=embedding_dim)

        ############ SELF ATTENTION ############
        self.text_self_attention = nn.MultiheadAttention(embed_dim=embedding_dim,
                                                         num_heads=self_attention_heads,
                                                         batch_first=True)
        self.graph_self_attention = nn.MultiheadAttention(embed_dim=embedding_dim,
                                                          num_heads=self_attention_heads,
                                                          batch_first=True)
        self.vision_self_attention = nn.MultiheadAttention(embed_dim=embedding_dim,
                                                           num_heads=self_attention_heads,
                                                           batch_first=True)
        self.text_self_attention_norm = nn.LayerNorm(embedding_dim)
        self.graph_self_attention_norm = nn.LayerNorm(embedding_dim)
        self.vision_self_attention_norm = nn.LayerNorm(embedding_dim)
        self.text_self_attention_ff1 = nn.Linear(in_features=embedding_dim, out_features=self_attn_ff_dim)
        self.text_self_attention_ff2 = nn.Linear(in_features=self_attn_ff_dim, out_features=embedding_dim)

        self.graph_self_attention_ff1 = nn.Linear(in_features=embedding_dim, out_features=self_attn_ff_dim)
        self.graph_self_attention_ff2 = nn.Linear(in_features=self_attn_ff_dim, out_features=embedding_dim)

        self.vision_self_attention_ff1 = nn.Linear(in_features=embedding_dim, out_features=self_attn_ff_dim)
        self.vision_self_attention_ff2 = nn.Linear(in_features=self_attn_ff_dim, out_features=embedding_dim)

        self.text_self_attention_ff_norm = nn.LayerNorm(embedding_dim)
        self.graph_self_attention_ff_norm = nn.LayerNorm(embedding_dim)
        self.vision_self_attention_ff_norm = nn.LayerNorm(embedding_dim)

        ############ CROSS MODAL ATTENTION ############
        self.cross_modal_attention_blocks = nn.ModuleList([
            CrossModelAttentionBlock(embedding_dim=embedding_dim,
                                     num_heads=num_cross_modal_attention_heads,
                                     feed_forward_dim=num_cross_modal_attention_ff_dim)
            for _ in range(num_cross_modal_attention_blocks)
        ])

        ############ OUTPUT LAYER ############

        # Gated Fusion
        self.gate_fc = nn.Linear(embedding_dim * 3, 3)
        self.post_fusion_norm = nn.LayerNorm(embedding_dim)
        self.classifier = nn.Linear(embedding_dim, output_channels)

    def forward(self, text_input_ids, text_attention_mask, graph_data, pixel_values):
        text_embedding = self.text_encoder(input_ids=text_input_ids, attention_mask=text_attention_mask)[0]
        _, node_embeddings, _ = self.graph_encoder(graph_data.x, graph_data.edge_index, graph_data.batch)
        dense_graph_embeddings, mask = to_dense_batch(node_embeddings, graph_data.batch)
        vision_embedding = self.vision_encoder(pixel_values=pixel_values).last_hidden_state

        ############ PROJECTION ############
        projected_text_embedding = self.text_projection(text_embedding)
        projected_graph_embedding = self.graph_projection(dense_graph_embeddings)
        projected_vision_embedding = self.vision_projection(vision_embedding)

        ############ SELF ATTENTION ############
        text_self_attn_out, _ = self.text_self_attention(projected_text_embedding,
                                                         projected_text_embedding,
                                                         projected_text_embedding)
        graph_self_attn_out, _ = self.graph_self_attention(projected_graph_embedding,
                                                           projected_graph_embedding,
                                                           projected_graph_embedding,
                                                           key_padding_mask=~mask)
        vision_self_attn_out, _ = self.vision_self_attention(projected_vision_embedding,
                                                             projected_vision_embedding,
                                                             projected_vision_embedding)
        text_self_attn_out = self.text_self_attention_norm(text_self_attn_out + projected_text_embedding)
        graph_self_attn_out = self.graph_self_attention_norm(graph_self_attn_out + projected_graph_embedding)
        vision_self_attn_out = self.vision_self_attention_norm(vision_self_attn_out + projected_vision_embedding)

        text_ff_out = F.relu(self.text_self_attention_ff1(text_self_attn_out))
        graph_ff_out = F.relu(self.graph_self_attention_ff1(graph_self_attn_out))
        vision_ff_out = F.relu(self.vision_self_attention_ff1(vision_self_attn_out))
        text_ff_out = self.text_self_attention_ff2(text_ff_out)
        graph_ff_out = self.graph_self_attention_ff2(graph_ff_out)
        vision_ff_out = self.vision_self_attention_ff2(vision_ff_out)
        text_ff_out = self.text_self_attention_ff_norm(text_self_attn_out + text_ff_out)
        graph_ff_out = self.graph_self_attention_ff_norm(graph_self_attn_out + graph_ff_out)
        vision_ff_out = self.vision_self_attention_ff_norm(vision_self_attn_out + vision_ff_out)

        ############ CROSS MODAL ATTENTION ############
        projected_text_embedding, projected_graph_embedding, projected_vision_embedding = text_ff_out, graph_ff_out, vision_ff_out
        for block in self.cross_modal_attention_blocks:
            projected_text_embedding_new = block(projected_text_embedding, projected_graph_embedding)
            projected_graph_embedding_new = block(projected_graph_embedding, projected_text_embedding)
            projected_text_embedding_new = self.text_self_attention_ff_norm(
                projected_text_embedding + projected_text_embedding_new)
            projected_text_embedding, projected_graph_embedding, projected_vision_embedding = (
                projected_text_embedding_new,
                projected_graph_embedding_new,
                projected_vision_embedding)

        ############ OUTPUT LAYER ############
        global_text_embedding = torch.mean(projected_text_embedding, dim=1)
        global_graph_embedding = torch.mean(projected_graph_embedding, dim=1)
        global_vision_embedding = torch.mean(projected_vision_embedding, dim=1)
        gated_out = self.gate_fc(
            torch.cat((global_text_embedding, global_graph_embedding, global_vision_embedding), dim=-1))
        gates = F.softmax(gated_out, dim=-1)
        alpha, beta, gamma = gates[:, 0:1], gates[:, 1:2], gates[:, 2:3]
        fused_embedding = (alpha * global_text_embedding) + (beta * global_graph_embedding) + (
                gamma * global_vision_embedding)
        fused_embedding = self.post_fusion_norm(fused_embedding)

        logits = self.classifier(fused_embedding)
        return logits

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps")

In [4]:
astrorag_dataset = AstroturfCampaignMultiModalDataset(
    json_dir='/Users/navneet/git/research/swarm-guard/dataset1/train/graphs',
    image_dir='/Users/navneet/git/research/swarm-guard/dataset1/train/images',
    text_model_id='answerdotai/ModernBERT-base',
    vision_model_id='google/vit-base-patch16-224',
)

In [5]:
# Specify the model name for text tokenization
text_model_name = "answerdotai/ModernBERT-base"

# Instantiate the dataset
# dataset = DummyMultiModalDataset(num_samples=10, text_model_name=text_model_name)

# Get one sample from the dataset
sample, label = astrorag_dataset[0]

print("Text Input IDs:", sample['text_input_ids'])
print("Text Attention Mask:", sample['text_attention_mask'])
print("Graph Data - x shape:", sample['graph_data'].x.shape)
print("Graph Data - edge_index shape:", sample['graph_data'].edge_index.shape)
# print("Graph Data - batch shape:", sample['graph_data'].batch.shape)
print("Graph Data - y shape:", sample['graph_data'].y.shape)
print("Vision Pixel Values:", sample['pixel_values'])

Text Input IDs: tensor([50281, 31600,   212,   171,   118,   226, 33104,   219, 35458,   123,
          168,   216,   224,   167,   228,   227, 49264, 39907,   169,   226,
          110,   168,    99,   215, 26532,   227, 26532,   239,  7775,   222,
           99,   168,  9223, 16857,   113,   115, 26532,   239, 31600,   212,
          171,   118,   226,   187,   187,   158,   239,   231, 32115,   216,
        36178,   219,   169,   226,   122,    27,  5987,  1358,    85,    15,
         1940,    16,    41,    87, 22351,    45, 13511,    19,    41,    79,
          187,   158,   239,   231, 10608,   107,   118, 34123,   168,  9223,
        15074,   244,    27,   288,  2140,    18,   296,    15,  2913,    33,
         5987,  1358,    85,    15,  1940,    16, 20723,    25,   304, 18933,
         2042,    69,   187,   187, 14931,   225,   219, 28774,   107, 34817,
        35061,  7775,   217,   223,   167,   242,   220, 10608,   213,   103,
        33104,   215, 28774,   219, 44633,   122

In [7]:
sample['pixel_values'].shape

torch.Size([3, 224, 224])

In [8]:
## {"in_channels": 9, "hidden_channels": 64, "num_classes": 2, "dropout": 0.2954021195697293, "lr": 0.0015804240267104938, "weight_decay": 7.64927591337679e-06, "batch_size": 128, "epochs": 200, "focal_alpha": 0.3081135417724518, "focal_gamma": 1.3936990483465523}
def load_pre_trained_graph_encoder(model_path: str, device: str = "cpu") -> UPFDGraphSageNet:
    model_file = torch.load(model_path)
    state_dict = model_file['model_state_dict']
    config = model_file['config']
    model = UPFDGraphSageNet(
        in_channels=config['in_channels'],
        hidden_channels=config['hidden_channels'],
        num_classes=config['num_classes'],
    )
    model.load_state_dict(state_dict)
    model = model.to(device)
    print(f'MOdel loaded with hidden channels: {model.hidden_channels}')
    return model


load_pre_trained_graph_encoder(
    model_path='/Users/navneet/git/research/swarm-guard/models/graph/graph_encoder.pth')

MOdel loaded with hidden channels: 64


UPFDGraphSageNet(
  (conv1): SAGEConv(9, 64, aggr=mean)
  (norm1): LayerNorm(64, affine=True, mode=graph)
  (conv2): SAGEConv(64, 64, aggr=mean)
  (norm2): LayerNorm(64, affine=True, mode=graph)
  (conv3): SAGEConv(64, 64, aggr=mean)
  (norm3): LayerNorm(64, affine=True, mode=graph)
  (classifier): Linear(in_features=64, out_features=2, bias=True)
)

In [9]:
# Number of parameters in the model in millions
num_params = sum(p.numel() for p in load_pre_trained_graph_encoder(
    model_path='/Users/navneet/git/research/swarm-guard/models/graph/graph_encoder.pth').parameters())
print(f"Number of parameters in the model: {num_params / 1e6:.2f}M")

MOdel loaded with hidden channels: 64
Number of parameters in the model: 0.02M


In [10]:
text_encoder = AutoModel.from_pretrained('answerdotai/ModernBERT-base').to(device)
graph_encoder = load_pre_trained_graph_encoder(
    model_path='/Users/navneet/git/research/swarm-guard/models/graph/graph_encoder.pth',
    device=device
)
vision_encoder = AutoModel.from_pretrained('google/vit-base-patch16-224').to(device)
model = MultiModalModelForClassification(
    text_encoder=text_encoder,
    graph_encoder=graph_encoder,
    vision_encoder=vision_encoder,
    self_attention_heads=8,
    embedding_dim=256,
    num_cross_modal_attention_blocks=6,
    num_cross_modal_attention_heads=8,
    self_attn_ff_dim=512,
    num_cross_modal_attention_ff_dim=512,
    output_channels=2
).to(device)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


MOdel loaded with hidden channels: 64


In [11]:
# Print the number of parameters in the model
num_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters in the model: {num_params / 1e6:.2f}M")

Number of parameters in the model: 240.58M


In [13]:
# forward pass with the sample
sample, y = astrorag_dataset[0]
text_input_ids = sample['text_input_ids'].unsqueeze(0)
text_attention_mask = sample['text_attention_mask'].unsqueeze(0)
graph_data = sample['graph_data']
#Move the data to the same device as the model
text_input_ids = text_input_ids.to(device)
text_attention_mask = text_attention_mask.to(device)
graph_data.x = graph_data.x.to(device)
graph_data.edge_index = graph_data.edge_index.to(device)
pixel_values = sample['pixel_values'].unsqueeze(0).to(device)
# graph_data.batch = graph_data.batch.to(device)
# Perform a forward pass
model(text_input_ids, text_attention_mask, graph_data, pixel_values)

tensor([[-0.2046,  0.2821]], device='mps:0', grad_fn=<LinearBackward0>)

In [14]:
# print the number of parameters in the model in millions
num_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters in the model: {num_params / 1e6:.2f}M")

# print the number of parameters in the model in millions excluding the text encoder and graph encoder
num_params_excluding_encoders = sum(p.numel() for name, p in model.named_parameters() if
                                    'text_encoder' not in name and 'graph_encoder' not in name)
print(f"Number of parameters in the model excluding encoders: {num_params_excluding_encoders / 1e6:.2f}M")

Number of parameters in the model: 240.58M
Number of parameters in the model excluding encoders: 91.54M


In [15]:
def multimodal_collate_fn(batch):
    # Unzip batch elements into data dictionaries and corresponding labels
    data_dicts, labels = zip(*batch)

    text_input_ids = torch.stack([d['text_input_ids'] for d in data_dicts], dim=0)
    text_attention_mask = torch.stack([d['text_attention_mask'] for d in data_dicts], dim=0)

    pixel_values = torch.stack([d['pixel_values'] for d in data_dicts], dim=0)

    # Create a batched graph using GeoBatch.from_data_list
    graph_data = Batch.from_data_list([d['graph_data'] for d in data_dicts])

    # Convert labels tuple (of ints) into a tensor.
    labels = torch.tensor(labels)

    return {
        'text_input_ids': text_input_ids,
        'text_attention_mask': text_attention_mask,
        'graph_data': graph_data,
        'pixel_values': pixel_values,
        'labels': labels
    }


crieterion = torch.nn.CrossEntropyLoss()
data_loader = torch.utils.data.DataLoader(astrorag_dataset, batch_size=2, shuffle=True,
                                          collate_fn=multimodal_collate_fn)
# iterate through the data loader taking both the features and labels
for batch in data_loader:
    text_input_ids = batch['text_input_ids']
    text_attention_mask = batch['text_attention_mask']
    graph_data = batch['graph_data']
    labels = batch['labels']
    pixel_values = batch['pixel_values']

    # Move the data to the same device as the model
    text_input_ids = text_input_ids.to(device)
    text_attention_mask = text_attention_mask.to(device)
    graph_data.x = graph_data.x.to(device)
    graph_data.edge_index = graph_data.edge_index.to(device)
    graph_data.batch = graph_data.batch.to(device)
    pixel_values = pixel_values.to(device)
    labels = labels.to(device)
    print(f"Pixel Values Shape: {pixel_values.shape}")
    print(f"Text Input IDs Shape: {text_input_ids.shape}")
    output = model(text_input_ids, text_attention_mask, graph_data, pixel_values)
    loss = crieterion(output, labels)

    print("Output shape:", output.shape)
    print("Labels shape:", labels.shape)
    print("Loss:", loss.item())
    break

Pixel Values Shape: torch.Size([2, 3, 224, 224])
Text Input IDs Shape: torch.Size([2, 280])
Output shape: torch.Size([2, 2])
Labels shape: torch.Size([2])
Loss: 0.5239043831825256
