## Training

In [None]:
import os
import json
import time
import random
from glob import glob
from tqdm.auto import tqdm
from collections import Counter

import torch
from torch.utils.data import (
    Dataset,
    DataLoader,
    WeightedRandomSampler,
    SubsetRandomSampler,
)

# Transformers
from transformers import (
    GPT2Config,
    GPT2LMHeadModel,
    GPT2Tokenizer,
    Trainer,
    TrainingArguments,
)

################################################################################
# 0. Custom Tokenizer Logic
################################################################################

# We'll store our custom tokenizer in this directory
custom_tokenizer_dir = "custom_tokenizer"

# If the directory doesn't exist, create + save a tokenizer with special tokens
if not os.path.exists(custom_tokenizer_dir):
    print(
        "No custom tokenizer found. Creating one from base GPT-2 and adding special tokens..."
    )
    base_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

    special_tokens = {
        "bos_token": "<BOS>",
        "eos_token": "<EOS>",
        "pad_token": "<PAD>",
        "additional_special_tokens": ["<DOWNWARD>"],
    }
    num_added = base_tokenizer.add_special_tokens(special_tokens)
    print(f"Added {num_added} special tokens. New vocab size = {len(base_tokenizer)}")

    # Save this new tokenizer for future runs
    base_tokenizer.save_pretrained(custom_tokenizer_dir)
    print(f"Custom tokenizer saved to {custom_tokenizer_dir}.")
else:
    print(f"Found existing custom tokenizer at {custom_tokenizer_dir}.")

# Now load the (possibly newly-created) custom tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(custom_tokenizer_dir)
print(f"Loaded custom tokenizer with vocab size = {len(tokenizer)}")

################################################################################
# 1. Utility Functions
################################################################################


def get_tsv_paths(num_classes, sample_first_batch):
    with open("process_p31_p279/class_counts.json", "r", encoding="utf-8") as f:
        class_counts = json.load(f)
    starting_entities = set(list(class_counts.keys())[:num_classes])

    tsv_paths_by_class = {}
    for path in glob("./extracted_paths/*/*.tsv"):
        class_dir = os.path.basename(os.path.dirname(path))
        if class_dir in starting_entities:
            tsv_paths_by_class.setdefault(class_dir, []).append(path)

    tsv_paths = []
    if sample_first_batch:
        for class_dir, paths in tsv_paths_by_class.items():
            batch1_files = [p for p in paths if "batch_1" in os.path.basename(p)]
            if batch1_files:
                tsv_paths.append(batch1_files[0])
            else:
                tsv_paths.append(paths[0])
    else:
        for paths in tsv_paths_by_class.values():
            tsv_paths.extend(paths)

    print(f"Found {len(tsv_paths)} TSV files.")
    return tsv_paths


def load_id2label(num_classes):
    with open(
        f"process_paths/vocab_top_{num_classes}.json", "r", encoding="utf-8"
    ) as f:
        id2label = json.load(f)
    return id2label


################################################################################
# 2. Datasets
################################################################################


class EfficientLazyDataset(Dataset):
    """
    Builds an index of (file_idx, byte_offset) for each non-empty line in TSV files.
    __getitem__ seeks directly to that offset, reads one line, and tokenizes it.
    """

    def __init__(self, tsv_paths, id2label, tokenizer, max_length):
        self.tsv_paths = tsv_paths
        self.id2label = id2label
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.index_mapping = []

        print("Building file offset index for lazy loading...")
        for file_idx, path in enumerate(tqdm(tsv_paths, desc="Indexing TSV files")):
            with open(path, "rb") as f:
                offset = f.tell()
                line = f.readline()
                while line:
                    if line.strip():
                        self.index_mapping.append((file_idx, offset))
                    offset = f.tell()
                    line = f.readline()
        print(f"Dataset contains {len(self.index_mapping)} samples.")

    def __len__(self):
        return len(self.index_mapping)

    def __getitem__(self, idx):
        file_idx, offset = self.index_mapping[idx]
        path = self.tsv_paths[file_idx]

        with open(path, "rb") as f:
            f.seek(offset)
            line = f.readline().decode("utf-8")

        items = line.strip().split("\t")
        tokens = [self.id2label.get(token, token) for token in items]

        if tokens:
            # Build <BOS> tok1 <DOWNWARD> tok2 ... <EOS>
            sequence = self.tokenizer.bos_token + tokens[0]
            for token in tokens[1:]:
                sequence += self.tokenizer.additional_special_tokens[0] + token
            sequence += self.tokenizer.eos_token
        else:
            sequence = self.tokenizer.bos_token + self.tokenizer.eos_token

        encoding = self.tokenizer(
            sequence,
            truncation=False,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt",
        )
        item = {k: v.squeeze(0) for k, v in encoding.items()}
        item["labels"] = item["input_ids"].clone()
        return item


