# Install Requirements and Import Libraries

In [None]:
!pip install evaluate seqeval -qqq # Requirements for the kaggle

In [None]:
from huggingface_hub import login
import wandb
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import DataCollatorForTokenClassification
import evaluate
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer
import numpy as np

# Authenticate with Wandb Log and Huggingface

In [None]:
wandb.login(key="")
login(token="")

# Load the DB and Models

In [None]:
ds = load_dataset("chuuhtetnaing/myanmar-text-segmentation-dataset")
tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

seqeval = evaluate.load("seqeval")

In [None]:
id2label = {i: l for i, l in enumerate(ds['train'].features['segment_tags'].feature.names)}
label2id = {l: i for i, l in enumerate(ds['train'].features['segment_tags'].feature.names)}

In [None]:
id2label

In [None]:
label2id

In [None]:
model = AutoModelForTokenClassification.from_pretrained(
    "FacebookAI/xlm-roberta-base", num_labels=2, id2label=id2label, label2id=label2id
)

# Evaluaiton Function and Tokenization Function

In [None]:
label_list = ds["train"].features[f"segment_tags"].feature.names

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = seqeval.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

In [None]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples[f"segment_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)  # Map tokens to their respective word.
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:  # Set the special tokens to -100.
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:  # Only label the first token of a given word.
                label_ids.append(label[word_idx])
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [None]:
tokenized_ds = ds.map(tokenize_and_align_labels, batched=True)

In [None]:
tokenized_ds

# Train/Fine-Tune the Model

In [None]:
training_args = TrainingArguments(
    output_dir="myanmar_text_segmentation_model",
    learning_rate=2e-5,
    per_device_train_batch_size=25,
    per_device_eval_batch_size=25,
    num_train_epochs=5,
    weight_decay=0.01,
    eval_strategy="steps",
    save_strategy="steps",
    load_best_model_at_end=True,
    push_to_hub=True,
    report_to="wandb",
    hub_private_repo=True,
    eval_steps=1000,
    save_steps=1000,
    save_total_limit=2,
    hub_strategy="all_checkpoints",
    save_safetensors=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_ds["train"],
    eval_dataset=tokenized_ds["test"],
    processing_class=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

# Add the training and evaluation result manually

In [1]:
# import json
# from huggingface_hub import hf_hub_download, HfApi

# REPO_ID = "chuuhtetnaing/myanmar_text_segmentation_model"
# CHECKPOINT = "checkpoint-54000"  # Your latest checkpoint

# # Download trainer_state.json from the latest checkpoint
# print(f"Downloading trainer_state.json from {CHECKPOINT}...")
# state_file = hf_hub_download(
#     repo_id=REPO_ID,
#     filename=f"{CHECKPOINT}/trainer_state.json"
# )

# with open(state_file, "r") as f:
#     state = json.load(f)

# log_history = state["log_history"]

# # Separate training and eval logs, then merge by step
# train_logs = {log["step"]: log for log in log_history if "loss" in log and "eval_loss" not in log}
# eval_logs = {log["step"]: log for log in log_history if "eval_loss" in log}

# # Get all eval steps (1000, 2000, ..., 54000)
# eval_steps = sorted(eval_logs.keys())

# # Build the markdown table
# table_rows = []
# table_rows.append("| Step | Training Loss | Validation Loss | Precision | Recall | F1 | Accuracy |")
# table_rows.append("|------|---------------|-----------------|-----------|--------|------|----------|")

# for step in eval_steps:
#     eval_log = eval_logs[step]

#     # Find the closest training loss (usually logged at same or nearby step)
#     train_loss = train_logs.get(step, {}).get("loss", None)
#     if train_loss is None:
#         # Look for nearby steps
#         for s in range(step, step - 1000, -100):
#             if s in train_logs:
#                 train_loss = train_logs[s]["loss"]
#                 break

#     train_loss_str = f"{train_loss:.4f}" if train_loss else "N/A"
#     eval_loss_str = f"{eval_log.get('eval_loss', 0):.4f}"
#     precision_str = f"{eval_log.get('eval_precision', 0):.4f}"
#     recall_str = f"{eval_log.get('eval_recall', 0):.4f}"
#     f1_str = f"{eval_log.get('eval_f1', 0):.4f}"
#     accuracy_str = f"{eval_log.get('eval_accuracy', 0):.4f}"

#     table_rows.append(f"| {step} | {train_loss_str} | {eval_loss_str} | {precision_str} | {recall_str} | {f1_str} | {accuracy_str} |")

# # Print the table
# print("\n=== Training Results Table ===\n")
# print("\n".join(table_rows))

# # Get final metrics
# best_checkpoint = state.get("best_model_checkpoint", "N/A")
# best_metric = state.get("best_metric", "N/A")

# # Find final train loss from the last training log
# final_train_loss = None
# for log in reversed(log_history):
#     if "train_loss" in log:
#         final_train_loss = log["train_loss"]
#         break
#     elif "loss" in log and "eval_loss" not in log:
#         final_train_loss = log["loss"]
#         break

# # If still None, use the last recorded training loss from train_logs
# if final_train_loss is None and train_logs:
#     last_train_step = max(train_logs.keys())
#     final_train_loss = train_logs[last_train_step].get("loss")

# print("\n\n=== Additional Info ===")
# print(f"Best checkpoint: {best_checkpoint}")
# print(f"Best metric value: {best_metric}")
# print(f"Total steps: {state.get('global_step', 'N/A')}")
# print(f"Total epochs: {state.get('epoch', 'N/A')}")
# print(f"Final training loss: {final_train_loss}")

# # Debug: show last few log entries
# print("\n=== Last 3 log entries (for debugging) ===")
# for log in log_history[-3:]:
#     print(log)

# # If final_train_loss is still None, allow manual input
# if final_train_loss is None:
#     print("\n⚠️  Could not find final training loss in logs.")
#     manual_loss = input("Enter final training loss manually (from TrainOutput), or press Enter to skip: ").strip()
#     if manual_loss:
#         try:
#             final_train_loss = float(manual_loss)
#         except ValueError:
#             print("Invalid number, using N/A")

# # Generate full README content
# readme_content = f"""---
# license: apache-2.0
# base_model: FacebookAI/xlm-roberta-base
# tags:
#   - token-classification
#   - myanmar
#   - text-segmentation
# language:
#   - my
#   - en
# datasets:
#   - chuuhtetnaing/myanmar-text-segmentation
# metrics:
#   - f1
#   - precision
#   - recall
#   - accuracy
# ---

# # Myanmar Text Segmentation Model

# Fine-tuned [FacebookAI/xlm-roberta-base](https://huggingface.co/FacebookAI/xlm-roberta-base) for Myanmar text segmentation (word boundary detection) using token classification.

# ## Training Results

# {chr(10).join(table_rows)}

# ## Training Details

# | Parameter | Value |
# |-----------|-------|
# | Base Model | google-bert/bert-base-multilingual-cased |
# | Total Steps | {state.get('global_step', 'N/A')} |
# | Epochs | {state.get('epoch', 'N/A')} |
# | Final Training Loss | {f"{final_train_loss:.6f}" if final_train_loss is not None else 'N/A'} |
# | Best Checkpoint | {best_checkpoint} |
# | Best F1 Score | {f"{best_metric:.4f}" if isinstance(best_metric, float) else best_metric} |
# | Learning Rate | 2e-5 |
# | Batch Size | 25 |
# | Weight Decay | 0.01 |

# ## Usage

# ```python
# from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline

# model = AutoModelForTokenClassification.from_pretrained("chuuhtetnaing/myanmar_text_segmentation_model")
# tokenizer = AutoTokenizer.from_pretrained("chuuhtetnaing/myanmar_text_segmentation_model")

# # Using pipeline
# nlp = pipeline("token-classification", model=model, tokenizer=tokenizer)
# tokens = nlp("အချစ်ဆိုတာလူတွေရှင်သန်ဖို့သဘာဝကပေးတဲ့လက်နက်လား၊ဒါမှမဟုတ်ယဉ်ကျေးမှုအရတီထွင်ထားတဲ့စိတ်ကူးယဉ်မှုသက်သက်လား။")

# segmented_text = []
# for item in tokens:
#     if item["entity_group"] == "B":
#         segmented_text.append(item["word"])
#     else:  # 'I' - append to previous word
#         segmented_text[-1] += item["word"]
# segmented_text = " ".join(segmented_text)

# return segmented_text
# ```

# ## Label Mapping

# - `B`: Beginning of a word/segment
# - `I`: Inside a word/segment (continuation)

# ## Dataset

# Trained on [chuuhtetnaing/myanmar-text-segmentation](https://huggingface.co/datasets/chuuhtetnaing/myanmar-text-segmentation) dataset.
# """

# # Save README locally
# readme_path = "README.md"
# with open(readme_path, "w", encoding="utf-8") as f:
#     f.write(readme_content)

# print(f"\n\n=== README saved to {readme_path} ===")


In [2]:
# # Upload to Hub
# api = HfApi()
# api.upload_file(
#     path_or_fileobj=readme_path,
#     path_in_repo="README.md",
#     repo_id=REPO_ID,
#     commit_message="Update README with training results table"
# )
# print(f"✅ README uploaded to {REPO_ID}")