In [6]:
%pip install datasets torch transformers sentencepiece

Defaulting to user installation because normal site-packages is not writeable
Collecting sentencepiece
  Downloading sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
Installing collected packages: sentencepiece
Successfully installed sentencepiece-0.2.0
Note: you may need to restart the kernel to use updated packages.


Load dataset [neo4j/text2cypher-2024v1](https://huggingface.co/datasets/neo4j/text2cypher-2024v1)

In [1]:
from datasets import load_dataset

# Load the dataset
dataset = load_dataset("neo4j/text2cypher-2024v1")

Dataset columns names

In [25]:
print(f"Number of columns: {dataset.num_columns}, \nColumns names: {dataset.column_names},\nDataset size: {dataset.shape}")

Number of columns: {'train': 6, 'test': 6}, 
Columns names: {'train': ['question', 'schema', 'cypher', 'data_source', 'instance_id', 'database_reference_alias'], 'test': ['question', 'schema', 'cypher', 'data_source', 'instance_id', 'database_reference_alias']},
Dataset size: {'train': (39554, 6), 'test': (4833, 6)}


Print a sample dataset

In [2]:
print(dataset)
print(dataset['train'][0])

DatasetDict({
    train: Dataset({
        features: ['question', 'schema', 'cypher', 'data_source', 'instance_id', 'database_reference_alias'],
        num_rows: 39554
    })
    test: Dataset({
        features: ['question', 'schema', 'cypher', 'data_source', 'instance_id', 'database_reference_alias'],
        num_rows: 4833
    })
})
{'question': 'Which 3 countries have the most entities linked as beneficiaries in filings?', 'schema': 'Node properties:\n- **Country**\n  - `location`: POINT \n  - `code`: STRING Example: "AFG"\n  - `name`: STRING Example: "Afghanistan"\n  - `tld`: STRING Example: "AF"\n- **Filing**\n  - `begin`: DATE_TIME Min: 2000-02-08T00:00:00Z, Max: 2017-09-05T00:00:00Z\n  - `end`: DATE_TIME Min: 2000-02-08T00:00:00Z, Max: 2017-11-03T00:00:00Z\n  - `originator_bank_id`: STRING Example: "cimb-bank-berhad"\n  - `sar_id`: STRING Example: "3297"\n  - `beneficiary_bank`: STRING Example: "Barclays Bank Plc"\n  - `filer_org_name_id`: STRING Example: "the-bank-of-new-york

In [9]:
from torch.utils.data import DataLoader, Dataset

In [10]:
class Text2CypherDataset(Dataset):
    def __init__(self, dataset_split, tokenizer, max_length=512):
        self.dataset = dataset_split
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        data_point = self.dataset[idx]
        question = data_point["question"]  # User's natural language question
        schema = data_point["schema"]  # Database schema details
        database_reference_alias = data_point["database_reference_alias"] # Database alias name, might be useful in subgraph or cross-domain.
        cypher_query = data_point["cypher"]  # Target Cypher query

        # Combine question and schema as input
        input_text = f"Question: {question} Schema: {schema} Database Refenerce Alias: {database_reference_alias}"

        # Tokenize input (question + schema) and output (cypher query)
        inputs = self.tokenizer(input_text, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt")
        outputs = self.tokenizer(cypher_query, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt")

        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "labels": outputs["input_ids"].squeeze(0),
        }

In [11]:
from transformers import T5Tokenizer

tokenizer = T5Tokenizer.from_pretrained("t5-small")
train_dataset = Text2CypherDataset(dataset["train"], tokenizer)

In [12]:
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# Sample batch
for batch in train_loader:
    print(batch)
    break

{'input_ids': tensor([[11860,    10, 11677,  ...,    26,     2,     1],
        [11860,    10,  4073,  ...,     0,     0,     0],
        [11860,    10,  4073,  ..., 21342, 22034,     1],
        ...,
        [11860,    10,  9778,  ...,     0,     0,     0],
        [11860,    10,  6792,  ...,  6306,    10,     1],
        [11860,    10,   363,  ...,   226,    40,     1]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]]), 'labels': tensor([[  283, 29572,    41,  ...,     0,     0,     0],
        [  283, 29572,    41,  ...,     0,     0,     0],
        [  283, 29572,    41,  ...,     0,     0,     0],
        ...,
        [  283, 29572,    41,  ...,     0,     0,     0],
        [  283, 29572,    41,  ...,     0,     0,     0],
        [  283, 29572,    41,  ...,     0,     0,     0]])}