class SubsetLazyDataset(EfficientLazyDataset):
    def __init__(self, base_dataset, subset_indices):
        self.tsv_paths = base_dataset.tsv_paths
        self.id2label = base_dataset.id2label
        self.tokenizer = base_dataset.tokenizer
        self.max_length = base_dataset.max_length
        self.index_mapping = [base_dataset.index_mapping[i] for i in subset_indices]

    def __len__(self):
        return len(self.index_mapping)


################################################################################
# 3. Custom Trainer with Optional Class-Aware Sampling
################################################################################


class MyTrainer(Trainer):

    def __init__(self, sampling_mode="iid", **kwargs):
        self.sampling_mode = sampling_mode
        super().__init__(**kwargs)

    def _get_dataloader_with_sampling(self, dataset, batch_size, shuffle):
        """
        Returns a DataLoader for the given dataset. If class-aware sampling is selected,
        it applies a WeightedRandomSampler; otherwise, it falls back to simple shuffling.
        """
        if self.sampling_mode == "class_aware":
            class_labels = []
            for file_idx, _ in dataset.index_mapping:
                tsv_path = dataset.tsv_paths[file_idx]
                class_label = os.path.basename(os.path.dirname(tsv_path))
                class_labels.append(class_label)
            counts = Counter(class_labels)
            weights = [1.0 / counts[label] for label in class_labels]
            print(f"Computed sample weights for {len(weights)} samples.")

            # Use numpy-based sampling if the dataset is very large
            if len(weights) > 2**24:
                import numpy as np

                weights_np = np.array(weights, dtype=np.float64)
                weights_np /= weights_np.sum()
                indices = np.random.choice(
                    len(weights_np),
                    size=len(dataset),
                    replace=True,
                    p=weights_np,
                )
                sampler = SubsetRandomSampler(indices)
                print("Using numpy-based sampling.")
            else:
                sampler = WeightedRandomSampler(
                    weights, num_samples=len(dataset), replacement=True
                )
                print("Using WeightedRandomSampler.")

            return DataLoader(
                dataset,
                batch_size=batch_size,
                sampler=sampler,
                collate_fn=self.data_collator,
            )
        else:
            return DataLoader(
                dataset,
                batch_size=batch_size,
                shuffle=shuffle,
                collate_fn=self.data_collator,
            )

    def get_train_dataloader(self):
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        return self._get_dataloader_with_sampling(
            dataset=self.train_dataset,
            batch_size=self.args.per_device_train_batch_size,
            shuffle=True,
        )

    # Removed get_eval_dataloader since we are not using validation.


################################################################################
# 4. Main Training Logic
################################################################################

# Misc flags
TIME_DATA_LOADING = False
fp16 = False
num_workers = 8
max_length = 256
per_device_train_batch_size = 64
logging_steps = 100
save_steps = 1000
save_total_limit = 2

num_classes_list = [1000]
epoch_list = [2]
sample_first_batch_list = [True]
sampling_modes = ["class_aware"]
model_sizes = ["large"]
use_cuda = True  # True for GPU, False for CPU

# If you want to resume from a checkpoint, specify here:
load_checkpoint_dir = None
# e.g. None if you want to train from scratch

