In [1]:
import json

with open("en_description.json", "r", encoding="utf-8") as f:
    en_description = json.load(f)
with open("process_p31_p279/class_counts.json", "r", encoding="utf-8") as f:
    class_counts = json.load(f)
with open("entityid2label.json", "r", encoding="utf-8") as f:
    entityid2label = json.load(f)

In [5]:
class_counts

{'Q13442814': 42725547,
 'Q5': 11112496,
 'Q16521': 3749497,
 'Q4167836': 3464955,
 'Q7318358': 2096909,
 'Q113145171': 1274348,
 'Q7187': 1224818,
 'Q8054': 1006995,
 'Q4167410': 806428,
 'Q3305213': 657812,
 'Q79007': 653456,
 'Q101352': 597294,
 'Q871232': 511820,
 'Q2668072': 503090,
 'Q8502': 444545,
 'Q30612': 391846,
 'Q4022': 330332,
 'Q11266439': 327511,
 'Q486972': 322197,
 'Q43305660': 284920,
 'Q54050': 284650,
 'Q13433827': 263120,
 'Q482994': 252130,
 'Q11424': 244392,
 'Q39614': 237494,
 'Q3863': 236348,
 'Q9842': 221384,
 'Q41176': 217221,
 'Q23397': 212816,
 'Q16970': 207449,
 'Q3331189': 206135,
 'Q47150325': 201330,
 'Q13406463': 194815,
 'Q4830453': 193870,
 'Q2782326': 187418,
 'Q532': 180843,
 'Q1080794': 180637,
 'Q19389637': 176357,
 'Q47521': 174302,
 'Q21191270': 165992,
 'Q105543609': 163453,
 'Q29654788': 161458,
 'Q21014462': 153639,
 'Q7725634': 152379,
 'Q59199015': 148039,
 'Q115595777': 140371,
 'Q18593264': 132380,
 'Q23442': 130154,
 'Q61443690': 1291

## Training

In [None]:
import os
import json
import random
import numpy as np
from glob import glob
from tqdm.auto import tqdm
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from transformers import (
    GPT2Config,
    GPT2LMHeadModel,
    Trainer,
    TrainingArguments,
)


################################################################################
# 0. Define Special Tokens
################################################################################

PAD_TOKEN_ID = 0
BOS_TOKEN_ID = 1
EOS_TOKEN_ID = 2
DIR_FWD_TOKEN_ID = 3
DIR_BWD_TOKEN_ID = 4

################################################################################
# 1. Load and Process .tsv Paths into a NumPy Object Array
################################################################################


def load_paths(num_classes=10):
    """
    Example function that:
      - loads JSON files (vocab, counts, etc.),
      - finds relevant .tsv files,
      - reads each line, and
      - converts string IDs to integer indices *shifted by 5* to avoid collisions
        with special tokens 0..4.

    Returns:
      - all_paths: 1D NumPy object array of shape (N,)
                   each element is a np.array of dtype int32 representing one path
      - id2idx:    dict mapping string ID -> int index (starting at 5)
      - idx2id:    dict mapping int index (>=5) -> string ID
      - id2label:  dict mapping string ID -> string label
    """
    # Paths to JSON data (adjust as needed)
    with open(f"process_paths/counts_{num_classes}.json", "r", encoding="utf-8") as f:
        entity_counts = json.load(f)

    with open(f"process_paths/vocab_{num_classes}.json", "r", encoding="utf-8") as f:
        id2label = json.load(f)

    with open(f"process_paths/stats_{num_classes}.json", "r", encoding="utf-8") as f:
        stats = json.load(f)

    with open(f"process_p31_p279/class_counts.json", "r", encoding="utf-8") as f:
        class_counts = json.load(f)

    # We shift real token IDs by +5 so that 0..4 are free for [PAD, BOS, EOS, DIR_FWD, DIR_BWD].
    # Original enumerations gave i=0.. => now i+5 => real IDs start at 5
    id2idx = {id_str: (i + 5) for i, id_str in enumerate(id2label)}
    idx2id = {v: k for k, v in id2idx.items()}  # inverse map

    # Suppose you want to pick the first num_classes from class_counts as "starting_entities"
    starting_entities = set(list(class_counts.keys())[:num_classes])

    # Collect relevant .tsv files
    tsv_paths = []
    for path in glob("./extracted_paths/*/*.tsv"):
        dir_name = os.path.basename(os.path.dirname(path))
        if dir_name in starting_entities:
            tsv_paths.append(path)

    all_paths = []
    for path in tqdm(tsv_paths, desc="Reading TSV files"):
        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                items = line.strip().split("\t")
                # Convert each string ID to int index (shifted by +5)
                path_indices = [id2idx[x] for x in items]
                if path_indices:
                    all_paths.append(np.array(path_indices, dtype=np.int32))

    # Convert the list into a 1D NumPy object array
    all_paths = np.array(all_paths, dtype=object)

    print(f"Loaded {len(all_paths)} paths total.")

    return all_paths, id2idx, idx2id, id2label


################################################################################
# 2. Shuffle all_paths Globally in Memory
################################################################################


def shuffle_paths(all_paths):
    """
    Randomly permute the array of paths in place (global shuffle).
    """
    print("Shuffling paths globally...")
    np.random.shuffle(all_paths)
    return all_paths


################################################################################
# 3. Define a Dataset that Doubles the Data (Forward + Backward)
################################################################################


class MultiAugmentPathsDataset(Dataset):
    """
    A single dataset storing each path only once. During __getitem__:
      1) Randomly crop with probability p_crop (if p_crop > 0.0).
      2) Randomly choose forward/backward with probability p_dir (if p_dir > 0.0).

    Args:
        paths_np_array (np.ndarray):
            1D NumPy object array of np.array(dtype=int32).
        min_crop_length (int):
            Minimum length of a sub-path if we do a crop.
        p_crop (float):
            Probability of performing a random crop on a given sample.
            If p_crop==0.0, no cropping is done.
        p_dir (float):
            Probability that a sample will be reversed (backward).
            If p_dir==0.0, all samples are forward only.
    """

    def __init__(self, paths_np_array, min_crop_length=2, p_crop=0.5, p_dir=0.5):
        self.paths = paths_np_array
        self.min_crop_length = min_crop_length
        self.p_crop = p_crop
        self.p_dir = p_dir

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        """
        Returns a dict with 'input_ids' and 'labels'.
        """
        path_array = self.paths[idx]

        # 1) Possibly random-crop if p_crop > 0
        if len(path_array) >= self.min_crop_length:
            if random.random() < self.p_crop:
                cropped = self.random_subpath(path_array)
                if cropped is not None:
                    path_array = cropped

        # Convert to list so we can optionally reverse it
        path_list = path_array.tolist()

        # 2) Possibly reverse direction if p_dir > 0
        if random.random() < self.p_dir:
            direction_token = DIR_BWD_TOKEN_ID
            path_list.reverse()
        else:
            direction_token = DIR_FWD_TOKEN_ID

        # Final sequence = [DIR, BOS] + path + [EOS]
        seq = [direction_token, BOS_TOKEN_ID] + path_list + [EOS_TOKEN_ID]
        seq_array = np.array(seq, dtype=np.int32)

        return {"input_ids": seq_array, "labels": seq_array}

    def random_subpath(self, path):
        """
        Returns a random sub-path of `path`.
        If the path is too short, returns None.
        """
        length = len(path)
        if length < self.min_crop_length:
            return None

        start = random.randint(0, length - self.min_crop_length)
        end = random.randint(start + self.min_crop_length, length)
        return path[start:end]


################################################################################
# 4. Custom Data Collator for Padding
################################################################################


class MyDataCollator:
    """
    Pads sequences in a batch. We assume:
      PAD_TOKEN_ID = 0,
      BOS_TOKEN_ID = 1,
      EOS_TOKEN_ID = 2,
      DIR_FWD_TOKEN_ID = 3,
      DIR_BWD_TOKEN_ID = 4,
      and real tokens start at 5+.

    '-100' is used for labels padding (ignored by the loss).
    """

    def __init__(self, pad_token_id=PAD_TOKEN_ID):
        self.pad_token_id = pad_token_id

    def __call__(self, batch):
        input_ids_list = [torch.tensor(d["input_ids"], dtype=torch.long) for d in batch]
        labels_list = [torch.tensor(d["labels"], dtype=torch.long) for d in batch]

        # Pad input_ids with pad_token_id=0
        input_ids = pad_sequence(
            input_ids_list, batch_first=True, padding_value=self.pad_token_id
        )
        # Pad labels with -100 so they're ignored
        labels = pad_sequence(labels_list, batch_first=True, padding_value=-100)

        attention_mask = (input_ids != self.pad_token_id).long()

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }


