# Augustus: Enhancing Epigraphic Language Models with POS Tagging, Material Classification, and Generative Capabilities 

## Import

In [2]:
import argparse
import json
import logging
import random
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import (
AdamW,
get_linear_schedule_with_warmup,
)
from torch.nn import CrossEntropyLoss
from cltk.tokenizers.lat.lat import LatinWordTokenizer as WordTokenizer # Not used in provided core logic
from cltk.tokenizers.lat.lat import LatinPunktSentenceTokenizer as SentenceTokenizer # Not used
from tensor2tensor.data_generators import text_encoder
import torch.nn as nn
from torch.nn import functional as F
from transformers import BertModel, BertPreTrainedModel, BertConfig
import wandb
import re
from collections import defaultdict
from sklearn.metrics import precision_recall_fscore_support, accuracy_score




In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


## Century Classification Labels

In [None]:
PERIOD_LABELS = {
    "Before_5C_BC": 0,  # year <= -501
    "C5_BC": 1,         # -500 to -401
    "C4_BC": 2,         # -400 to -301
    "C3_BC": 3,         # -300 to -201
    "C2_BC": 4,         # -200 to -101
    "C1_BC": 5,         # -100 to -1
    "C1_AD": 6,         # 0 to 100 (representative_year = 0 included here)
    "C2_AD": 7,         # 101 to 200
    "C3_AD": 8,         # 201 to 300
    "C4_AD": 9,         # 301 to 400
    "C5_AD": 10,        # 401 to 500
    "After_5C_AD": 11,  # year >= 501
    "Unknown": -1
}
NUM_DATE_LABELS = 12

## Model

In [None]:
class LatinBERTForMultiTask(nn.Module):
    def __init__(self, bert_path, num_date_labels, num_material_labels, num_pos_labels):
        super(LatinBERTForMultiTask, self).__init__()
        self.bert = BertModel.from_pretrained(bert_path, add_pooling_layer=True)
        config = self.bert.config

        self.cls_mlm = nn.Linear(config.hidden_size, config.vocab_size)

        self.num_date_labels = num_date_labels
        self.dropout_date = nn.Dropout(config.hidden_dropout_prob)
        self.classifier_date = nn.Linear(config.hidden_size, self.num_date_labels)

        self.num_material_labels = num_material_labels
        self.dropout_material = nn.Dropout(config.hidden_dropout_prob)
        self.classifier_material = nn.Linear(config.hidden_size, self.num_material_labels)

        self.num_pos_labels = num_pos_labels
        self.dropout_pos = nn.Dropout(config.hidden_dropout_prob)
        self.classifier_pos = nn.Linear(config.hidden_size, self.num_pos_labels)

    def forward(
        self, 
        input_ids=None, 
        attention_mask=None, 
        token_type_ids=None, 
        masked_lm_labels=None, 
        date_labels=None, 
        material_labels=None,
        pos_labels=None
    ):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            output_hidden_states=True
        )

        sequence_output = outputs.last_hidden_state 
        pooled_output = outputs.pooler_output     

        mlm_loss, date_loss, material_loss, pos_loss = None, None, None, None
        mlm_logits, date_logits, material_logits, pos_logits = None, None, None, None
        loss_fct = CrossEntropyLoss() 
        loss_fct_token = CrossEntropyLoss(ignore_index=-100) 

        mlm_logits = self.cls_mlm(sequence_output)
        if masked_lm_labels is not None:
            mlm_loss = loss_fct_token(mlm_logits.view(-1, self.bert.config.vocab_size), masked_lm_labels.view(-1))

        if self.num_date_labels > 0 and pooled_output is not None:
            pooled_output_date = self.dropout_date(pooled_output)
            date_logits = self.classifier_date(pooled_output_date)
            if date_labels is not None:
                date_loss = loss_fct(date_logits.view(-1, self.num_date_labels), date_labels.view(-1))

        if self.num_material_labels > 0 and pooled_output is not None:
            pooled_output_material = self.dropout_material(pooled_output)
            material_logits = self.classifier_material(pooled_output_material)
            if material_labels is not None:
                material_loss = loss_fct(material_logits.view(-1, self.num_material_labels), material_labels.view(-1))

        if self.num_pos_labels > 0 and sequence_output is not None:
            sequence_output_pos = self.dropout_pos(sequence_output)
            pos_logits = self.classifier_pos(sequence_output_pos)
            if pos_labels is not None:
                pos_loss = loss_fct_token(pos_logits.view(-1, self.num_pos_labels), pos_labels.view(-1))

        return {
            "mlm_loss": mlm_loss,
            "mlm_logits": mlm_logits,
            "date_loss": date_loss,
            "date_logits": date_logits,
            "material_loss": material_loss,
            "material_logits": material_logits,
            "pos_loss": pos_loss,
            "pos_logits": pos_logits,
            "hidden_states": outputs.hidden_states,
            "pooler_output": pooled_output
        }

## Tokenizer

In [None]:
class LatinTokenizer():
    UNK_TOKEN_IN_CORPUS = "<unk>"
    MODEL_UNK_TOKEN = "[UNK]"

    def __init__(self, encoder):
        self.vocab = {}
        self.reverseVocab = {}
        self.encoder = encoder

        self.vocab["[PAD]"] = 0
        self.vocab[self.MODEL_UNK_TOKEN] = 1 
        self.vocab["[CLS]"] = 2
        self.vocab["[SEP]"] = 3
        self.vocab["[MASK]"] = 4

        for key in self.encoder._subtoken_string_to_id:
            subword_id_in_our_vocab = self.encoder._subtoken_string_to_id[key] + 5
            self.vocab[key] = subword_id_in_our_vocab
            self.reverseVocab[subword_id_in_our_vocab] = key

    def convert_tokens_to_ids(self, tokens):
        ids = []
        for token in tokens:
            if token == "[PAD]":
                ids.append(0)
            elif token == self.MODEL_UNK_TOKEN: 
                ids.append(1)
            elif token == "[CLS]":
                ids.append(2)
            elif token == "[SEP]":
                ids.append(3)
            elif token == "[MASK]":
                ids.append(4)
            else:
                ids.append(self.vocab.get(token, self.vocab[self.MODEL_UNK_TOKEN]))
        return ids

    def convert_ids_to_tokens(self, ids_list):
        tokens = []
        for id_val in ids_list:
            if id_val == 0:
                tokens.append("[PAD]")
            elif id_val == 1:
                tokens.append(self.MODEL_UNK_TOKEN) 
            elif id_val == 2:
                tokens.append("[CLS]")
            elif id_val == 3:
                tokens.append("[SEP]")
            elif id_val == 4:
                tokens.append("[MASK]")
            else:
                tokens.append(self.reverseVocab.get(id_val, self.MODEL_UNK_TOKEN))
        return tokens
    
    def tokenize_word(self, word):
        if self.UNK_TOKEN_IN_CORPUS in word:
            return [self.MODEL_UNK_TOKEN]
        elif word in {"[PAD]", self.MODEL_UNK_TOKEN, "[CLS]", "[SEP]", "[MASK]"}:
            return [word]
        else:
            try:
                sub_ids_from_encoder = self.encoder.encode(word)
                subword_strings = []
                for sub_id in sub_ids_from_encoder:
                    our_subword_id = sub_id + 5
                    subword_string = self.reverseVocab.get(our_subword_id)
                    if subword_string is not None:
                        subword_strings.append(subword_string)
                    else:
                        subword_strings.append(self.MODEL_UNK_TOKEN)
                return subword_strings
            except Exception:
                return [self.MODEL_UNK_TOKEN]

    def tokenize(self, text):
        tokens = []
        raw_words = text.split()
        for word in raw_words:
            tokens.extend(self.tokenize_word(word))
        return tokens

    def __call__(self, text, max_length=128, padding=True, truncation=True,
                 return_attention_mask=True, add_special_tokens=True):

        if isinstance(text, str):
            tokens = self.tokenize(text)
        else:
            tokens = list(text)

        if add_special_tokens:
            if not tokens or tokens[0] != "[CLS]":
                tokens = ["[CLS]"] + tokens
            if not tokens or tokens[-1] != "[SEP]":
                 tokens = tokens + ["[SEP]"]

        if truncation and len(tokens) > max_length:
            tokens = tokens[:max_length - 1] + ["[SEP]"]

        input_ids = self.convert_tokens_to_ids(tokens)
        attention_mask_list = [1] * len(input_ids) if return_attention_mask else None

        if padding:
            pad_length = max_length - len(input_ids)
            if pad_length > 0 :
                input_ids = input_ids + [self.vocab["[PAD]"]] * pad_length
                if attention_mask_list:
                    attention_mask_list = attention_mask_list + [0] * pad_length
            if len(input_ids) > max_length: 
                input_ids = input_ids[:max_length]
                if attention_mask_list:
                    attention_mask_list = attention_mask_list[:max_length]

        result = {"input_ids": input_ids}
        if return_attention_mask:
            result["attention_mask"] = attention_mask_list
        return result

    def get_vocab_size(self):
        return len(self.vocab)

