In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch_geometric.data import Batch as GeoBatch
from torch_geometric.data import Data
from torch_geometric.utils import to_dense_batch
from transformers import AutoTokenizer, AutoModel

from src.dataset.astrorag_dataset import AstroturfCampaignMultiModalDataset
from src.modules.graph_encoder import UPFDGraphSageNet

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps")

In [3]:
astrorag_dataset = AstroturfCampaignMultiModalDataset(
    json_dir='/Users/navneet/git/research/brag-fake-news-campaigns/dataset1/train',
    model_id='answerdotai/ModernBERT-base')

In [104]:
class DummyMultiModalDataset(Dataset):
    def __init__(self,
                 num_samples: int = 100,
                 text_length: int = 20,
                 text_model_name: str = "bert-base-uncased"):
        """
        Args:
            num_samples (int): Number of samples in the dataset.
            text_length (int): Maximum token length for the text input.
            text_model_name (str): Hugging Face model name used to initialize the tokenizer.
        """
        self.num_samples = num_samples
        self.text_length = text_length

        # Initialize the tokenizer from Hugging Face
        self.tokenizer = AutoTokenizer.from_pretrained(text_model_name)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx: int):
        # -----------
        # Graph Data
        # -----------
        # Create node features: 72 nodes each with 10 features, shape [72, 10]
        x = torch.randn(72, 10, dtype=torch.float)

        # Create edge_index: a tensor with shape [2, 71] representing 71 random edges
        edge_index = torch.randint(0, 72, (2, 71), dtype=torch.long)

        # Create a dummy label tensor (for binary classification) with shape [1]
        y = torch.randint(0, 2, (1,), dtype=torch.long)

        # Create a batch tensor: for a single graph, all nodes have the same batch index, here 0.
        batch = torch.zeros(72, dtype=torch.long)

        # Create a PyTorch Geometric Data object for the graph
        graph_data = Data(x=x, edge_index=edge_index, y=y, batch=batch)

        # -----------
        # Text Data
        # -----------
        # Create a dummy text sample (with index for variability)
        dummy_text = f"This is a dummy sentence number {idx} for testing multimodal input."

        # Tokenize using the specified tokenizer, padding/truncating to self.text_length
        tokenized = self.tokenizer(dummy_text,
                                   max_length=self.text_length,
                                   padding='max_length',
                                   truncation=True,
                                   return_tensors='pt')

        # Remove the extra batch dimension from the tokenized outputs; final shape is [text_length]
        text_input_ids = tokenized['input_ids'].squeeze(0)
        text_attention_mask = tokenized['attention_mask'].squeeze(0)

        # Return the data and label (convert label to a scalar using .item())
        return {
            'text_input_ids': text_input_ids,  # Shape: [text_length]
            'text_attention_mask': text_attention_mask,  # Shape: [text_length]
            'graph_data': graph_data  # Graph Data object with x, edge_index, y, batch
        }, y.item()  # Return the graph label as a scalar integer


In [4]:
# Specify the model name for text tokenization
text_model_name = "answerdotai/ModernBERT-base"

# Instantiate the dataset
# dataset = DummyMultiModalDataset(num_samples=10, text_model_name=text_model_name)

# Get one sample from the dataset
sample, label = astrorag_dataset[0]

print("Text Input IDs:", sample['text_input_ids'])
print("Text Attention Mask:", sample['text_attention_mask'])
print("Graph Data - x shape:", sample['graph_data'].x.shape)
print("Graph Data - edge_index shape:", sample['graph_data'].edge_index.shape)
# print("Graph Data - batch shape:", sample['graph_data'].batch.shape)
print("Graph Data - y shape:", sample['graph_data'].y.shape)

