In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data

from transformers import AutoModel, AutoTokenizer, AutoFeatureExtractor

In [2]:
class GraphEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GraphEncoder, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels) #Maybe GATConv ?
        self.conv2 = GCNConv(hidden_channels, out_channels) #Maybe GATConv ?
    
    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

In [3]:
class TriModalBridgeLayer(nn.Module):
    def __init__(self, hidden_dim, num_heads):
        super().__init__()
        self.self_attn_text = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
        self.self_attn_vision = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
        self.self_attn_graph = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
        self.cross_attn_text = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
        self.cross_attn_vision = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
        self.cross_attn_graph = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
        self.ff_text = nn.Sequential(
            nn.Linear(hidden_dim, 4 * hidden_dim),
            nn.ReLU(),
            nn.Linear(4 * hidden_dim, hidden_dim),
        )
        self.ff_vision = nn.Sequential(
            nn.Linear(hidden_dim, 4 * hidden_dim),
            nn.ReLU(),
            nn.Linear(4 * hidden_dim, hidden_dim),
        )
        self.ff_graph = nn.Sequential(
            nn.Linear(hidden_dim, 4 * hidden_dim),
            nn.ReLU(),
            nn.Linear(4 * hidden_dim, hidden_dim),
        )

    def forward(self, text_emb, vision_emb, graph_emb):
        text_emb, _ = self.self_attn_text(text_emb, text_emb, text_emb)
        vision_emb, _ = self.self_attn_vision(vision_emb, vision_emb, vision_emb)
        graph_emb, _ = self.self_attn_graph(graph_emb, graph_emb, graph_emb)

        # (vision, graph), (text,graph), (text, vision) - all 3 combs
        #TODO: Or maybe concat all 3 ?
        vis_graph_combined = torch.cat([vision_emb, graph_emb], dim=1)
        text_emb, _ = self.cross_attn_text(text_emb, vis_graph_combined, vis_graph_combined)
        txt_graph_combined = torch.cat([text_emb, graph_emb], dim=1)
        vision_emb, _ = self.cross_attn_vision(vision_emb, txt_graph_combined, txt_graph_combined)
        txt_vis_combined = torch.cat([text_emb, vision_emb], dim=1)
        graph_emb, _ = self.cross_attn_graph(graph_emb, txt_vis_combined, txt_vis_combined)
        
        text_emb = self.ff_text(text_emb)
        vision_emb = self.ff_vision(vision_emb)
        graph_emb = self.ff_graph(graph_emb)
        
        return text_emb, vision_emb, graph_emb

In [4]:
class TriModalBridgeTower(nn.Module):
    def __init__(
        self,
        text_model_name: str,
        vision_model_name: str,
        graph_in_channels: int,
        graph_hidden_channels: int,
        graph_out_channels: int,
        hidden_dim: int,
        num_bridge_layers: int = 2,
        num_heads: int = 8
    ):
        super().__init__()
        
        self.text_encoder = AutoModel.from_pretrained(text_model_name)
        self.vision_encoder = AutoModel.from_pretrained(vision_model_name)
        
        self.graph_encoder = GraphEncoder(
            in_channels=graph_in_channels,
            hidden_channels=graph_hidden_channels,
            out_channels=graph_out_channels
        )
        
        self.graph_proj = nn.Linear(graph_out_channels, hidden_dim)
        
        self.bridge_layers = nn.ModuleList([
            TriModalBridgeLayer(hidden_dim, num_heads)
            for _ in range(num_bridge_layers)
        ])
        
        self.classifier = nn.Linear(hidden_dim, 2)  # genuine vs fake ? Depends on the dataset

    def forward(self, text_batch, vision_batch, graph_data):
        """
        text_batch: dict for text model (input_ids, attention_mask, etc.)
        vision_batch: dict for vision model (pixel_values, etc.)
        graph_data: PyG Data object with .x (node features), .edge_index, etc. (assuming the pyg tutorials)
        """
        
        text_outputs = self.text_encoder(**text_batch)
        text_emb = text_outputs.last_hidden_state  # (batch_size, seq_len_text, hidden_dim)
        
        vision_outputs = self.vision_encoder(**vision_batch)
        vision_emb = vision_outputs.last_hidden_state  # (batch_size, seq_len_vision, hidden_dim)
        
        node_emb = self.graph_encoder(graph_data.x, graph_data.edge_index) # I am assuming we have one large graph from the entire twitter dataset i.e the social network graph (probably we generate this using pyspark)
        graph_emb = self.graph_proj(node_emb)
        if graph_emb.size(0) != batch_size:
            graph_emb = graph_emb.expand(batch_size, -1, -1)
        
        for layer in self.bridge_layers:
            text_emb, vision_emb, graph_emb = layer(text_emb, vision_emb, graph_emb)
        
        # Taking the summaries (CLS tokens)
        text_cls = text_emb[:, 0, :]
        vision_cls = vision_emb[:, 0, :]
        graph_cls = graph_emb[:, 0, :]
        
        # TODO: Experiment
        fused = (text_cls + vision_cls + graph_cls) / 3.0
        logits = self.classifier(fused)
        return logits

In [5]:
text_model_name = "roberta-base" # I've seen this model usually generalizes well (but then again no free lunch)
vision_model_name = "google/vit-base-patch16-224-in21k"
model = TriModalBridgeTower(
    text_model_name=text_model_name,
    vision_model_name=vision_model_name,
    graph_in_channels=16,
    graph_hidden_channels=32,
    graph_out_channels=64,
    hidden_dim=768,       
    num_bridge_layers=2,  #TODO: Experiment, maybe more because we need more hierarchy when we add graph data as well ?
    num_heads=8
)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
tokenizer = AutoTokenizer.from_pretrained(text_model_name)
text_input = ["Hello world!", "This is a test."]
text_batch = tokenizer(
    text_input, padding=True, truncation=True, return_tensors="pt"
)

In [7]:
batch_size = len(text_input)
vision_batch = {
    "pixel_values": torch.randn(batch_size, 3, 224, 224)
}

In [8]:
x = torch.rand(4, 16) 
edge_index = torch.tensor([[0, 1, 2, 2],
                           [1, 0, 3, 1]], dtype=torch.long)
graph_data = Data(x=x, edge_index=edge_index)

In [9]:
logits = model(text_batch, vision_batch, graph_data)
print(logits.shape) 
print( logits)

torch.Size([2, 2])
tensor([[-0.0308, -0.0192],
        [-0.0308, -0.0192]], grad_fn=<AddmmBackward0>)