def get_period_label_for_record(record_dating_info):
    year_from_raw = record_dating_info.get("year_from")
    year_to_raw = record_dating_info.get("year_to")

    try:
        year_from = int(year_from_raw) if year_from_raw is not None and str(year_from_raw).strip() != "" else None
    except ValueError: year_from = None
    try:
        year_to = int(year_to_raw) if year_to_raw is not None and str(year_to_raw).strip() != "" else None
    except ValueError: year_to = None

    representative_year = None
    if year_from is not None and year_to is not None:
        if year_from > year_to: return PERIOD_LABELS["Unknown"]
        representative_year = (year_from + year_to) / 2.0 
    elif year_from is not None:
        representative_year = float(year_from)
    elif year_to is not None:
        representative_year = float(year_to)
    else:
        return PERIOD_LABELS["Unknown"]

    if representative_year is None: return PERIOD_LABELS["Unknown"] 

    if representative_year <= -501: return PERIOD_LABELS["Before_5C_BC"]
    elif -500 <= representative_year <= -401: return PERIOD_LABELS["C5_BC"]
    elif -400 <= representative_year <= -301: return PERIOD_LABELS["C4_BC"]
    elif -300 <= representative_year <= -201: return PERIOD_LABELS["C3_BC"]
    elif -200 <= representative_year <= -101: return PERIOD_LABELS["C2_BC"]
    elif -100 <= representative_year <= -1: return PERIOD_LABELS["C1_BC"]
    elif 0 <= representative_year <= 100: return PERIOD_LABELS["C1_AD"] 
    elif 101 <= representative_year <= 200: return PERIOD_LABELS["C2_AD"]
    elif 201 <= representative_year <= 300: return PERIOD_LABELS["C3_AD"]
    elif 301 <= representative_year <= 400: return PERIOD_LABELS["C4_AD"]
    elif 401 <= representative_year <= 500: return PERIOD_LABELS["C5_AD"]
    elif representative_year >= 501: return PERIOD_LABELS["After_5C_AD"]
    else:
        return PERIOD_LABELS["Unknown"]

## Dataset

In [None]:
class EpigraphDataset(Dataset):
    def __init__(self, source_data_file, tokenizer, max_length, max_unk_percentage, material_to_id, pos_to_id):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.examples = []
        self.material_to_id = material_to_id
        self.pos_to_id = pos_to_id
        self.max_unk_percentage = max_unk_percentage
        self.unk_token_id = tokenizer.vocab[tokenizer.MODEL_UNK_TOKEN]

        num_records_in_file, num_skipped_missing_fields = 0, 0
        num_examples_before_filtering, num_filtered_unk = 0, 0
        num_filtered_date, num_filtered_material, num_filtered_pos_mismatch = 0, 0, 0

        print(f"Reading and processing data from {source_data_file}...")
        with open(source_data_file, 'r', encoding='utf-8') as f:
            all_data = json.load(f)
        num_records_in_file = len(all_data)

        for record in all_data:
            record_id = record.get("record_number")
            text = record.get("parsed_field", "").strip() 
            if not record_id or not text:
                num_skipped_missing_fields += 1
                continue
            
            num_examples_before_filtering += 1
            
            date_label = get_period_label_for_record(record.get("dating", {}))
            material_raw = record.get("material", "Unknown").lower().strip()
            material_label = self.material_to_id.get(material_raw, self.material_to_id["Unknown"])
            
            if date_label == PERIOD_LABELS["Unknown"]:
                num_filtered_date += 1
                continue

            pos_tags_data = record.get("pos_tags", [])
            
            subwords, pos_ids = [], []

            if pos_tags_data and isinstance(pos_tags_data, list):
                words = [item[0] for item in pos_tags_data]
                tags = [item[1] for item in pos_tags_data]
                
                for i, word in enumerate(words):
                    word_subtokens = self.tokenizer.tokenize_word(word)
                    tag = tags[i]
                    tag_id = self.pos_to_id.get(tag, self.pos_to_id["Unknown"])
                    subwords.extend(word_subtokens)
                    pos_ids.extend([tag_id] * len(word_subtokens))
            else:
                num_filtered_pos_mismatch += 1
                subwords = self.tokenizer.tokenize(text)
                pos_ids = [-100] * len(subwords)
                
            subwords = ["[CLS]"] + subwords
            pos_ids = [-100] + pos_ids
            
            if len(subwords) > self.max_length - 1: 
                subwords = subwords[:self.max_length - 1]
                pos_ids = pos_ids[:self.max_length - 1]
            
            subwords.append("[SEP]")
            pos_ids.append(-100)

            input_ids = self.tokenizer.convert_tokens_to_ids(subwords)

            unk_count = input_ids.count(self.unk_token_id)
            unk_percent = (unk_count / len(input_ids)) * 100 if input_ids else 0
            if unk_percent > self.max_unk_percentage:
                num_filtered_unk += 1
                continue
                
            self.examples.append({
                "record_id": record_id,
                "input_ids": input_ids,
                "date_label": date_label,
                "material_label": material_label,
                "pos_labels": pos_ids
            })

        print(f"Total records read: {num_records_in_file}")
        print(f"  Skipped (no ID/text): {num_skipped_missing_fields}")
        print(f"  Records before filtering: {num_examples_before_filtering}")
        print(f"  Filtered (UNK > {self.max_unk_percentage}%): {num_filtered_unk}")
        print(f"  Filtered (Unknown date): {num_filtered_date}")
        print(f"  Records with missing/malformed 'pos_tags' field: {num_filtered_pos_mismatch}")
        print(f"Loaded {len(self.examples)} examples after all filtering.")
        
    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]

