# Imports

In [100]:
import os
import zipfile
import glob
import pickle
import nltk
import json
from collections import Counter
from datasets import load_dataset, load_from_disk
from pprint import pprint
from PIL import Image
import pandas as pd
import networkx as nx
from nltk.stem import WordNetLemmatizer

import torchvision.transforms as transforms
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from huggingface_hub import hf_hub_download

nltk.download('punkt_tab', quiet=True)
nltk.download('wordnet', quiet=True)
nltk.download('omw-1.4', quiet=True)

True

# Constants

### Dataset Contstants

In [43]:
IMG_SIZE = (224, 224)
VOCAB_SIZE = 5000
BATCH_SIZE = 32
MAX_NODES_PER_QUESTION = 10

# Directory Information
DATA_DIR = "data/"
DATASET_PATH = os.path.join(DATA_DIR, 'dataset/')
IMAGE_PATH = os.path.join(DATA_DIR, 'imgs/')
KG_PATH = os.path.join(DATA_DIR, 'KG/')
VOCABS_PATH = os.path.join(DATA_DIR, 'vocabs/')

# Huggingface Repository Information
repo_id = "BoKelvin/SLAKE"
repo_type = "dataset"
kg_file = "KG.zip"
img_file = "imgs.zip"

# Entity Extraction
MAX_PHRASE_LENGTH = 5 # Look up to 5-grams for entity matching

# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


# Dataset Setup

### Data Download

In [44]:
# Utility function for downloading and extracting ZIP file
def download_and_store_ZIP(filename, save_dir):
    print(f"Fetching file {filename} from {repo_id} repo")

    try:
        # Caches the file locally and returns the path to the cached file
        cached_zip_path = hf_hub_download(
          repo_id=repo_id,
          filename=filename,
          repo_type=repo_type
        )
        print(f"{filename} download complete. Cached at: {cached_zip_path}")

        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        # Extract the contents
        print(f"Extracting to {save_dir}...")
        with zipfile.ZipFile(cached_zip_path, 'r') as zip_ref:
            zip_ref.extractall(save_dir)

        print("Extraction complete.")
        print(f"{filename} files are located in: {os.path.abspath(save_dir)}")
    except Exception as e:
        print(f"Failed to download or extract {filename}: {e}")

# Scoping to English only
def filter_language(original):
    return original.filter(lambda data: data['q_lang'] == 'en')

# Download and store the dataset
def download_and_store_english_dataset():
    print(f"Downloading dataset from {repo_id} repo")

    # Load from Hugging Face
    original = load_dataset(repo_id)

    # Scope to English Only
    original = filter_language(original)

    # Show the dataset formatting
    pprint(original)

    # Save the original dataset
    if not os.path.exists(DATA_DIR):
        os.makedirs(DATA_DIR)

    if not os.path.exists(DATASET_PATH):
        os.makedirs(DATASET_PATH)

    original.save_to_disk(DATASET_PATH)
    return original

# Download and store the Knowledge Graph files
def download_and_store_KG():
    download_and_store_ZIP(kg_file, DATA_DIR)

# Download and store the image files
def download_and_store_image():
    download_and_store_ZIP(img_file, DATA_DIR)

# Download necessary files
def download_and_store_slake():
    dataset = download_and_store_english_dataset()
    download_and_store_image()
    download_and_store_KG()

    return dataset

### Knowledge Graph

In [97]:
# Build NetworkX graph from the knowledge graph
def build_slake_knowledge_graph(kg_directory):
    # Initialize a Directed Multi-Graph (allows parallel edges)
    G = nx.MultiDiGraph()

    search_path = os.path.join(kg_directory, "en_*.csv")
    csv_files = glob.glob(search_path)

    if not csv_files:
        print(f"No files found matching {search_path}")
        return None
    print(f"Found {len(csv_files)} English KG files. Building graph...")

    for file_path in csv_files:
        filename = os.path.basename(file_path)
        print(f"Processing {file_path}")

        try:
            df = pd.read_csv(file_path, dtype=str, delimiter='#')
            print(f"Total relations in {filename}: {df.size}")

            if len(df.columns) < 3:
                print(f"Error with {filename}: Expected 3 columns, found {len(df.columns)}")
                continue

            df.columns = ['head', 'relation', 'tail'] + list(df.columns[3:])

            # Add edges for the graph
            for _, row in df.iterrows():
                u = row['head'].strip().lower()
                rel = row['relation'].strip().lower()
                v = row['tail'].strip().lower()

                # Add to edge list
                G.add_edge(u, v, relation=rel)

            # G.add_edges_from(edges)

        except Exception as e:
            print(f"Error reading {filename}: {e}")

    print(f"\nGraph Built Successfully!")
    print(f"Nodes: {G.number_of_nodes()}")
    print(f"Edges: {G.number_of_edges()}")

    # Save as pickle for fast load
    if not os.path.exists(KG_PATH):
        os.makedirs(KG_PATH)

    pickle.dump(G, open(os.path.join(KG_PATH, 'slake_kg.pickle'), 'wb'))
    return G