Text Input IDs: tensor([50281, 12442,   267, 17680,  3551,   281,  1214, 39596, 20671,  1996,
          436,   807,   432,  1214, 25989,   387, 23556,  9151,  2418,   273,
          253,  6398,  5987,  1358,    85,    15,  1940,    16,    52,  2598,
        18933,    44,    18,    54,    58,    53, 50282, 50283, 50283, 50283,
        50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
        50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
        50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
        50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
        50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
        50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
        50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
        50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
        50283, 50283, 50283, 50283, 50283, 50283

In [5]:
## {"in_channels": 9, "hidden_channels": 64, "num_classes": 2, "dropout": 0.2954021195697293, "lr": 0.0015804240267104938, "weight_decay": 7.64927591337679e-06, "batch_size": 128, "epochs": 200, "focal_alpha": 0.3081135417724518, "focal_gamma": 1.3936990483465523}
def load_pre_trained_graph_encoder(model_path: str, device: str = "cpu") -> UPFDGraphSageNet:
    model_file = torch.load(model_path)
    state_dict = model_file['model_state_dict']
    config = model_file['config']
    model = UPFDGraphSageNet(
        in_channels=config['in_channels'],
        hidden_channels=config['hidden_channels'],
        num_classes=config['num_classes'],
    )
    model.load_state_dict(state_dict)
    model = model.to(device)
    print(f'MOdel loaded with hidden channels: {model.hidden_channels}')
    return model


load_pre_trained_graph_encoder(
    model_path='/Users/navneet/git/research/swarm-guard/models/graph/graph_encoder.pth')

MOdel loaded with hidden channels: 64


UPFDGraphSageNet(
  (conv1): SAGEConv(9, 64, aggr=mean)
  (norm1): LayerNorm(64, affine=True, mode=graph)
  (conv2): SAGEConv(64, 64, aggr=mean)
  (norm2): LayerNorm(64, affine=True, mode=graph)
  (conv3): SAGEConv(64, 64, aggr=mean)
  (norm3): LayerNorm(64, affine=True, mode=graph)
  (classifier): Linear(in_features=64, out_features=2, bias=True)
)

In [28]:
# Number of parameters in the model in millions
num_params = sum(p.numel() for p in load_pre_trained_graph_encoder(
    model_path='/Users/navneet/git/research/swarm-guard/models/graph/graph_encoder.pth').parameters())
print(f"Number of parameters in the model: {num_params / 1e6:.2f}M")

Number of parameters in the model: 0.02M


In [6]:
class CrossModelAttentionBlock(nn.Module):
    def __init__(self, embedding_dim: int, num_heads: int, feed_forward_dim: int):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.mha_text_graph = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=num_heads, batch_first=True)
        self.mha_graph_text = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=num_heads, batch_first=True)

        self.ff_text_graph_1 = nn.Linear(in_features=embedding_dim, out_features=feed_forward_dim)
        self.ff_graph_text_1 = nn.Linear(in_features=embedding_dim, out_features=feed_forward_dim)

        self.ff_text_graph_2 = nn.Linear(in_features=feed_forward_dim, out_features=embedding_dim)
        self.ff_graph_text_2 = nn.Linear(in_features=feed_forward_dim, out_features=embedding_dim)

        self.text_graph_norm = nn.LayerNorm(embedding_dim)
        self.graph_text_norm = nn.LayerNorm(embedding_dim)

    def forward(self, text_embedding, graph_embedding):
        # Cross-attention: text queries attend to graph keys/values, and vice versa.
        mha_text_graph_out, _ = self.mha_text_graph(text_embedding, graph_embedding, graph_embedding)
        mha_graph_text_out, _ = self.mha_graph_text(graph_embedding, text_embedding, text_embedding)

        text_graph_out = F.relu(self.ff_text_graph_1(mha_text_graph_out))
        graph_text_out = F.relu(self.ff_graph_text_1(mha_graph_text_out))

        text_graph_out = self.ff_text_graph_2(text_graph_out)
        graph_text_out = self.ff_graph_text_2(graph_text_out)

        text_out = self.text_graph_norm(text_graph_out + mha_text_graph_out)
        graph_out = self.graph_text_norm(graph_text_out + mha_graph_text_out)

        return text_out, graph_out