## Data Collator for Multi-Task

In [None]:
class LatinDataCollatorForMultiTask:
    def __init__(self, tokenizer, mlm_probability=0.15):
        self.tokenizer = tokenizer
        self.mlm_probability = mlm_probability

    def __call__(self, examples):
        input_ids_list = [e["input_ids"] for e in examples]
        pos_labels_list = [e["pos_labels"] for e in examples]
        date_labels_list = [e["date_label"] for e in examples]
        material_labels_list = [e["material_label"] for e in examples]

        batch_input_ids = self._pad_sequences(input_ids_list, pad_value=self.tokenizer.vocab["[PAD]"])
        batch_pos_labels = self._pad_sequences(pos_labels_list, pad_value=-100) # Pad with -100 for loss ignore
        batch_size, seq_length = batch_input_ids.size()

        attention_mask = (batch_input_ids != self.tokenizer.vocab["[PAD]"]).long()

        mlm_labels = batch_input_ids.clone()
        special_tokens_mask = torch.zeros_like(batch_input_ids, dtype=torch.bool)
        for special_id in [self.tokenizer.vocab[tok] for tok in ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]]:
            special_tokens_mask |= (batch_input_ids == special_id)
            
        probability_matrix = torch.full(mlm_labels.shape, self.mlm_probability)
        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        mlm_labels[~masked_indices] = -100  

        indices_replaced = torch.bernoulli(torch.full(mlm_labels.shape, 0.8)).bool() & masked_indices
        batch_input_ids[indices_replaced] = self.tokenizer.vocab["[MASK]"]

        indices_random = torch.bernoulli(torch.full(mlm_labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(5, self.tokenizer.get_vocab_size(), mlm_labels.shape, dtype=torch.long)
        batch_input_ids[indices_random] = random_words[indices_random]

        batch_date_labels = torch.tensor(date_labels_list, dtype=torch.long)
        batch_material_labels = torch.tensor(material_labels_list, dtype=torch.long)

        return {
            "input_ids": batch_input_ids,
            "attention_mask": attention_mask,
            "masked_lm_labels": mlm_labels,      
            "date_labels": batch_date_labels,         
            "material_labels": batch_material_labels, 
            "pos_labels": batch_pos_labels        
        }

    def _pad_sequences(self, sequences, pad_value):
        max_len = max(len(seq) for seq in sequences) if sequences else 0
        if max_len == 0: return torch.empty((len(sequences), 0), dtype=torch.long)
        
        padded_sequences = torch.full((len(sequences), max_len), pad_value, dtype=torch.long)
        for i, seq in enumerate(sequences):
            padded_sequences[i, :len(seq)] = torch.tensor(seq, dtype=torch.long)
        return padded_sequences

## Evaluation

In [None]:
def evaluate(model, eval_dataloader, args_config, pos_class_names, material_class_names):
    model.eval()
    total_eval_loss = 0.0
    total_eval_mlm_loss, total_eval_date_loss, total_eval_material_loss, total_eval_pos_loss = 0.0, 0.0, 0.0, 0.0
    nb_eval_steps = 0

    all_mlm_preds, all_mlm_labels = [], []
    all_date_preds, all_date_labels = [], []
    all_material_preds, all_material_labels = [], []
    all_pos_preds, all_pos_labels = [], []

    for batch in eval_dataloader:
        for key, value in batch.items():
            batch[key] = value.to(device)
            
        with torch.no_grad():
            outputs = model(**batch)

            mlm_loss = outputs.get("mlm_loss")
            date_loss = outputs.get("date_loss")
            material_loss = outputs.get("material_loss")
            pos_loss = outputs.get("pos_loss")

            current_batch_loss = 0
            if mlm_loss is not None: total_eval_mlm_loss += mlm_loss.item(); current_batch_loss += mlm_loss
            if date_loss is not None: total_eval_date_loss += date_loss.item(); current_batch_loss += args_config["date_loss_weight"] * date_loss
            if material_loss is not None: total_eval_material_loss += material_loss.item(); current_batch_loss += args_config["material_loss_weight"] * material_loss
            if pos_loss is not None: total_eval_pos_loss += pos_loss.item(); current_batch_loss += args_config["pos_loss_weight"] * pos_loss
            total_eval_loss += current_batch_loss.item()

            if outputs.get("mlm_logits") is not None:
                mask = (batch["masked_lm_labels"] != -100)
                all_mlm_preds.extend(torch.argmax(outputs["mlm_logits"], dim=-1)[mask].cpu().tolist())
                all_mlm_labels.extend(batch["masked_lm_labels"][mask].cpu().tolist())
            
            if outputs.get("date_logits") is not None: 
                all_date_preds.extend(torch.argmax(outputs["date_logits"], dim=-1).cpu().tolist())
                all_date_labels.extend(batch["date_labels"].cpu().tolist())
            
            if outputs.get("material_logits") is not None:
                all_material_preds.extend(torch.argmax(outputs["material_logits"], dim=-1).cpu().tolist())
                all_material_labels.extend(batch["material_labels"].cpu().tolist())

            if outputs.get("pos_logits") is not None:
                mask = (batch["pos_labels"] != -100)
                all_pos_preds.extend(torch.argmax(outputs["pos_logits"], dim=-1)[mask].cpu().tolist())
                all_pos_labels.extend(batch["pos_labels"][mask].cpu().tolist())

        nb_eval_steps += 1

    avg_loss = lambda total_loss: total_loss / nb_eval_steps if nb_eval_steps > 0 else 0
    results = {
        "eval_loss_combined": avg_loss(total_eval_loss),
        "eval_mlm_loss": avg_loss(total_eval_mlm_loss),
        "eval_date_loss": avg_loss(total_eval_date_loss),
        "eval_material_loss": avg_loss(total_eval_material_loss),
        "eval_pos_loss": avg_loss(total_eval_pos_loss),
        "eval_mlm_perplexity": torch.exp(torch.tensor(avg_loss(total_eval_mlm_loss))).item()
    }

    def get_cls_metrics(labels, preds, num_labels):
        if not labels: return 0, 0, 0, 0
        accuracy = accuracy_score(labels, preds)
        p, r, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted', labels=list(range(num_labels)), zero_division=0)
        return accuracy, p, r, f1
    
    results["eval_mlm_accuracy"] = accuracy_score(all_mlm_labels, all_mlm_preds) if all_mlm_labels else 0
    results["eval_date_accuracy"], results["eval_date_precision"], results["eval_date_recall"], results["eval_date_f1"] = \
        get_cls_metrics(all_date_labels, all_date_preds, args_config["num_date_labels"])
    results["eval_material_accuracy"], results["eval_material_precision"], results["eval_material_recall"], results["eval_material_f1"] = \
        get_cls_metrics(all_material_labels, all_material_preds, args_config["num_material_labels"])
    results["eval_pos_accuracy"], results["eval_pos_precision"], results["eval_pos_recall"], results["eval_pos_f1"] = \
        get_cls_metrics(all_pos_labels, all_pos_preds, args_config["num_pos_labels"])

    return results

## Logging

In [None]:
def log_training_example(args, tokenizer, input_ids_batch, labels_batch, logits_batch, epoch, step, global_step):
    if not hasattr(log_training_example, "log_file_handler"):
        log_file_path = "training_examples.txt"
        log_training_example.log_file_handler = open(log_file_path, "a", encoding="utf-8")
        print(f"Logging training examples to: {log_file_path}")
        if os.path.getsize(log_file_path) == 0:
            log_training_example.log_file_handler.write("Epoch | Step (Batch) | Global Step | Original | Masked Input | Target | Prediction\n")
            log_training_example.log_file_handler.write("----------------------------------------------------------------------------------\n")

    num_examples_to_log = min(2, input_ids_batch.size(0))
    for i in range(num_examples_to_log):
        input_sequence_ids = input_ids_batch[i].tolist()
        label_sequence_ids = labels_batch[i].tolist() 
        logit_sequence = logits_batch[i] 

        reconstructed_original_ids = []
        for k_idx, l_id in enumerate(label_sequence_ids):
            if l_id != -100: reconstructed_original_ids.append(l_id)
            else: reconstructed_original_ids.append(input_sequence_ids[k_idx])

        original_tokens = tokenizer.convert_ids_to_tokens(reconstructed_original_ids)
        masked_input_tokens = tokenizer.convert_ids_to_tokens(input_sequence_ids)

        log_training_example.log_file_handler.write(f"Epoch {epoch} | Step {step} | GStep {global_step}\n")
        log_training_example.log_file_handler.write(f"  Original Approx: {' '.join(original_tokens)}\n")
        log_training_example.log_file_handler.write(f"  Masked Input:    {' '.join(masked_input_tokens)}\n")

        masked_positions = [idx for idx, token_id in enumerate(label_sequence_ids) if token_id != -100]
        if not masked_positions:
            log_training_example.log_file_handler.write("  No tokens masked in this example for MLM prediction.\n")
        else:
            for pos in masked_positions:
                target_token_id = label_sequence_ids[pos]
                target_token = tokenizer.convert_ids_to_tokens([target_token_id])[0]
                predicted_token_id = torch.argmax(logit_sequence[pos]).item()
                predicted_token = tokenizer.convert_ids_to_tokens([predicted_token_id])[0]
                top_3_tokens_ids = torch.topk(logit_sequence[pos], 3).indices.tolist()
                top_3_tokens = tokenizer.convert_ids_to_tokens(top_3_tokens_ids)
                log_training_example.log_file_handler.write(
                    f"  Pos {pos}: Target='{target_token}' ({target_token_id}), Predicted='{predicted_token}' ({predicted_token_id}), Top-3: {top_3_tokens}\n"
                )
        log_training_example.log_file_handler.write("----------------------------------------------------------------------------------\n")
    log_training_example.log_file_handler.flush()

def close_log_file():
    if hasattr(log_training_example, "log_file_handler") and log_training_example.log_file_handler:
        print("Closing training examples log file.")
        log_training_example.log_file_handler.close()
        delattr(log_training_example, "log_file_handler")

## Training

In [None]:
def train(args, model, train_dataset, eval_dataset, tokenizer, pos_class_names, material_class_names):
    train_dataloader = DataLoader(
        train_dataset, batch_size=args["batch_size"], shuffle=True,
        collate_fn=LatinDataCollatorForMultiTask(tokenizer, mlm_probability=args["mlm_probability"])
    )
    eval_dataloader = DataLoader(
        eval_dataset, batch_size=args["batch_size"], shuffle=False,
        collate_fn=LatinDataCollatorForMultiTask(tokenizer, mlm_probability=args["mlm_probability"])
    )

    t_total = len(train_dataloader) // args.get("gradient_accumulation_steps", 1) * args["num_train_epochs"]
    optimizer = AdamW(model.parameters(), lr=args["learning_rate"], eps=args["adam_epsilon"])
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args["warmup_steps"], num_training_steps=t_total)

    print("***** Running training *****")
    print(f"  Num examples = {len(train_dataset)}")
    print(f"  Num Epochs = {args['num_train_epochs']}")
    print(f"  Batch size = {args['batch_size']}")
    print(f"  Total optimization steps = {t_total}")

    global_step = 0
    total_train_loss_accumulator = 0.0
    logging_loss_accumulators = defaultdict(float)
    model.zero_grad()
    best_eval_loss = float('inf')

    for epoch in range(int(args["num_train_epochs"])):
        for step, batch in enumerate(train_dataloader):
            model.train()
            for key, value in batch.items():
                batch[key] = value.to(device)

            outputs = model(**batch)

            mlm_loss = outputs.get("mlm_loss")
            date_loss = outputs.get("date_loss")
            material_loss = outputs.get("material_loss")
            pos_loss = outputs.get("pos_loss")

            combined_loss = 0
            if mlm_loss is not None: 
                combined_loss += mlm_loss
                logging_loss_accumulators["mlm"] += mlm_loss.item()
            if date_loss is not None: 
                weighted_loss = args["date_loss_weight"] * date_loss
                combined_loss += weighted_loss
                logging_loss_accumulators["date"] += weighted_loss.item()
            if material_loss is not None: 
                weighted_loss = args["material_loss_weight"] * material_loss
                combined_loss += weighted_loss
                logging_loss_accumulators["material"] += weighted_loss.item()
            if pos_loss is not None: 
                weighted_loss = args["pos_loss_weight"] * pos_loss
                combined_loss += weighted_loss
                logging_loss_accumulators["pos"] += weighted_loss.item()

            if isinstance(combined_loss, torch.Tensor):
                combined_loss.backward()
                total_train_loss_accumulator += combined_loss.item()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), args["max_grad_norm"])
            optimizer.step()
            scheduler.step()
            model.zero_grad()
            global_step += 1

            if args["logging_steps"] > 0 and global_step % args["logging_steps"] == 0:
                avg_combined_loss = (total_train_loss_accumulator - getattr(train, 'last_logged_loss', 0)) / args["logging_steps"]
                train.last_logged_loss = total_train_loss_accumulator
                
                print(f"Epoch {epoch}, GStep {global_step}, AvgCombinedLoss: {avg_combined_loss:.4f}, LR: {scheduler.get_last_lr()[0]:.2e}")

                for task in ["mlm", "date", "material", "pos"]:
                    avg_loss = logging_loss_accumulators[task] / args["logging_steps"]
                    print(f"  - Avg {task.upper()} Loss: {avg_loss:.4f}")
                    logging_loss_accumulators[task] = 0.0
                
                if outputs.get("mlm_logits") is not None:
                    log_training_example(args, tokenizer, batch["input_ids"].cpu(), batch["masked_lm_labels"].cpu(), outputs["mlm_logits"].cpu().detach(), epoch, step, global_step)

            if args["save_steps"] > 0 and global_step % args["save_steps"] == 0:
                output_dir = os.path.join(args["output_dir"], f"checkpoint-{global_step}")
                os.makedirs(output_dir, exist_ok=True)
                torch.save(model.state_dict(), os.path.join(output_dir, "pytorch_model.bin"))
                model.bert.config.to_json_file(os.path.join(output_dir, "config.json"))
                print(f"Saving model checkpoint to {output_dir}")

        eval_metrics = evaluate(model, eval_dataloader, args, pos_class_names, material_class_names)
        print(f"--- Evaluation results after epoch {epoch} ---")
        for key, value in eval_metrics.items():
            print(f"  {key}: {value:.4f}")

        if eval_metrics['eval_loss_combined'] < best_eval_loss:
            best_eval_loss = eval_metrics['eval_loss_combined']
            best_model_dir = os.path.join(args["output_dir"], "best_model")
            os.makedirs(best_model_dir, exist_ok=True)
            torch.save(model.state_dict(), os.path.join(best_model_dir, "pytorch_model.bin"))
            model.bert.config.to_json_file(os.path.join(best_model_dir, "config.json"))
            print(f"Saving best model with eval combined loss {best_eval_loss:.4f} to {best_model_dir}")

    close_log_file()
    final_model_dir = os.path.join(args["output_dir"], "final_model")
    os.makedirs(final_model_dir, exist_ok=True)
    torch.save(model.state_dict(), os.path.join(final_model_dir, "pytorch_model.bin"))
    model.bert.config.to_json_file(os.path.join(final_model_dir, "config.json"))
    print(f"Saving final model to {final_model_dir}")
    return global_step, total_train_loss_accumulator / global_step if global_step > 0 else 0