### Vocabs list

In [46]:
class VocabularyBuilder:
    def __init__(self, min_freq=1):
        self.min_freq = min_freq
        self.itos = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"}
        self.stoi = {v: k for k, v in self.itos.items()}

    def tokenize(self, text):
        return nltk.word_tokenize(text.lower())
    
    def __len__(self):
        return len(self.stoi)
    
    def build_word_vocabs(self, sentences):
        counter = Counter()
        start_index = len(self.stoi)

        # 1. Count frequencies of all tokens in the tokenized sentences
        for sentence in sentences:
            tokens = self.tokenize(sentence)
            counter.update(tokens)

        # 2. Add words that meet the frequency threshold
        for word, count in counter.items():
            if count >= self.min_freq and word not in self.stoi:
                self.stoi[word] = start_index
                self.itos[start_index] = word
                start_index += 1

        print(f"Vocabulary Built. Vocabulary Size: {len(self.stoi)}")

    def numericalize(self, text):
        tokens = self.tokenize(text)
        return [
            self.stoi[token] if token in self.stoi else self.stoi["<unk>"]
            for token in tokens
        ]

### Utilities

In [47]:
# Build vocabularies for questions and answers
def build_vocabs(dataset):
    questions = [item['question'] for item in dataset]
    answers = [item['answer'] for item in dataset]

    # Question Vocabulary
    questvocab_builder = VocabularyBuilder(min_freq=1)
    questvocab_builder.build_word_vocabs(questions)
    
    # Answer Vocabulary
    ansvocab_builder = VocabularyBuilder(min_freq=1)
    ansvocab_builder.build_word_vocabs(answers)

    return questvocab_builder, ansvocab_builder

# Save vocabularies to JSON files
def save_vocabs(quest_vocab, ans_vocab):
    if not os.path.exists(VOCABS_PATH):
        os.makedirs(VOCABS_PATH)

    # Save Question Vocabulary
    with open(os.path.join(VOCABS_PATH, 'question_vocab.json'), 'w') as f:
        json.dump({'stoi': quest_vocab.stoi, 'itos': quest_vocab.itos}, f)

    # Save Answer Vocabulary
    with open(os.path.join(VOCABS_PATH, 'answer_vocab.json'), 'w') as f:
        json.dump({'stoi': ans_vocab.stoi, 'itos': ans_vocab.itos}, f)

    print("Vocabularies saved successfully.")

# Mapping from node name to index
def create_node_mapping(graph):
    nodes = list(graph.nodes)
    node_to_index = {node: idx + 1 for idx, node in enumerate(nodes)} # Index 0 reserved for padding
    index_to_node = {idx + 1: node for idx, node in enumerate(nodes)}

    return node_to_index, index_to_node

### Dataset Class