class MultiModalModelForClassification(nn.Module):
    def __init__(self,
                 text_encoder: nn.Module,
                 graph_encoder: nn.Module,
                 self_attention_heads: int,
                 embedding_dim: int,
                 num_cross_modal_attention_blocks: int,
                 num_cross_modal_attention_heads: int,
                 self_attn_ff_dim: int,
                 num_cross_modal_attention_ff_dim: int,
                 output_channels: int):
        super().__init__()

        # Use the provided encoders and freeze them for PEFT.
        self.text_encoder = text_encoder
        for param in self.text_encoder.parameters():
            param.requires_grad = False

        self.graph_encoder = graph_encoder
        for param in self.graph_encoder.parameters():
            param.requires_grad = False

        # Assuming the text encoder has a config with hidden_size.
        self.text_embedding_size = self.text_encoder.config.hidden_size
        self.embedding_dim = embedding_dim

        ############ PROJECTION ############
        self.text_projection = nn.Linear(in_features=self.text_embedding_size, out_features=embedding_dim)
        # Adjust the in_features for the graph projection if needed.
        self.graph_projection = nn.Linear(in_features=graph_encoder.hidden_channels, out_features=embedding_dim)

        ############ SELF ATTENTION ############
        self.text_self_attention = nn.MultiheadAttention(embed_dim=embedding_dim,
                                                         num_heads=self_attention_heads,
                                                         batch_first=True)
        self.graph_self_attention = nn.MultiheadAttention(embed_dim=embedding_dim,
                                                          num_heads=self_attention_heads,
                                                          batch_first=True)
        self.text_self_attention_norm = nn.LayerNorm(embedding_dim)
        self.graph_self_attention_norm = nn.LayerNorm(embedding_dim)
        self.text_self_attention_ff1 = nn.Linear(in_features=embedding_dim, out_features=self_attn_ff_dim)
        self.text_self_attention_ff2 = nn.Linear(in_features=self_attn_ff_dim, out_features=embedding_dim)

        self.graph_self_attention_ff1 = nn.Linear(in_features=embedding_dim, out_features=self_attn_ff_dim)
        self.graph_self_attention_ff2 = nn.Linear(in_features=self_attn_ff_dim, out_features=embedding_dim)

        self.text_self_attention_ff_norm = nn.LayerNorm(embedding_dim)
        self.graph_self_attention_ff_norm = nn.LayerNorm(embedding_dim)

        ############ CROSS MODAL ATTENTION ############
        self.cross_modal_attention_blocks = nn.ModuleList([
            CrossModelAttentionBlock(embedding_dim=embedding_dim,
                                     num_heads=num_cross_modal_attention_heads,
                                     feed_forward_dim=num_cross_modal_attention_ff_dim)
            for _ in range(num_cross_modal_attention_blocks)
        ])

        ############ OUTPUT LAYER ############
        self.output_pre_norm = nn.LayerNorm(embedding_dim * 2)
        self.output_ff = nn.Linear(embedding_dim * 2, output_channels)

    def forward(self, text_input_ids, text_attention_mask, graph_data):
        text_embedding = self.text_encoder(input_ids=text_input_ids, attention_mask=text_attention_mask)[0]
        _, node_embeddings, _ = self.graph_encoder(graph_data.x, graph_data.edge_index, graph_data.batch)
        dense_graph_embeddings, mask = to_dense_batch(node_embeddings, graph_data.batch)

        ############ PROJECTION ############
        projected_text_embedding = self.text_projection(text_embedding)
        projected_graph_embedding = self.graph_projection(dense_graph_embeddings)

        ############ SELF ATTENTION ############
        text_self_attn_out, _ = self.text_self_attention(projected_text_embedding,
                                                         projected_text_embedding,
                                                         projected_text_embedding)
        graph_self_attn_out, _ = self.graph_self_attention(projected_graph_embedding,
                                                           projected_graph_embedding,
                                                           projected_graph_embedding,
                                                           key_padding_mask=~mask)
        text_self_attn_out = self.text_self_attention_norm(text_self_attn_out + projected_text_embedding)
        graph_self_attn_out = self.graph_self_attention_norm(graph_self_attn_out + projected_graph_embedding)

        text_ff_out = F.relu(self.text_self_attention_ff1(text_self_attn_out))
        graph_ff_out = F.relu(self.graph_self_attention_ff1(graph_self_attn_out))
        text_ff_out = self.text_self_attention_ff2(text_ff_out)
        graph_ff_out = self.graph_self_attention_ff2(graph_ff_out)
        text_ff_out = self.text_self_attention_ff_norm(text_self_attn_out + text_ff_out)
        graph_ff_out = self.graph_self_attention_ff_norm(graph_self_attn_out + graph_ff_out)

        ############ CROSS MODAL ATTENTION ############
        projected_text_embedding, projected_graph_embedding = text_ff_out, graph_ff_out
        for block in self.cross_modal_attention_blocks:
            projected_text_embedding, projected_graph_embedding = block(projected_text_embedding,
                                                                        projected_graph_embedding)

        ############ OUTPUT LAYER ############
        global_text_embedding = torch.mean(projected_text_embedding, dim=1)
        global_graph_embedding = torch.mean(projected_graph_embedding, dim=1)
        combined_embedding = torch.cat((global_text_embedding, global_graph_embedding), dim=-1)
        combined_embedding = self.output_pre_norm(combined_embedding)
        output = self.output_ff(combined_embedding)
        return output

In [7]:
text_encoder = AutoModel.from_pretrained('answerdotai/ModernBERT-base').to(device)
graph_encoder = load_pre_trained_graph_encoder(
    model_path='/Users/navneet/git/research/swarm-guard/models/graph/graph_encoder.pth',
    device=device
)
model = MultiModalModelForClassification(
    text_encoder=text_encoder,
    graph_encoder=graph_encoder,
    self_attention_heads=8,
    embedding_dim=256,
    num_cross_modal_attention_blocks=6,
    num_cross_modal_attention_heads=8,
    self_attn_ff_dim=512,
    num_cross_modal_attention_ff_dim=512,
    output_channels=2
).to(device)

