# Imports

In [None]:
!pip install datasets

In [None]:
import os
import shutil
import json
import wandb
from datasets import load_dataset, Dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
)
import pandas as pd
import numpy as np
import torch
import re
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

In [None]:
wandb.login()

# Establish Google Drive Connection (if needed)

In [None]:
from google.colab import drive

drive.mount('/content/drive')

# Configuration

In [None]:
PROJECT_NAME   = "RE_SpanBERT_Finetune"
#DATA_FILES     = {"train": "train_re.csv", "validation": "val_re.csv", "test": "test_re.csv"}
MODEL_NAME     = "SpanBERT/spanbert-large-cased"
NUM_EPOCHS     = 20
LEARNING_RATES = [5e-5, 3e-5, 2e-5, 1e-5]
FIXED_LR       = 2e-5
BATCH_SIZES    = [8, 16, 24]
LOG_STEPS      = 50

# Data loading & preprocessing

## Load datasets - Challenge

In [None]:
DATA_DIR = "/content/drive/MyDrive/project_files/data/raw/train"

if not os.path.isdir(DATA_DIR):
    raise FileNotFoundError(f"Directory not found: {DATA_DIR}. Please verify the path.")

rows = []
for fname in os.listdir(DATA_DIR):
    if not fname.endswith(".json"):
        continue
    path = os.path.join(DATA_DIR, fname)
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)

    records = data if isinstance(data, list) else [data]

    for rec in records:
        doc       = rec.get("doc", "")
        sentences = re.split(r'(?<=[.!?])\s+', doc)

        for triple in rec.get("triples", []):
            head     = triple["head"]
            tail     = triple["tail"]
            relation = triple["relation"]

            sentence = next((s for s in sentences if head in s and tail in s), "")
            if not sentence:
                continue

            rows.append({
                "entity1": head,
                "entity2": tail,
                "text":    sentence,
                "relation": relation
            })

df = pd.DataFrame(rows)
df.to_csv("relation_dataset.csv", index=False)
print(f"Built dataset with {len(df)} rows and wrote relation_dataset.csv")
print(df.head())


### New challenge dataset

In [None]:
df = pd.read_csv("relation_dataset.csv")
print("Counts per relation before split:\n")
print(df["relation"].value_counts())

train_df, temp_df = train_test_split(
    df,
    test_size=0.2,
    random_state=42,
)

val_df, test_df = train_test_split(
    temp_df,
    test_size=0.5,
    random_state=42,
)

train_labels = set(train_df["relation"])

val_rogue  = val_df[~val_df["relation"].isin(train_labels)]
test_rogue = test_df[~test_df["relation"].isin(train_labels)]

if not val_rogue.empty or not test_rogue.empty:
    print(f"Moving {len(val_rogue)} val + {len(test_rogue)} test rogue rows → train")
    train_df = pd.concat([train_df, val_rogue, test_rogue], ignore_index=True)
    val_df   = val_df[val_df["relation"].isin(train_labels)]
    test_df  = test_df[test_df["relation"].isin(train_labels)]

print(f"\nFinal sizes → train: {len(train_df)}, val: {len(val_df)}, test: {len(test_df)}")
print("Relations per split:")
print(f"  train: {train_df['relation'].value_counts().to_dict()}")
print(f"  val  : {val_df['relation'].value_counts().to_dict()}")
print(f"  test : {test_df['relation'].value_counts().to_dict()}")

raw_datasets = DatasetDict({
    "train":      Dataset.from_pandas(train_df.reset_index(drop=True)),
    "validation": Dataset.from_pandas(val_df.reset_index(drop=True)),
    "test":       Dataset.from_pandas(test_df.reset_index(drop=True)),
})



def make_example(example):
    sent = example["text"]
    e1, e2 = example["entity1"], example["entity2"]
    marked = sent.replace(e1, f"[E1]{e1}[/E1]") \
                 .replace(e2, f"[E2]{e2}[/E2]")
    return {
        "text":           marked,
        "entity1_label":  e1,
        "entity2_label":  e2,
        "relation_label": example["relation"],
    }

token_input_datasets = raw_datasets.map(
    make_example,
    remove_columns=["text", "entity1", "entity2", "relation"],
)