################################################################################
# 5. Main: Putting it All Together
################################################################################

# A. Load data
all_paths, id2idx, idx2id, id2label = load_paths(num_classes=100)

# B. Shuffle globally
all_paths = shuffle_paths(all_paths)

# C. Build Dataset (both forward & backward sequences)
dataset = MultiAugmentPathsDataset(
    paths_np_array=all_paths,
    min_crop_length=2,
    p_crop=0.5,  # 50% chance of cropping
    p_dir=0.5,  # 50% chance of reversing
)

# D. Create a GPT-like model
# We must ensure the vocab accommodates:
#   - PAD (0), BOS (1), EOS (2), DIR_FWD (3), DIR_BWD (4),
#   - plus all real tokens starting at 5, up to (5 + len(id2label)-1).
vocab_size = len(id2idx) + 5
print("Computed vocab_size:", vocab_size)

config = GPT2Config(
    vocab_size=vocab_size,
    n_positions=64,
    n_embd=256,  # normally 768 for GPT-2 small
    n_layer=4,  # normally 12
    n_head=4,  # normally 12
)

model = GPT2LMHeadModel(config)

# Set special token IDs
model.config.pad_token_id = PAD_TOKEN_ID  # 0
model.config.bos_token_id = BOS_TOKEN_ID  # 1
model.config.eos_token_id = EOS_TOKEN_ID  # 2
# NOTE: We don't have explicit config fields for direction tokens,
#       but they are just normal tokens in the vocab.

# E. Training with Trainer
training_args = TrainingArguments(
    output_dir="./model_output",
    overwrite_output_dir=True,
    num_train_epochs=10,
    per_device_train_batch_size=512,
    evaluation_strategy="no",
    logging_steps=100,
    save_steps=1000,
    save_total_limit=2,
    # shuffle=True is the default in Trainer for map-style datasets
)

data_collator = MyDataCollator(pad_token_id=PAD_TOKEN_ID)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=data_collator,
)

print(f"Total parameters: {model.num_parameters()}")
print("Number of parameters per layer:")
for name, param in model.named_parameters():
    print(f"{name}: {param.numel()}")

print("Starting training...")
trainer.train()
print("Training complete.")