## Predict

In [None]:
def predict_missing_words(model, tokenizer, text, max_predictions=3):
    model.eval()
    model_bert_component = model.bert if hasattr(model, 'bert') else model
    max_len = model_bert_component.config.max_position_embeddings

    encoded_input = tokenizer(text, max_length=max_len, padding=True, truncation=True, return_attention_mask=True, add_special_tokens=True)
    input_ids = torch.tensor([encoded_input["input_ids"]]).to(device)
    attention_mask = torch.tensor([encoded_input["attention_mask"]]).to(device)
    mask_token_id = tokenizer.vocab["[MASK]"]
    mask_positions = (input_ids[0] == mask_token_id).nonzero(as_tuple=True)[0].tolist()

    if not mask_positions:
        return "No [MASK] tokens found in the text to predict."

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        prediction_scores = outputs["mlm_logits"]

    results = []
    original_tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist())
    for mask_pos in mask_positions:
        logits_for_mask = prediction_scores[0, mask_pos]
        top_k = torch.topk(logits_for_mask, max_predictions)
        predicted_tokens = tokenizer.convert_ids_to_tokens(top_k.indices.tolist())
        results.append({
            "position": mask_pos,
            "predictions": [{"token": tok, "score": score} for tok, score in zip(predicted_tokens, top_k.values.tolist())]
        })
    return {"full_input": " ".join(original_tokens), "predictions_for_masks": results}