unique_rels = sorted(token_input_datasets["train"].unique("relation_label"))
label2id = {rel: i for i, rel in enumerate(unique_rels)}
id2label = {i: rel for rel, i in label2id.items()}

def add_label_ids(example):
    return {"labels": label2id[example["relation_label"]]}

final_datasets = token_input_datasets.map(
    add_label_ids,
)

## Own custom dataset

In [None]:
base_path = "/content/drive/MyDrive/project_files/data/processed/RE_datasets/"
train_df = pd.read_csv(base_path + "train_re.csv")
val_df   = pd.read_csv(base_path + "val_re.csv")
test_df  = pd.read_csv(base_path + "test_re.csv")

print(f"Loaded {len(train_df)} train / {len(val_df)} val / {len(test_df)} test rows")

raw_datasets = DatasetDict({
    "train":      Dataset.from_pandas(train_df.reset_index(drop=True)),
    "validation": Dataset.from_pandas(val_df.reset_index(drop=True)),
    "test":       Dataset.from_pandas(test_df.reset_index(drop=True)),
})

def make_example(example):
    sent = example["RE_sentence"]
    e1 = example["entity1_label"]
    e2 = example["entity2_label"]
    marked = sent.replace(e1, f"[E1]{e1}[/E1]").replace(e2, f"[E2]{e2}[/E2]")
    return {
        "text": marked,
        "entity1_label": e1,
        "entity2_label": e2,
        "relation_label": example["relation_label"],
    }

formatted_datasets = raw_datasets.map(
    make_example,
    remove_columns=["RE_sentence", "relation", "entity1_id", "entity2_id"],
)

unique_rels = sorted(formatted_datasets["train"].unique("relation_label"))
label2id = {rel: i for i, rel in enumerate(unique_rels)}
id2label = {i: rel for rel, i in label2id.items()}

def add_label_ids(example):
    return {"labels": label2id[example["relation_label"]]}

final_datasets = formatted_datasets.map(add_label_ids)

print(f"Final dataset features: {final_datasets['train'].features}")
print(f"Label mapping:\n{label2id}")

print("\n🔍 Sample examples from training set:\n")
final_datasets["train"].shuffle(seed=42).select(range(3)).to_pandas()


## Insert entity markers and map labels

## Tokenizer and data collection

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
data_collator = DataCollatorWithPadding(tokenizer)


def tokenize_fn(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        padding=False,
        max_length=256,
    )

tokenized = final_datasets.map(
    tokenize_fn,
    batched=True,
)

# Baseline Evaluation

## Start W&B run

In [None]:
wandb.init(
    project=PROJECT_NAME,
    name="baseline_new_training",
    reinit=True,
    config={
        "model": MODEL_NAME,
    }
)

In [None]:
baseline_table = wandb.Table(columns=["split", "eval_loss", "precision", "recall", "f1", "accuracy"])

target_labels = list(label2id.keys())

baseline_model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(target_labels),
    id2label=id2label,
    label2id=label2id,
)

def compute_seq_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    labels = p.label_ids
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, preds, average="macro", zero_division=0
    )
    acc = accuracy_score(labels, preds)
    return {"precision": precision, "recall": recall, "f1": f1, "accuracy": acc}

for split in ["validation", "test"]:
    args = TrainingArguments(
        output_dir=f"./baseline_{split}",
        do_train=False,
        do_eval=True,
        per_device_eval_batch_size=32,
        logging_strategy="no",
        save_strategy="no",
        report_to=[],
    )

    trainer = Trainer(
        model=baseline_model,
        args=args,
        data_collator=DataCollatorWithPadding(tokenizer),
        eval_dataset=tokenized[split],
        tokenizer=tokenizer,
        compute_metrics=compute_seq_metrics,
    )

    metrics = trainer.evaluate()

    baseline_table.add_data(
        split,
        metrics["eval_loss"],
        metrics["eval_precision"],
        metrics["eval_recall"],
        metrics["eval_f1"],
        metrics["eval_accuracy"],
    )

wandb.log({"baseline_metrics_table": baseline_table})


## Evaluate and log on both splits