MOdel loaded with hidden channels: 64


In [32]:
# forward pass with the sample
sample, y = astrorag_dataset[0]
text_input_ids = sample['text_input_ids'].unsqueeze(0)
text_attention_mask = sample['text_attention_mask'].unsqueeze(0)
graph_data = sample['graph_data']
#Move the data to the same device as the model
text_input_ids = text_input_ids.to(device)
text_attention_mask = text_attention_mask.to(device)
graph_data.x = graph_data.x.to(device)
graph_data.edge_index = graph_data.edge_index.to(device)
graph_data.batch = graph_data.batch.to(device)
# Perform a forward pass
model(text_input_ids, text_attention_mask, graph_data)

AttributeError: 'NoneType' object has no attribute 'to'

In [33]:
# print the number of parameters in the model in millions
num_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters in the model: {num_params / 1e6:.2f}M")

# print the number of parameters in the model in millions excluding the text encoder and graph encoder
num_params_excluding_encoders = sum(p.numel() for name, p in model.named_parameters() if
                                    'text_encoder' not in name and 'graph_encoder' not in name)
print(f"Number of parameters in the model excluding encoders: {num_params_excluding_encoders / 1e6:.2f}M")

Number of parameters in the model: 156.67M
Number of parameters in the model excluding encoders: 7.64M


In [8]:
def multimodal_collate_fn(batch):
    # Unzip batch elements into data dictionaries and corresponding labels
    data_dicts, labels = zip(*batch)

    text_input_ids = torch.stack([d['text_input_ids'] for d in data_dicts], dim=0)
    text_attention_mask = torch.stack([d['text_attention_mask'] for d in data_dicts], dim=0)

    # Create a batched graph using GeoBatch.from_data_list
    graph_data = GeoBatch.from_data_list([d['graph_data'] for d in data_dicts])

    # Convert labels tuple (of ints) into a tensor.
    labels = torch.tensor(labels)

    return {
        'text_input_ids': text_input_ids,
        'text_attention_mask': text_attention_mask,
        'graph_data': graph_data,
        'labels': labels
    }


crieterion = torch.nn.CrossEntropyLoss()
data_loader = torch.utils.data.DataLoader(astrorag_dataset, batch_size=2, shuffle=True,
                                          collate_fn=multimodal_collate_fn)
# iterate through the data loader taking both the features and labels
for batch in data_loader:
    text_input_ids = batch['text_input_ids']
    text_attention_mask = batch['text_attention_mask']
    graph_data = batch['graph_data']
    labels = batch['labels']

    # Move the data to the same device as the model
    text_input_ids = text_input_ids.to(device)
    text_attention_mask = text_attention_mask.to(device)
    graph_data.x = graph_data.x.to(device)
    graph_data.edge_index = graph_data.edge_index.to(device)
    graph_data.batch = graph_data.batch.to(device)
    labels = labels.to(device)

    output = model(text_input_ids, text_attention_mask, graph_data)
    loss = crieterion(output, labels)

    print("Output shape:", output.shape)
    print("Labels shape:", labels.shape)
    print("Loss:", loss.item())


Output shape: torch.Size([2, 2])
Labels shape: torch.Size([2])
Loss: 0.5210650563240051
Output shape: torch.Size([2, 2])
Labels shape: torch.Size([2])
Loss: 0.46840599179267883
Output shape: torch.Size([2, 2])
Labels shape: torch.Size([2])
Loss: 0.5839701294898987
Output shape: torch.Size([2, 2])
Labels shape: torch.Size([2])
Loss: 0.5308350324630737
Output shape: torch.Size([2, 2])
Labels shape: torch.Size([2])
Loss: 0.41162973642349243
Output shape: torch.Size([2, 2])
Labels shape: torch.Size([2])
Loss: 0.7288163304328918
Output shape: torch.Size([2, 2])
Labels shape: torch.Size([2])
Loss: 0.6156578063964844
Output shape: torch.Size([2, 2])
Labels shape: torch.Size([2])
Loss: 0.6752355098724365
Output shape: torch.Size([2, 2])
Labels shape: torch.Size([2])
Loss: 0.5649727582931519
Output shape: torch.Size([2, 2])
Labels shape: torch.Size([2])
Loss: 0.7341600060462952
Output shape: torch.Size([2, 2])
Labels shape: torch.Size([2])
Loss: 0.6944077014923096
Output shape: torch.Size([2, 2

KeyboardInterrupt: 