## Hyper Parameters

In [None]:
args = {
    "bert_path": "bert_model/",
    "tokenizer_path": "latin-bert/models/subword_tokenizer_latin/latin.subword.encoder",
    "source_data_file": "data/final_results.json",
    "output_dir": "runs/multitask_mlm_date_material_pos",

    "max_seq_length": 128,
    "batch_size": 16,
    "learning_rate": 7e-5,
    "weight_decay": 0.05,
    "adam_epsilon": 1e-8,
    "max_grad_norm": 1.0,
    "num_train_epochs": 10.0,
    "warmup_steps": 8000,
    "mlm_probability": 0.15,
    
    "date_loss_weight": 0.5,
    "material_loss_weight": 0.2, 
    "pos_loss_weight": 0.7,
    
    "logging_steps": 100,
    "save_steps": 500,
    "seed": 42,
    "eval_split": 0.1,
    "max_unk_percentage": 50.0,
    
    "predict": False,
    "predict_text": "dominus [MASK] servum et <unk> est in horto.",
    
    "num_date_labels": NUM_DATE_LABELS,
    "num_material_labels": 0,
    "num_pos_labels": 0
}

## Main

In [None]:
def create_vocabs_from_data(data_file_path):
    print(f"Scanning {data_file_path} to build vocabularies for Material and POS...")
    with open(data_file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    material_set = set()
    pos_tag_set = set()
    
    for record in data:
        material = record.get("material")
        if material and isinstance(material, str):
            material_set.add(material.lower().strip())
        
        pos_tags = record.get("pos_tags")
        if pos_tags and isinstance(pos_tags, list):
            for _, tag in pos_tags:
                if tag and isinstance(tag, str):
                    pos_tag_set.add(tag)
    
    material_vocab = {name: i for i, name in enumerate(sorted(list(material_set)))}
    if "Unknown" not in material_vocab:
        material_vocab["Unknown"] = len(material_vocab)

    pos_vocab = {name: i for i, name in enumerate(sorted(list(pos_tag_set)))}
    if "Unknown" not in pos_vocab:
        pos_vocab["Unknown"] = len(pos_vocab)
        
    print(f"Found {len(material_vocab)} unique materials.")
    print(f"Found {len(pos_vocab)} unique POS tags.")
    return material_vocab, pos_vocab

random.seed(args["seed"])
np.random.seed(args["seed"])
torch.manual_seed(args["seed"])
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(args["seed"])

material_to_id, pos_to_id = create_vocabs_from_data(args["source_data_file"])
args["num_material_labels"] = len(material_to_id)
args["num_pos_labels"] = len(pos_to_id)

material_class_names = [name for name, _ in sorted(material_to_id.items(), key=lambda item: item[1])]
pos_class_names = [name for name, _ in sorted(pos_to_id.items(), key=lambda item: item[1])]

encoder = text_encoder.SubwordTextEncoder(args["tokenizer_path"])
tokenizer = LatinTokenizer(encoder)
print(f"Vocabulary size (for MLM): {tokenizer.get_vocab_size()} tokens")

if args.get("predict", False):
    print(f"\nRunning in PREDICTION mode for text: '{args['predict_text']}'")
    predict_model_path = os.path.join(args["output_dir"], "best_model")
    if not os.path.exists(os.path.join(predict_model_path, "pytorch_model.bin")):
        print(f"Warning: Fine-tuned model not found at {predict_model_path}. Using base model.")
        predict_model_path = args["bert_path"]

    model = LatinBERTForMultiTask(
        bert_path=predict_model_path,
        num_date_labels=args["num_date_labels"],
        num_material_labels=args["num_material_labels"],
        num_pos_labels=args["num_pos_labels"]
    )
    if os.path.exists(os.path.join(predict_model_path, "pytorch_model.bin")):
        model.load_state_dict(torch.load(os.path.join(predict_model_path, "pytorch_model.bin"), map_location=device))
    model.to(device)
    
    prediction_results = predict_missing_words(model, tokenizer, args["predict_text"])
    print("\nMLM Predictions:")
    print(json.dumps(prediction_results, indent=2))

else:
    print("\nRunning in TRAINING mode.")
    model = LatinBERTForMultiTask(
        bert_path=args["bert_path"],
        num_date_labels=args["num_date_labels"],
        num_material_labels=args["num_material_labels"],
        num_pos_labels=args["num_pos_labels"]
    )
    model.to(device)

    full_dataset = EpigraphDataset(
        source_data_file=args["source_data_file"],
        tokenizer=tokenizer,
        max_length=args["max_seq_length"],
        max_unk_percentage=args["max_unk_percentage"],
        material_to_id=material_to_id,
        pos_to_id=pos_to_id
    )

    if len(full_dataset) == 0:
        raise ValueError("Dataset is empty after filtering. Check data source and filtering criteria.")

    eval_size = int(len(full_dataset) * args["eval_split"])
    if eval_size == 0 and len(full_dataset) > 1: eval_size = 1
    train_size = len(full_dataset) - eval_size

    if train_size <= 0:
        raise ValueError(f"Train size is not positive. Total: {len(full_dataset)}, Eval: {eval_size}.")

    train_dataset, eval_dataset = random_split(full_dataset, [train_size, eval_size], generator=torch.Generator().manual_seed(args["seed"]))
    print(f"Training on {train_size} examples, evaluating on {eval_size} examples.")

    train(args, model, train_dataset, eval_dataset, tokenizer, pos_class_names, material_class_names)
    print("\nTraining complete.")


Scanning data/final_results.json to build vocabularies for Material and POS...
Found 497 unique materials.
Found 16 unique POS tags.

Vocabulary size (for MLM): 32900 tokens

Running in TRAINING mode.
Reading and processing data from data/final_results.json...
Total records read: 103542
  Skipped (no ID/text): 821
  Records before filtering: 102721
  Filtered (UNK > 50.0%): 114
  Filtered (Unknown date): 12567
  Records with missing/malformed 'pos_tags' field: 0
Loaded 90040 examples after all filtering.
Training on 81036 examples, evaluating on 9004 examples.




***** Running training *****
  Num examples = 81036
  Num Epochs = 10.0
  Batch size = 16
  Total optimization steps = 50650.0
Epoch 0, GStep 100, AvgCombinedLoss: 14.8735, LR: 8.75e-07
  - Avg MLM Loss: 10.4666
  - Avg DATE Loss: 1.1944
  - Avg MATERIAL Loss: 1.2471
  - Avg POS Loss: 1.9653
Logging training examples to: training_examples.txt
Epoch 0, GStep 200, AvgCombinedLoss: 14.3780, LR: 1.75e-06
  - Avg MLM Loss: 10.4041
  - Avg DATE Loss: 1.0037
  - Avg MATERIAL Loss: 1.2101
  - Avg POS Loss: 1.7602
Epoch 0, GStep 300, AvgCombinedLoss: 13.5705, LR: 2.62e-06
  - Avg MLM Loss: 10.2423
  - Avg DATE Loss: 0.7947
  - Avg MATERIAL Loss: 1.0988
  - Avg POS Loss: 1.4348
Epoch 0, GStep 400, AvgCombinedLoss: 12.6711, LR: 3.50e-06
  - Avg MLM Loss: 9.7703
  - Avg DATE Loss: 0.7222
  - Avg MATERIAL Loss: 0.8984
  - Avg POS Loss: 1.2802
Epoch 0, GStep 500, AvgCombinedLoss: 11.7182, LR: 4.37e-06
  - Avg MLM Loss: 9.1706
  - Avg DATE Loss: 0.6819
  - Avg MATERIAL Loss: 0.6704
  - Avg POS Loss: 

## Top-K Accuracy Evaluation

In [None]:
print("--- Re-evaluating the best model for Top-K MLM Accuracy ---")

best_model_path = os.path.join(args["output_dir"], "best_model")
model_weights_path = os.path.join(best_model_path, "pytorch_model.bin")

if not os.path.exists(model_weights_path):
    print(f"Error: Best model not found at {model_weights_path}. Cannot proceed.")
else:
    print(f"Loading best model from: {best_model_path}")
    model_to_evaluate = LatinBERTForMultiTask(
        bert_path=args["bert_path"],
        num_date_labels=args["num_date_labels"],
        num_material_labels=args["num_material_labels"],
        num_pos_labels=args["num_pos_labels"]
    )
    model_to_evaluate.load_state_dict(torch.load(model_weights_path, map_location=device))
    model_to_evaluate.to(device)
    model_to_evaluate.eval()

    print("Re-creating the evaluation dataset to ensure consistency...")
    material_to_id_reeval, pos_to_id_reeval = create_vocabs_from_data(args["source_data_file"])

    full_dataset_reeval = EpigraphDataset(
        source_data_file=args["source_data_file"],
        tokenizer=tokenizer,
        max_length=args["max_seq_length"],
        max_unk_percentage=args["max_unk_percentage"],
        material_to_id=material_to_id_reeval,
        pos_to_id=pos_to_id_reeval
    )

    eval_size = int(len(full_dataset_reeval) * args["eval_split"])
    if eval_size == 0 and len(full_dataset_reeval) > 1: eval_size = 1
    train_size = len(full_dataset_reeval) - eval_size

    _, eval_dataset_reeval = random_split(
        full_dataset_reeval, [train_size, eval_size],
        generator=torch.Generator().manual_seed(args["seed"])
    )

    eval_dataloader_reeval = DataLoader(
        eval_dataset_reeval,
        batch_size=args["batch_size"],
        shuffle=False, 
        collate_fn=LatinDataCollatorForMultiTask(tokenizer, mlm_probability=args["mlm_probability"])
    )
    print(f"Recreated evaluation dataset with {len(eval_dataset_reeval)} examples.")

    mlm_total_masked_tokens = 0
    mlm_correct_top1 = 0
    mlm_correct_top5 = 0
    mlm_correct_top10 = 0

    with torch.no_grad():
        for batch in eval_dataloader_reeval:
            for key, value in batch.items():
                batch[key] = value.to(device)
            
            outputs = model_to_evaluate(**batch)

            if outputs.get("mlm_logits") is not None:
                mlm_logits = outputs["mlm_logits"]
                mlm_labels = batch["masked_lm_labels"]
                mask = (mlm_labels != -100)
                
                batch_masked_count = mask.sum().item()
                if batch_masked_count > 0:
                    mlm_total_masked_tokens += batch_masked_count
                    _, top_10_indices = torch.topk(mlm_logits, k=10, dim=-1)
                    expanded_labels = mlm_labels.unsqueeze(-1)
                    masked_top_k_matches = (expanded_labels == top_10_indices) & mask.unsqueeze(-1)
                    
                    mlm_correct_top1 += masked_top_k_matches[:, :, 0].sum().item()
                    mlm_correct_top5 += masked_top_k_matches[:, :, :5].any(dim=-1).sum().item()
                    mlm_correct_top10 += masked_top_k_matches.any(dim=-1).sum().item()

    print("\n--- MLM Top-K Accuracy Results ---")
    if mlm_total_masked_tokens > 0:
        acc1 = (mlm_correct_top1 / mlm_total_masked_tokens) * 100
        acc5 = (mlm_correct_top5 / mlm_total_masked_tokens) * 100
        acc10 = (mlm_correct_top10 / mlm_total_masked_tokens) * 100
        
        print(f"Total Masked Tokens Evaluated: {mlm_total_masked_tokens}")
        print(f"  Accuracy@1:  {acc1:.2f}%")
        print(f"  Accuracy@5:  {acc5:.2f}%")
        print(f"  Accuracy@10: {acc10:.2f}%")
    else:
        print("No masked tokens were found in the evaluation set to calculate accuracy.")

--- Re-evaluating the best model for Top-K MLM Accuracy ---
Loading best model from: runs/multitask_mlm_date_material_pos\best_model
Re-creating the evaluation dataset to ensure consistency...
Scanning data/final_results.json to build vocabularies for Material and POS...
Found 497 unique materials.
Found 16 unique POS tags.
Reading and processing data from data/final_results.json...
Total records read: 103542
  Skipped (no ID/text): 821
  Records before filtering: 102721
  Filtered (UNK > 50.0%): 114
  Filtered (Unknown date): 12567
  Records with missing/malformed 'pos_tags' field: 0
Loaded 90040 examples after all filtering.
Recreated evaluation dataset with 9004 examples.

--- MLM Top-K Accuracy Results ---
Total Masked Tokens Evaluated: 30148
  Accuracy@1:  68.60%
  Accuracy@5:  80.70%
  Accuracy@10: 84.19%


## Generation experiments

In [None]:
from transformers import T5ForConditionalGeneration, T5Tokenizer, AdamW

T5_MODEL_NAME = "t5-small"
T5_OUTPUT_DIR = "t5_epigraph_generator/"
T5_NUM_EPOCHS = 3
T5_BATCH_SIZE = 8
T5_LEARNING_RATE = 5e-5


class T5DenoisingDataset(Dataset):
    def __init__(self, tokenizer, source_data_file, max_length=128):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.inputs = []
        self.targets = []
        
        print(f"Loading data from {source_data_file} for T5 fine-tuning...")
        with open(source_data_file, 'r', encoding='utf-8') as f:
            all_data = json.load(f)
            
        for record in all_data:
            text = record.get("parsed_field")
            if text and text.strip():
                input_text, target_text = self._create_denoising_example(text)
                self.inputs.append(input_text)
                self.targets.append(target_text)

    def _create_denoising_example(self, text, corruption_rate=0.25):
        words = text.split()
        num_words_to_corrupt = int(len(words) * corruption_rate)
        if num_words_to_corrupt == 0 and len(words) > 0:
            num_words_to_corrupt = 1

        corrupted_indices = sorted(random.sample(range(len(words)), k=num_words_to_corrupt))
        
        input_parts = []
        target_parts = []
        current_mask_id = 0
        last_index = -1

        for i in corrupted_indices:
            if i > last_index + 1:
                input_parts.append(" ".join(words[last_index+1:i]))
            
            input_parts.append(f"<extra_id_{current_mask_id}>")
            
            target_parts.append(f"<extra_id_{current_mask_id}>")
            target_parts.append(words[i])
            
            current_mask_id += 1
            last_index = i
        
        if last_index < len(words) - 1:
            input_parts.append(" ".join(words[last_index+1:]))

        target_parts.append(f"<extra_id_{current_mask_id}>")

        input_text = "infill: " + " ".join(input_parts)
        target_text = " ".join(target_parts)
        
        return input_text, target_text

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        source = self.tokenizer.batch_encode_plus(
            [self.inputs[idx]], max_length=self.max_length, padding='max_length', truncation=True, return_tensors="pt"
        )
        target = self.tokenizer.batch_encode_plus(
            [self.targets[idx]], max_length=self.max_length, padding='max_length', truncation=True, return_tensors="pt"
        )
        
        source_ids = source['input_ids'].squeeze()
        source_mask = source['attention_mask'].squeeze()
        target_ids = target['input_ids'].squeeze()

        return {"input_ids": source_ids, "attention_mask": source_mask, "labels": target_ids}


print("\n--- Starting T5 Fine-Tuning ---")

if os.path.exists(os.path.join(T5_OUTPUT_DIR, 'pytorch_model.bin')):
    print(f"Fine-tuned T5 model already found in {T5_OUTPUT_DIR}. Skipping training.")
else:
    t5_tokenizer = T5Tokenizer.from_pretrained(T5_MODEL_NAME)
    t5_model = T5ForConditionalGeneration.from_pretrained(T5_MODEL_NAME)
    t5_model.to(device)

    train_dataset_t5 = T5DenoisingDataset(t5_tokenizer, args["source_data_file"])
    train_dataloader_t5 = DataLoader(train_dataset_t5, batch_size=T5_BATCH_SIZE, shuffle=True)

    optimizer_t5 = AdamW(t5_model.parameters(), lr=T5_LEARNING_RATE)

    t5_model.train()
    for epoch in range(T5_NUM_EPOCHS):
        print(f"--- T5 Epoch {epoch+1}/{T5_NUM_EPOCHS} ---")
        for i, batch in enumerate(train_dataloader_t5):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            labels[labels == t5_tokenizer.pad_token_id] = -100

            outputs = t5_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            loss = outputs.loss
            
            if (i + 1) % 500 == 0:
                print(f"  Batch {i+1}/{len(train_dataloader_t5)}, Loss: {loss.item():.4f}")

            loss.backward()
            optimizer_t5.step()
            optimizer_t5.zero_grad()

    print("T5 training complete. Saving model...")
    os.makedirs(T5_OUTPUT_DIR, exist_ok=True)
    t5_model.save_pretrained(T5_OUTPUT_DIR)
    t5_tokenizer.save_pretrained(T5_OUTPUT_DIR)
    print(f"T5 model saved to {T5_OUTPUT_DIR}")


print("\n--- T5 Generation Examples ---")

try:
    t5_tokenizer_ft = T5Tokenizer.from_pretrained(T5_OUTPUT_DIR)
    t5_model_ft = T5ForConditionalGeneration.from_pretrained(T5_OUTPUT_DIR)
    t5_model_ft.to(device)
    t5_model_ft.eval()
    print("Fine-tuned T5 model loaded successfully.")

    def generate_epigraph(prompt_text, num_beams=5, max_length=50):
        input_text = f"infill: {prompt_text}"
        
        input_ids = t5_tokenizer_ft.encode(input_text, return_tensors="pt").to(device)
        
        outputs = t5_model_ft.generate(
            input_ids,
            max_length=max_length,
            num_beams=num_beams,
            early_stopping=True,
            no_repeat_ngram_size=2
        )
        
        generated_text = t5_tokenizer_ft.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
        return generated_text

    prompt1 = "dis manibus <extra_id_0> et sibi"
    prompt2 = "aurelius <extra_id_0> vixit annis <extra_id_1>"
    prompt3 = "<extra_id_0> coniugi karissimo"

    print("\n--- Prompt 1 ---")
    print(f"Input:  '{prompt1}'")
    print(f"Output: '{generate_epigraph(prompt1)}'")

    print("\n--- Prompt 2 ---")
    print(f"Input:  '{prompt2}'")
    print(f"Output: '{generate_epigraph(prompt2)}'")

    print("\n--- Prompt 3 ---")
    print(f"Input:  '{prompt3}'")
    print(f"Output: '{generate_epigraph(prompt3)}'")

except OSError:
    print(f"Could not load fine-tuned T5 model from {T5_OUTPUT_DIR}. Please run the training block first.")


--- Starting T5 Fine-Tuning ---


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading data from data/final_results.json for T5 fine-tuning...




--- T5 Epoch 1/3 ---
  Batch 500/12841, Loss: 2.9455
  Batch 1000/12841, Loss: 2.7264
  Batch 1500/12841, Loss: 2.1127
  Batch 2000/12841, Loss: 2.1828
  Batch 2500/12841, Loss: 2.0874
  Batch 3000/12841, Loss: 2.1929
  Batch 3500/12841, Loss: 2.4971
  Batch 4000/12841, Loss: 1.6540
  Batch 4500/12841, Loss: 1.7418
  Batch 5000/12841, Loss: 2.0527
  Batch 5500/12841, Loss: 1.9899
  Batch 6000/12841, Loss: 2.5501
  Batch 6500/12841, Loss: 2.2010
  Batch 7000/12841, Loss: 2.0770
  Batch 7500/12841, Loss: 2.0720
  Batch 8000/12841, Loss: 2.2403
  Batch 8500/12841, Loss: 1.8319
  Batch 9000/12841, Loss: 1.7869
  Batch 9500/12841, Loss: 1.9508
  Batch 10000/12841, Loss: 1.3510
  Batch 10500/12841, Loss: 1.9181
  Batch 11000/12841, Loss: 1.3844
  Batch 11500/12841, Loss: 2.1971
  Batch 12000/12841, Loss: 1.6450
  Batch 12500/12841, Loss: 1.7421
--- T5 Epoch 2/3 ---
  Batch 500/12841, Loss: 2.2259
  Batch 1000/12841, Loss: 1.9186
  Batch 1500/12841, Loss: 1.4745
  Batch 2000/12841, Loss: 1.85

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


T5 model saved to t5_epigraph_generator/

--- T5 Generation Examples ---
Fine-tuned T5 model loaded successfully.

--- Prompt 1 ---
Input:  'dis manibus <extra_id_0> et sibi'
Output: 'sibi'

--- Prompt 2 ---
Input:  'aurelius <extra_id_0> vixit annis <extra_id_1>'
Output: 'saturninus xxxv.'

--- Prompt 3 ---
Input:  '<extra_id_0> coniugi karissimo'
Output: ''


In [None]:
TARGET_TOKEN_LENGTH = 20
MAX_ITERATIONS = 25


try:
    if 't5_model_ft' not in locals() or 't5_tokenizer_ft' not in locals():
        print("Loading fine-tuned T5 model for long-form generation...")
        T5_OUTPUT_DIR = "t5_epigraph_generator/" 
        t5_tokenizer_ft = T5Tokenizer.from_pretrained(T5_OUTPUT_DIR)
        t5_model_ft = T5ForConditionalGeneration.from_pretrained(T5_OUTPUT_DIR)
        t5_model_ft.to(device)
        t5_model_ft.eval()
        print("Model loaded successfully.")
except NameError:
    print("Error: T5_OUTPUT_DIR is not defined. Please run the T5 training/loading cell first.")
except OSError:
    print(f"Error: Could not load model from {T5_OUTPUT_DIR}. Please ensure the model was trained and saved.")



def generate_long_epigraph(start_prompt, target_length, model, tokenizer):
    current_text = start_prompt.strip()
    
    unk_token_id = tokenizer.unk_token_id
    bad_words_ids = [[unk_token_id]] if unk_token_id is not None else None
    
    if bad_words_ids:
        print(f"Forbidding generation of '<unk>' token (ID: {unk_token_id}).")

    for i in range(MAX_ITERATIONS):
        current_token_count = len(current_text.split())
        if current_token_count >= target_length:
            print(f"Target length of {target_length} reached. Stopping.")
            break
            
        input_prompt = f"infill: {current_text} <extra_id_0>"
        input_ids = tokenizer.encode(input_prompt, return_tensors="pt").to(model.device)
        
        output_ids = model.generate(
            input_ids,
            max_new_tokens=10,
            do_sample=True,            
            top_k=50,                   
            top_p=0.95,                 
            repetition_penalty=1.2,     
            early_stopping=True,
            bad_words_ids=bad_words_ids 
        )
        
        decoded_output = tokenizer.decode(output_ids[0], skip_special_tokens=False)
        match = re.search(r"<extra_id_0>(.*?)<extra_id_1>", decoded_output)
        
        if match:
            new_chunk = match.group(1).strip()
            if not new_chunk or new_chunk in [".", ","]:
                print("Model generated an empty or punctuation-only chunk. Stopping.")
                break
            
            new_words = new_chunk.split()
            current_text += " " + " ".join(new_words[:2])
        else:
            print("Could not parse model output. Stopping.")
            break

        print(f"Iteration {i+1}: (Length: {len(current_text.split())}) -> {current_text}")

    return " ".join(current_text.split())


if 't5_model_ft' in locals():
    print(f"\n--- Generating {TARGET_TOKEN_LENGTH}-token Epigraphs (with anti-looping) ---")
    
    start_prompts = [
        "dis manibus",
        "hic iacet",
        "iulia filia",
        "imperatori caesari",
        "in hoc tumulo",
        "valeria marci et liberta"
    ]
    
    for i, prompt in enumerate(start_prompts):
        print(f"\n--- Example {i+1} ---")
        print(f"Starting with: '{prompt}'")
        
        final_epigraph = generate_long_epigraph(
            start_prompt=prompt,
            target_length=TARGET_TOKEN_LENGTH,
            model=t5_model_ft,
            tokenizer=t5_tokenizer_ft
        )
        
        print("\n--- FINAL GENERATED EPIGRAPH ---")
        print(final_epigraph)
        print("--------------------------------\n")
else:
    print("\nSkipping long-form generation because the T5 model is not loaded.")


--- Generating 20-token Epigraphs (with anti-looping) ---

--- Example 1 ---
Starting with: 'dis manibus'
Forbidding generation of '<unk>' token (ID: 2).
Iteration 1: (Length: 3) -> dis manibus antonino.




Iteration 2: (Length: 4) -> dis manibus antonino. antonino.
Iteration 3: (Length: 5) -> dis manibus antonino. antonino. sacrum.
Iteration 4: (Length: 6) -> dis manibus antonino. antonino. sacrum. rhoda.
Iteration 5: (Length: 7) -> dis manibus antonino. antonino. sacrum. rhoda. sile
Iteration 6: (Length: 8) -> dis manibus antonino. antonino. sacrum. rhoda. sile fecit.
Iteration 7: (Length: 9) -> dis manibus antonino. antonino. sacrum. rhoda. sile fecit. hermiae.
Iteration 8: (Length: 10) -> dis manibus antonino. antonino. sacrum. rhoda. sile fecit. hermiae. tyche.
Iteration 9: (Length: 11) -> dis manibus antonino. antonino. sacrum. rhoda. sile fecit. hermiae. tyche. fecit.
Iteration 10: (Length: 12) -> dis manibus antonino. antonino. sacrum. rhoda. sile fecit. hermiae. tyche. fecit. hic.
Iteration 11: (Length: 13) -> dis manibus antonino. antonino. sacrum. rhoda. sile fecit. hermiae. tyche. fecit. hic. fecit.
Iteration 12: (Length: 14) -> dis manibus antonino. antonino. sacrum. rhoda. s