In [None]:
for split in ["validation","test"]:
    args = TrainingArguments(
        output_dir=f"baseline_{split}",
        per_device_eval_batch_size=BATCH_SIZES[-1],
        do_train=False, do_eval=True,
        logging_strategy="no", save_strategy="no",
        report_to=["wandb"],
    )
    trainer = Trainer(
        model=baseline_model, args=args,
        data_collator=data_collator,
        eval_dataset=tokenized[split], tokenizer=tokenizer,
        compute_metrics=compute_seq_metrics,
    )
    res = trainer.evaluate()
    wandb.log({f"baseline/{split}_{k}":v for k,v in res.items()})
    baseline_table.add_data(
        split,
        res['eval_loss'],
        res['eval_precision'],
        res['eval_recall'],
        res['eval_f1'],
        res['eval_accuracy'],
    )
    print(f"Baseline on {split}:",res)

wandb.log({"baseline_metrics_table":baseline_table})

In [None]:
torch.cuda.empty_cache()
wandb.finish()

Training

In [None]:
del baseline_model
torch.cuda.empty_cache()

for lr in LEARNING_RATES:
    wandb.init(
        project=PROJECT_NAME,
        config={
            "model": MODEL_NAME,
            "epochs": NUM_EPOCHS,
            "learning_rate": lr,
            "batch_size": 4,
            "log_steps": LOG_STEPS,
        },
        reinit=True,
        name=f"bs_4_lr_{lr}",
        resume=False,
    )

    metrics_table = wandb.Table(columns=["epoch","train_loss","eval_loss","precision","recall","f1","accuracy"])

    model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_NAME,
        num_labels=len(target_labels), id2label=id2label, label2id=label2id,
    )

    args = TrainingArguments(
        output_dir=f"outputs/bs_4", overwrite_output_dir=True,
        num_train_epochs=NUM_EPOCHS,
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        learning_rate=FIXED_LR,
        logging_strategy="epoch",
        eval_strategy="epoch",
        save_strategy="no",
        report_to=["wandb"],
        load_best_model_at_end=False,
        fp16=True,
    )

    trainer = Trainer(
        model=model, args=args,
        train_dataset=tokenized["train"],
        eval_dataset=tokenized["validation"],
        data_collator=data_collator, tokenizer=tokenizer,
        compute_metrics=compute_seq_metrics,
    )

    trainer.train()
    for log in trainer.state.log_history:
        if all(k in log for k in ["epoch","eval_loss"]):
            metrics_table.add_data(
                log["epoch"], log.get("loss"), log["eval_loss"],
                log.get("eval_precision"), log.get("eval_recall"), log.get("eval_f1"), log.get("eval_accuracy"),
            )
    wandb.log({"metrics_table":metrics_table})

    final_metrics = trainer.evaluate()
    wandb.log({f"final/bs_4_{k}":v for k,v in final_metrics.items()})

    preds_out = trainer.predict(tokenized["validation"])
    preds = np.argmax(preds_out.predictions,axis=1)
    cm = wandb.plot.confusion_matrix(probs=None, y_true=preds_out.label_ids, preds=preds, class_names=target_labels)
    wandb.log({"confusion_matrix":cm})

    wandb.finish()

# 200 epoch training

## Setup

In [None]:
wandb.init(
    project=PROJECT_NAME,
    name="train_new_200e_lr2e-5_16b_3",
    reinit=True,
    config={
        "model": MODEL_NAME,
        "epochs": 200,
        "learning_rate": FIXED_LR,
        "batch_size": 16,
        "log_steps": LOG_STEPS,
    }
)

train200_table = wandb.Table(columns=["epoch","train_loss","eval_loss","precision","recall","f1","accuracy"])

model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(target_labels),
    id2label=id2label,
    label2id=label2id,
)

args = TrainingArguments(
    output_dir="outputs/train_200e_lr2e-5_16b",
    overwrite_output_dir=True,
    num_train_epochs=200,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=FIXED_LR,
    logging_strategy="epoch",
    eval_strategy="epoch",
    save_strategy="no",
    report_to=["wandb"],
    load_best_model_at_end=False,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_seq_metrics,
)

## Run

In [None]:
trainer.train()

for log in trainer.state.log_history:
    if "epoch" in log and "eval_loss" in log:
        train200_table.add_data(
            log["epoch"],
            log.get("loss"),
            log.get("eval_loss"),
            log.get("eval_precision"),
            log.get("eval_recall"),
            log.get("eval_f1"),
            log.get("eval_accuracy"),
        )