# Model architecture if training from scratch
for model_size in model_sizes:
    for num_classes in num_classes_list:
        for num_train_epochs in epoch_list:
            for sample_first_batch in sample_first_batch_list:
                for sampling_mode in sampling_modes:

                    # Output directory structure
                    output_dir = (
                        f"./model_output_{num_classes}/"
                        f"model_size_{model_size}/"
                        f"sample_first_batch_{sample_first_batch}/"
                        f"sampling_mode_{sampling_mode}"
                    )

                    if model_size == "small":
                        model_arch = {
                            "vocab_size": len(
                                tokenizer
                            ),  # Matches your custom tokenizer
                            "n_embd": 128,
                            "n_layer": 2,
                            "n_head": 2,
                            "n_inner": 512,
                            "n_positions": max_length,
                            "attn_pdrop": 0.1,
                            "resid_pdrop": 0.1,
                            "embd_pdrop": 0.1,
                        }
                    elif model_size == "medium":
                        model_arch = {
                            "vocab_size": len(tokenizer),
                            "n_embd": 256,
                            "n_layer": 4,
                            "n_head": 4,
                            "n_inner": 1024,  # typically 4 * n_embd
                            "n_positions": max_length,
                            "attn_pdrop": 0.1,
                            "resid_pdrop": 0.1,
                            "embd_pdrop": 0.1,
                        }
                    elif model_size == "large":
                        model_arch = {
                            "vocab_size": len(tokenizer),
                            "n_embd": 512,
                            "n_layer": 8,
                            "n_head": 8,
                            "n_inner": 2048,  # typically 4 * n_embd
                            "n_positions": max_length,
                            "attn_pdrop": 0.1,
                            "resid_pdrop": 0.1,
                            "embd_pdrop": 0.1,
                        }
                    else:
                        raise ValueError(f"Unknown model size: {model_size}")

                    # 1. Get TSV paths + id2label
                    tsv_paths = get_tsv_paths(
                        num_classes=num_classes, sample_first_batch=sample_first_batch
                    )
                    id2label = load_id2label(num_classes=num_classes)

                    # 2. Dataset: use the full dataset for training only.
                    full_dataset = EfficientLazyDataset(
                        tsv_paths, id2label, tokenizer, max_length
                    )
                    train_dataset = (
                        full_dataset  # Use full dataset; no validation split.
                    )

                    # 3. Either load from checkpoint or create from scratch
                    if load_checkpoint_dir and os.path.exists(load_checkpoint_dir):
                        print(f"\nLoading model from checkpoint: {load_checkpoint_dir}")
                        model = GPT2LMHeadModel.from_pretrained(load_checkpoint_dir)
                        if model.config.vocab_size != len(tokenizer):
                            raise ValueError(
                                f"Checkpoint vocab_size ({model.config.vocab_size}) does not match "
                                f"the current tokenizer length ({len(tokenizer)}). "
                                "Make sure to add special tokens and resize the model before training."
                            )
                    else:
                        print("\nTraining from scratch.")
                        custom_config = GPT2Config(**model_arch)
                        model = GPT2LMHeadModel(custom_config)
                        model.resize_token_embeddings(len(tokenizer))
                        model.config.pad_token_id = tokenizer.pad_token_id
                        model.config.bos_token_id = tokenizer.bos_token_id
                        model.config.eos_token_id = tokenizer.eos_token_id

                    # 4. Move to device
                    device = "cuda" if use_cuda and torch.cuda.is_available() else "cpu"
                    model.to(device)
                    print(f"Using device: {device}")
                    print(f"Total model parameters: {model.num_parameters()}")

                    # 5. Training args (set eval_strategy to "no" to disable evaluation)
                    training_args = TrainingArguments(
                        output_dir=output_dir,
                        overwrite_output_dir=True,
                        num_train_epochs=num_train_epochs,
                        per_device_train_batch_size=per_device_train_batch_size,
                        eval_strategy="no",  # No evaluation
                        logging_steps=logging_steps,
                        save_steps=save_steps,
                        save_total_limit=save_total_limit,
                        fp16=fp16 and (device == "cuda"),
                        dataloader_num_workers=num_workers,
                    )

                    # 6. Build trainer (only train_dataset is provided)
                    trainer = MyTrainer(
                        model=model,
                        args=training_args,
                        train_dataset=train_dataset,
                        sampling_mode=sampling_mode,
                    )

                    # (Optional) Time data loading
                    if TIME_DATA_LOADING:
                        print("\nTiming the data loading for train dataloader...")
                        train_dataloader = trainer.get_train_dataloader()
                        start_time = time.time()
                        batch_count = 0
                        for batch in tqdm(
                            train_dataloader, desc="Iterating over train batches"
                        ):
                            batch_count += 1
                        end_time = time.time()
                        total_time = end_time - start_time
                        print(f"Total batches: {batch_count}")
                        print(f"Time to iterate train set: {total_time:.2f} seconds")
                        print(
                            f"Avg time per batch: {total_time / batch_count:.4f} seconds\n"
                        )

                    # 7. Train (resuming if checkpoint is set)
                    print(
                        f"\nStarting training. Checkpoint: {load_checkpoint_dir or 'None (scratch)'}"
                    )
                    trainer.train(
                        resume_from_checkpoint=(
                            load_checkpoint_dir if load_checkpoint_dir else None
                        )
                    )

                    print("\nTraining complete.\n" + "-" * 60)

  from .autonotebook import tqdm as notebook_tqdm


Found existing custom tokenizer at custom_tokenizer.
Loaded custom tokenizer with vocab size = 50261
Found 559 TSV files.
Building file offset index for lazy loading...


Indexing TSV files: 100%|██████████| 559/559 [00:01<00:00, 285.57it/s]


Dataset contains 4701873 samples.

Training from scratch.
Using device: cuda
Total model parameters: 16091904

Starting training. Checkpoint: None (scratch)
Computed sample weights for 4701873 samples.
Using WeightedRandomSampler.


`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


Step,Training Loss
100,6.0236
200,3.721
300,1.9498
400,1.0567
500,0.733
600,0.6101
700,0.5323
800,0.4824
900,0.4452
1000,0.4183


KeyboardInterrupt: 