In [101]:
class SlakeDataset(Dataset):
    def __init__(self, dataset, question_vocab, answer_vocab, graph, node_to_index, transform=None):
        self.data = dataset
        self.question_vocab = question_vocab
        self.answer_vocab = answer_vocab
        self.graph = graph
        self.node_to_index = node_to_index
        self.transform = transform

        # Pre compute graph nodes for faster access
        self.graph_nodes = set(graph.nodes)

    def __len__(self):
        return len(self.data)
    
    def extract_entities(self, text):
        # Initialize Lemmatizer
        lemmatizer = WordNetLemmatizer()

        tokens = nltk.word_tokenize(text.lower())
        extracted_entities = []

        i = 0
        while i < len(tokens):
            match_found = False
            # Check for n-grams from MAX_PHRASE_LENGTH down to 1
            for length in range(MAX_PHRASE_LENGTH, 0, -1):
                if i + length <= len(tokens):
                    phrase_tokens = tokens[i:i+length]

                    # Strategy 1: Exact Match
                    phrase = ' '.join(phrase_tokens)
                    if phrase in self.graph_nodes:
                        index = self.node_to_index.get(phrase, None)
                        extracted_entities.append(index)
                        i += length
                        match_found = True
                        break

                    # Strategy 2: Lemmatized Match (For singular/plural issues)
                    phrase_lemma_tokens = [lemmatizer.lemmatize(token, pos='n') for token in phrase_tokens]
                    phrase_lemma_str = " ".join(phrase_lemma_tokens)
                    if phrase_lemma_str in self.graph_nodes:
                        index = self.node_to_index.get(phrase_lemma_str, None)
                        extracted_entities.append(index)
                        i += length
                        match_found = True
                        break

            if not match_found:
                i += 1

        return extracted_entities if extracted_entities else [0] # Return placeholder if no entities found
    
    def __getitem__(self, idx):
        item = self.data[idx]

        # 1. Image Processing
        image_path = item['img_name']
        image_path = os.path.join(IMAGE_PATH, image_path)
        image = Image.open(image_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        # 2. Question Processing
        question = item['question']
        question_indices = self.question_vocab.numericalize(question)

        # 3. Answer Processing
        answer = str(item.get('answer', '')) # Answer may be missing in test set
        answer_index = self.answer_vocab.numericalize(answer)

        # 4. Inject Knowledge from KG
        kg_index = self.extract_entities(question)

        return {
            'image': image,
            'question' : torch.tensor(question_indices),
            'answer' : torch.tensor(answer_index),
            'kg_index' : torch.tensor(kg_index),
            # Add original items for reference
            'original_question': question,
            'original_answer': answer,
            # Add ID for tracking
            'id': item['qid']
        }

### Collate function

Questions have different lengths, need to pad properly to make sure the length is constant.

In [49]:
class SlakeCollate:
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        # Stack images
        images = torch.stack([item['image'] for item in batch])

        # Pad Questions
        questions = [item['question'] for item in batch]
        padded_questions = pad_sequence(questions, batch_first=True, padding_value=self.pad_idx)

        # Pad Answers
        answers = [item['answer'] for item in batch]
        answers = pad_sequence(answers, batch_first=True, padding_value=self.pad_idx)

        # Pad KG Indices
        kg_indices = [item['kg_index'] for item in batch]
        kg_indices = pad_sequence(kg_indices, batch_first=True, padding_value=0)

        return {
            'image': images,
            'question': padded_questions,
            'answer': answers,
            'kg_index': kg_indices,
            'original_questions': [item['original_question'] for item in batch],
            'original_answers': [item['original_answer'] for item in batch],
            'ids': [item['id'] for item in batch]
        }

In [None]:
# Testing Purpose
def test_pipelines(need_download=True):
    if need_download:
        # Download and store the dataset
        dataset = download_and_store_slake()
    else:
        dataset = load_from_disk(DATASET_PATH)

    # Build the knowledge graph
    kg = build_slake_knowledge_graph(KG_PATH)

    # Build vocabularies
    train_data = dataset['train']
    question_vocab, answer_vocab = build_vocabs(train_data)

    # Save vocabularies
    save_vocabs(question_vocab, answer_vocab)

    # Create node mappings
    node_to_index, index_to_node = create_node_mapping(kg)

    # Define image transformations
    transform = transforms.Compose([
        transforms.Resize(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Create dataset and dataloader
    train_dataset = SlakeDataset(train_data, question_vocab, answer_vocab, kg, node_to_index, transform=transform)
    collate_fn = SlakeCollate(pad_idx=0)

    train_loader = DataLoader(
        train_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=True, 
        collate_fn=collate_fn
    )

    # Test the loader
    for batch in train_loader:
        print("Batch Size:", train_loader.batch_size)
        print("Images shape:", batch['image'].shape)     # (32, 3, 224, 224)
        print("Questions shape:", batch['question'].shape) # (32, max_q_len)
        print("Answer shape:", batch['answer'].shape) # (32, max_q_len)
        print("KG Nodes shape:", batch['kg_index'].shape)   # (32, max_nodes_len)

        # Specific data
        print("Question: ", batch['original_questions'][0])
        print("Tokenized Question: ", batch['question'][0])
        print("Answer: ", batch['original_answers'][0])
        print("Tokenized Answer: ", batch['answer'][0])
        print("KG Nodes: ", batch['kg_index'][0])
        break

In [115]:
loader, kg, n2i, i2n = test_pipelines(False)

Found 3 English KG files. Building graph...
Processing data/KG\en_disease.csv
Total relations in en_disease.csv: 6648
Processing data/KG\en_organ.csv
Total relations in en_organ.csv: 843
Processing data/KG\en_organ_rel.csv
Total relations in en_organ_rel.csv: 309

Graph Built Successfully!
Nodes: 1892
Edges: 2600
Vocabulary Built. Vocabulary Size: 281
Vocabulary Built. Vocabulary Size: 249
Vocabularies saved successfully.
Batch Size: 32
Images shape: torch.Size([32, 3, 224, 224])
Questions shape: torch.Size([32, 17])
Answer shape: torch.Size([32, 5])
KG Nodes shape: torch.Size([32, 2])
Question:  Is this a study of the chest?
Tokenized Question:  tensor([  6,  10, 117, 118,  15,  16, 128,  12,   0,   0,   0,   0,   0,   0,
          0,   0,   0])
Answer:  Yes
Tokenized Answer:  tensor([7, 0, 0, 0, 0])
KG Nodes:  tensor([1722,    0])