wandb.log({"train_200e_metrics_table": train200_table})

final_metrics = trainer.evaluate()
wandb.log({f"train_200e_lr2e-5_16b_{k}": v for k, v in final_metrics.items()})
preds_out = trainer.predict(tokenized["validation"])
preds = np.argmax(preds_out.predictions, axis=1)
cm = wandb.plot.confusion_matrix(
    probs=None,
    y_true=preds_out.label_ids,
    preds=preds,
    class_names=target_labels
)
wandb.log({"train_200e_confusion_matrix": cm})

wandb.finish()

# Test

In [None]:
# If running immediately after training, `trainer.model` is already the final model.
# Otherwise, load from the output directory of your best run:
# model = AutoModelForSequenceClassification.from_pretrained("outputs/train_200e_lr2e-5_16b", num_labels=len(target_labels), id2label=id2label, label2id=label2id)
model = trainer.model  # reuse the model from the last training cell



## Inference only trainer

In [None]:
test_args = TrainingArguments(
    output_dir="inference",
    per_device_eval_batch_size= 16,
    do_train=False,
    do_eval=True,
    logging_strategy="no",
    save_strategy="no",
)
inference_trainer = Trainer(
    model=model,
    args=test_args,
    data_collator=data_collator,
    eval_dataset=tokenized["test"],
    tokenizer=tokenizer,
    compute_metrics=lambda p: {
        **{"accuracy": accuracy_score(p.label_ids, np.argmax(p.predictions, axis=1))},
        **{k: v for k, v in zip(["precision","recall","f1"], precision_recall_fscore_support(p.label_ids, np.argmax(p.predictions, axis=1), average='macro', zero_division=0)[:3])}
    },
)

In [None]:
test_results = inference_trainer.evaluate()
print("Test metrics:", test_results)

In [None]:
wandb.init(project=PROJECT_NAME, name="test_evaluation_1", reinit=True)
wandb.log({f"test/{k}": v for k, v in test_results.items() if k.startswith("eval_")})

In [None]:
preds_out = inference_trainer.predict(tokenized["test"])
pred_ids = np.argmax(preds_out.predictions, axis=1)

triples = []
for i, pred_id in enumerate(pred_ids):
    ex = tokenized["test"][i]
    triples.append({
        "head": ex["entity1_label"],
        "relation": id2label[pred_id],
        "tail": ex["entity2_label"],
    })

output = {"triples": triples, "label_set": target_labels}
json_path = "predictions.json"
with open(json_path, "w") as fp:
    json.dump(output, fp, indent=4)

In [None]:
wandb.save(json_path)
wandb.finish()
print(f"Wrote {len(triples)} triples and logged test metrics to W&B")

#Final training

In [None]:
wandb.init(
    project=PROJECT_NAME,
    name="labels_new_train_20e_lr2e-5_bs16",
    reinit=True,
    config={
        "model": MODEL_NAME,
        "epochs": 20,
        "learning_rate": FIXED_LR,
        "batch_size": 16,
        "log_steps": LOG_STEPS,
    }
)


In [None]:
final_model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(label2id),
    id2label=id2label,
    label2id=label2id,
)


In [None]:
final_args = TrainingArguments(
    overwrite_output_dir=True,
    num_train_epochs=20,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    logging_strategy="epoch",
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_f1",
    greater_is_better=True,
    save_total_limit=1,
    report_to=["wandb"],
)

### Old

In [None]:
def compute_seq_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    labels = p.label_ids
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, preds, average="macro", zero_division=0
    )
    acc = accuracy_score(labels, preds)
    return {"precision": precision, "recall": recall, "f1": f1, "accuracy": acc}

final_trainer = Trainer(
    model=final_model,
    args=final_args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_seq_metrics,
)

### New

