## Imports

In [21]:
import os
import json
import re
import torch
import requests
from PIL import Image
import torch_geometric
from torch_geometric.data import Data, Dataset
from torch.utils.data import DataLoader
from transformers import RobertaTokenizer, ViTFeatureExtractor

### Dataset Class


In [None]:
class DatasetClass(torch.utils.data.Dataset):
    def __inti__(self,path, tokenizer, feature_extractor, img_dir):
        '''
        Args:
        path (str): path to trees json files
        tokenizer: Text, currenlty use RobertaTokenizer <= Change Later
        feature_extractor: Image, currently use ViTFeatureExtractor <= Change Later
        img_dir (str): directory to images 

        Note: Currenlty, I'm downloading images from URL and saving folder.
        Later update it..

        '''
        self.path = path
        self.tokenizer = tokenizer
        self.feature_extractor = feature_extractor
        self.img_dir = img_dir

        os.makedirs(self.img_dir,exist_ok=True)
        self.data = self.load_dataset()

        def extract_url(self,tweet_text):
            url_pattern = re.compile(r'https?://\S+') # URL pattern
            return url_pattern.findall(tweet_text)
        
        def download_image(self,url,filename):
            ''' I'm trying to download image from URL and save it in folder '''
            try:
                response = requests.get(url,timeout=5) 
                if response.status_code == 200:
                    filepath = os.path.join(self.img_dir,filename)
                    with open(filepath,'wb') as f: # write binary
                        f.write(response.content)
                    return filepath
            except Exception as e:
                print(f"Error downloading image {url}: {e}")
            return None
        
        def load_dataset(self):
            ''' Load dataset from json files '''
            processed_data = []
            
            for filename in os.listdir(self.dataset_path):
                if filename.endswith('.json'):
                    filepath = os.path.join(self.dataset_path, filename)
                    
                    with open(filepath, 'r') as f:
                        data = json.load(f)
                    
                    # I'm trying to extract label, Needed?
                    label = 1 if data['label'] == 'real' else 0
                    
                    # Graph
                    node_features = [] # Like followers_count, following_count, verified status
                    node_id_map = {}  # Map node IDs to indices for edge construction (Like source-target pairs)
                    
                    text_inputs = []
                    image_inputs = []
                    
                    for idx, node in enumerate(data['nodes']):
                        # Map node ID to index for edge construction later
                        node_id_map[node['id']] = idx
                        
                        # Text processing
                        tweet_text = node['tweet_text']
                        text_input = self.tokenizer(
                            tweet_text,
                            padding=True,
                            truncation=True,
                            return_tensors='pt'
                        )
                        text_inputs.append(text_input)
                        
                        # Image processing
                        image_urls = self.extract_urls(tweet_text)
                        image_path = None
                        if image_urls:
                            image_filename = f"{node['id']}.jpg"
                            image_path = self.download_image(image_urls[0], image_filename)
                        
                        if image_path:
                            image_input = self.feature_extractor(
                                Image.open(image_path),
                                return_tensors='pt'
                            )
                            image_inputs.append(image_input)
                        else:
                            # Fallback: Just use a random tensor if no image available
                            random_image_tensor = {'pixel_values': torch.randn(1, 3, 224, 224)}
                            image_inputs.append(random_image_tensor)
                        
                        features = torch.tensor([
                            node['followers_count'],
                            node['following_count'],
                            node['verified']
                        ], dtype=torch.float)
                        
                        node_features.append(features)
                    
                    # Process edges using source-target pairs from JSON's "edges" key
                    edge_index = []
                    for edge in data['edges']:
                        source_idx = node_id_map.get(edge['source'])
                        target_idx = node_id_map.get(edge['target'])
                        
                        if source_idx is not None and target_idx is not None:
                            edge_index.append((source_idx, target_idx))
                    
                    # Convert edge list to PyTorch tensor(edge_index)
                    # Stored column-wise (i.e., first row contains source nodes, second row contains target nodes)
                    edge_index_tensor = torch.tensor(edge_index).t().contiguous()  # Convert to PyTorch tensor
                    
                    # Create a PyTorch Geometric Data object for graph representation
                    # (See https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html#data-handling-of-graphs)

                    graph_data = Data(
                        x=torch.stack(node_features),  # Node features matrix [num_nodes x num_features]
                        edge_index=edge_index_tensor  # Edge index [2 x num_edges]
                    )
                    
                    # I'm currenlty returning graph, text_inputs, image_inputs, label --> embeddings
                    processed_entry = {
                        'graph': graph_data,
                        'text_inputs': text_inputs,
                        'image_inputs': image_inputs,
                        'label': label,
                    }
                    
                    processed_data.append(processed_entry)
            
            return processed_data
        
        def __len__(self):
            return len(self.data)
        
        def __getitem__(self, idx):
            return self.data[idx]
        

####  Collate Function to  Batch the Multimodal
This code does following:
* Combine individual graph data objects (graph) into a batch.
* Pad text sequences (input_ids, attention_mask) to ensure all sequences in the batch have the same length.
* Stack image tensors (pixel_values) into a batch.
* Combine labels into a single tensor.




In [None]:
def collate_fn(batch):
    """
    Args:
    - batch (list): Batch of processed entries.
    graph: A PyTorch Geometric Data object representing a graph.
    text_inputs: Tokenized text inputs.
    image_inputs: pixel_values
    label: The label for the sample.

    """
    graphs_batch = [item['graph'] for item in batch]
    labels_batch = torch.tensor([item['label'] for item in batch])
    
    text_batch_input_ids = torch.nn.utils.rnn.pad_sequence(
        [t['input_ids'].squeeze() for item in batch for t in item['text_inputs']],
        batch_first=True
    )
    # Attention mask and token type IDs are the same for all text inputs
    text_batch_attention_mask = torch.nn.utils.rnn.pad_sequence(
        [t['attention_mask'].squeeze() for item in batch for t in item['text_inputs']],
        batch_first=True
    )
    
    text_batch = {
        'input_ids': text_batch_input_ids,
        'attention_mask': text_batch_attention_mask,
    }
    
    # Stack image pixel values into a single tensor
    image_batch_pixel_values = torch.stack(
        [i['pixel_values'].squeeze() for item in batch for i in item['image_inputs']]
    )
    
    image_batch = {
        'pixel_values': image_batch_pixel_values,
    }
    
    return {
        'graphs': graphs_batch,
        'text_batch': text_batch,
        'image_batch': image_batch,
        'labels': labels_batch,
    }


#### DataLoader

In [None]:
def dataloader(train_path,test_path,batch_size=32): # I'm setting default batch size to 32

    tokenizer = RobertaTokenizer.from_pretrained('roberta-base') #Using Roberta Tokenizer
    feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k') #Using ViT Feature Extractor

    train_image_dir = os.path.join(train_path,'images')
    test_image_dir = os.path.join(test_path,'images')


    train_dataset = DatasetClass(path=train_path,
                                 tokenizer=tokenizer,
                                 feature_extractor=feature_extractor,
                                 img_dir=train_image_dir)
    
    test_dataset = DatasetClass(path=test_path,
                                tokenizer=tokenizer,
                                feature_extractor=feature_extractor,
                                img_dir=test_image_dir)
    
    train_loader = DataLoader(train_dataset, 
                              batch_size=batch_size, 
                              shuffle=True, 
                              collate_fn=collate_fn)
    
    test_loader = DataLoader(test_dataset,
                             batch_size=batch_size,
                             shuffle=False,
                             collate_fn=collate_fn)     
    
    return train_loader, test_loader