In [None]:
def compute_seq_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    labels = p.label_ids

    macro_precision, macro_recall, macro_f1, _ = precision_recall_fscore_support(
        labels, preds, average="macro", zero_division=0
    )
    accuracy = accuracy_score(labels, preds)

    per_label_precision, per_label_recall, per_label_f1, _ = precision_recall_fscore_support(
        labels, preds, average=None, zero_division=0, labels=list(label2id.values())
    )

    metrics = {
        "accuracy": accuracy,
        "precision": macro_precision,
        "recall": macro_recall,
        "f1": macro_f1,
    }

    for i, label_id in enumerate(label2id.values()):
        label_name = id2label[label_id]
        metrics[f"{label_name}_precision"] = per_label_precision[i]
        metrics[f"{label_name}_recall"] = per_label_recall[i]
        metrics[f"{label_name}_f1"] = per_label_f1[i]

    return metrics

final_trainer = Trainer(
      model=final_model,
      args=final_args,
      train_dataset=tokenized["train"],
      eval_dataset=tokenized["validation"],
      data_collator=data_collator,
      tokenizer=tokenizer,
      compute_metrics=compute_seq_metrics,
  )


In [None]:
final_trainer.train()

In [None]:
best_ckpt_dir = final_args.output_dir
final_trainer.save_model(best_ckpt_dir)

In [None]:
artifact = wandb.Artifact("spanbert_new_best_checkpoint", type="model")
artifact.add_dir(best_ckpt_dir)
wandb.log_artifact(artifact, aliases=["best"])

In [None]:
parent_dir = "/content/drive/MyDrive/project_files/checkpoints/RE/"

best_ckpt = final_trainer.state.best_model_checkpoint

best_name = os.path.basename(best_ckpt.rstrip("/"))
new_folder = os.path.join(parent_dir, f"best_{best_name}")

os.makedirs(new_folder, exist_ok=True)

final_trainer.save_model(new_folder)

tokenizer.save_pretrained(new_folder)

print(f"Saved best model and tokenizer into {new_folder}")


In [None]:
final_table = wandb.Table(columns=["epoch","train_loss","eval_loss","precision","recall","f1","accuracy"])
for log in final_trainer.state.log_history:
    if all(k in log for k in ["epoch","eval_loss"]):
        final_table.add_data(
            log["epoch"],
            log.get("loss"),
            log.get("eval_loss"),
            log.get("eval_precision"),
            log.get("eval_recall"),
            log.get("eval_f1"),
            log.get("eval_accuracy"),
        )
wandb.log({"final_training_table": final_table})

val_preds = final_trainer.predict(tokenized["validation"])
val_pred_ids = np.argmax(val_preds.predictions, axis=1)
cm = wandb.plot.confusion_matrix(
    probs=None,
    y_true=val_preds.label_ids,
    preds=val_pred_ids,
    class_names=target_labels
)
wandb.log({"final_confusion_matrix": cm})




## Save predictions in files

In [None]:
val_preds = final_trainer.predict(tokenized["validation"])
val_pred_ids = np.argmax(val_preds.predictions, axis=1)

val_texts = tokenized["validation"]["text"]
true_ids = val_preds.label_ids
pred_ids = val_pred_ids

pred_df = pd.DataFrame({
    "sentence": val_texts,
    "actual_relation": [id2label[i] for i in true_ids],
    "predicted_relation": [id2label[i] for i in pred_ids],
})

csv_path = "/content/drive/MyDrive/project_files/data/processed/re_val_predictions.csv"
json_path = "/content/drive/MyDrive/project_files/data/processed/re_val_predictions.json"
pred_df.to_csv(csv_path, index=False)
pred_df.to_json(json_path, orient="records", lines=True)

wandb.save(csv_path)
wandb.save(json_path)


In [None]:
wandb.finish()

In [None]:
local_src = '/content/drive/MyDrive/project_files/checkpoints/RE/checkpoint-306'
temp_dir = '/content/tmp_checkpoint-306'
if os.path.exists(local_src):
    shutil.copytree(local_src, temp_dir, dirs_exist_ok=True)
    print(f"Temporary copy of checkpoint created at {temp_dir}")
else:
    print(f"Error: local src {local_src} does not exist")

In [None]:
drive.mount('/content/gdrive', force_remount=True)

In [None]:
dst = '/content/gdrive/MyDrive/project_files/checkpoints/RE/'
os.makedirs(dst, exist_ok=True)
shutil.copytree(temp_dir, os.path.join(dst, 'checkpoint-306'), dirs_exist_ok=True)

print(f"Checkpoint folder copied from temp to Drive at {dst}/checkpoint-306")