In [1]:
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)
import sys
sys.path.append('/content/drive/MyDrive/Colab Notebooks')

Mounted at /content/drive/


In [2]:
!pip install torch transformers datasets
!pip install numpy==1.26.4



In [3]:
from transformers import MT5Tokenizer, MT5ForConditionalGeneration
import torch
import torch.nn as nn
# from datasets import DatasetDict
from transformers import Trainer, TrainingArguments, DataCollatorWithPadding
from sklearn.metrics import classification_report, accuracy_score, f1_score
from sklearn.preprocessing import LabelEncoder
from run_model import load_rels_dataset
import argparse
import numpy as np
import csv

In [4]:
print(np.__version__)

1.26.4


In [19]:
# === load dataset ===
dataset = load_rels_dataset(
    "/content/drive/MyDrive/Colab Notebooks/sample_data/zho.rst.gcdt_dev.rels",
    "/content/drive/MyDrive/Colab Notebooks/sample_data/zho.rst.gcdt_train.rels"
)

# encode labels to integers
label_encoder = LabelEncoder()
label_encoder.fit(dataset["train"]["label"])

def encode_label(example):
    example["label"] = label_encoder.transform([example["label"]])[0]
    return example

dataset = dataset.map(encode_label)
tokenizer = MT5Tokenizer.from_pretrained("google/mt5-base")

def preprocess(example):
    # combine the two spans into a single text input
    text = f"Classify: Arg1: {example['u1']} Arg2: {example['u2']}"
    encoded = tokenizer(text, padding="max_length", truncation=True, max_length=512)
    encoded["label"] = example["label"]
    return encoded

# drop the original columns that are not needed for training
tokenized_dataset = dataset.map(
    preprocess,
    remove_columns=dataset["train"].column_names
)

tokenized_dataset.set_format("torch")

{'label': 'organization', 'type': 'none', 'u1': '疫情 期间 主流 媒体 的 传播 策略 研究', 'u2': '直 至 今天 的 新冠 疫情 常态化 防治 时期 ， 我们 看到 了 <*> 不同 传播 行动 。', 'direction': '1>2'}
{'label': 'organization', 'type': 'none', 'u1': '狂欢 与 凝视 ：', 'u2': '颜值 消费 与 田园 回归', 'direction': '1>2'}


Map:   0%|          | 0/1006 [00:00<?, ? examples/s]

Map:   0%|          | 0/953 [00:00<?, ? examples/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'T5Tokenizer'. 
The class this function is called from is 'MT5Tokenizer'.


Map:   0%|          | 0/1006 [00:00<?, ? examples/s]

Map:   0%|          | 0/953 [00:00<?, ? examples/s]

In [24]:
class MT5Classifier(nn.Module):
    def __init__(self, num_labels, num_languages=1, lang_emb_dim=None):
        super().__init__()
        self.encoder = MT5ForConditionalGeneration.from_pretrained("google/mt5-small").get_encoder()
        hidden_size = self.encoder.config.d_model
        # add language embedding layer
        self.lang_emb_dim = lang_emb_dim or hidden_size
        self.language_embedding = nn.Embedding(num_languages, self.lang_emb_dim)
        self.lang_proj = nn.Linear(self.lang_emb_dim, hidden_size) if self.lang_emb_dim != hidden_size else None

        self.classifier = nn.Linear(hidden_size, num_labels)
        self.loss_fct = nn.CrossEntropyLoss()

    def forward(self, input_ids, attention_mask, language_ids=None, labels=None):
       # make sure we always have batch dim
        if input_ids.dim() == 1:
          input_ids = input_ids.unsqueeze(0)
          attention_mask = attention_mask.unsqueeze(0)
        if language_ids is not None:
            language_ids = language_ids.unsqueeze(0)
        if labels is not None:
            labels = labels.unsqueeze(0)

        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        pooled = outputs.last_hidden_state.mean(dim=1)
        if language_ids is not None:
            lang_emb = self.language_embedding(language_ids)
            if self.lang_proj:
                lang_emb = self.lang_proj(lang_emb)
            pooled = pooled + lang_emb
        logits = self.classifier(pooled)
        loss = self.loss_fct(logits, labels) if labels is not None else None
        return {"loss": loss, "logits": logits}

In [None]:
model = MT5Classifier(num_labels=len(label_encoder.classes_))

In [25]:
# === training setup ===
use_cuda = True
device = torch.device("cuda" if torch.cuda.is_available() and use_cuda else "cpu")

model = model.to(device)

training_args = TrainingArguments(
    output_dir="/content/drive/MyDrive/Colab Notebooks/baseline_mt5_classifier_2.results",
    overwrite_output_dir=False,
    learning_rate=2e-5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=2,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    auto_find_batch_size=True,
)

# add 'return_tensor' to help smaller dataset with unmatched dimension issues
data_collator = DataCollatorWithPadding(tokenizer, return_tensors="pt")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=1)
    acc = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds, average="weighted")
    report = classification_report(labels, preds, target_names=label_encoder.classes_)

    print("\n=== Classification Report ===")
    print(report)
    print(f"Accuracy: {acc:.4f}")
    print(f"Weighted F1: {f1:.4f}\n")

    return {
        "accuracy": acc,
        "f1": f1
    }

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["dev"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

batch = next(iter(trainer.get_train_dataloader()))
print("Batch input_ids shape:", batch["input_ids"].shape)

Batch input_ids shape: torch.Size([2, 512])


  trainer = Trainer(


In [None]:
trainer.train()
trainer.evaluate()

In [8]:
# === error analysis ===
# get predictions on dev set
pred_out = trainer.predict(tokenized_dataset["dev"])
logits = pred_out.predictions
labels = pred_out.label_ids

# convert logits to label IDs
preds = np.argmax(logits, axis=1)

# softmax confidence score
probs = torch.softmax(torch.tensor(logits), dim=1).numpy()

# extract missclassified examples
dev_texts = dataset["dev"]
formatted_texts = [
    f"Arg1: {u1} | Arg2: {u2}"
    for u1, u2 in zip(dev_texts["u1"], dev_texts["u2"])
]

mis = [
    (text, label_encoder.classes_[true], label_encoder.classes_[pred], probs[i][pred])
    for i, (text, true, pred) in enumerate(zip(formatted_texts, labels, preds))
    if true != pred
]

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

out_path = '/content/drive/MyDrive/Colab Notebooks/misclassifications.csv'

# log the results to a csv file
with open(out_path, 'w', newline='', encoding='utf-8') as f:
    writer = csv.writer(f)
    writer.writerow(["text", "true_label", "pred_label", "confidence"])
    writer.writerows(mis)

print(f"✅ Saved {len(mis)} misclassified examples to:\n{out_path}")


=== Classification Report ===
               precision    recall  f1-score   support

  alternation       0.00      0.00      0.00        31
  attribution       0.00      0.00      0.00       242
       causal       0.00      0.00      0.00       203
      comment       0.00      0.00      0.00       165
   concession       0.00      0.00      0.00       162
    condition       0.00      0.00      0.00        69
  conjunction       0.00      0.00      0.00       541
     contrast       0.00      0.00      0.00       167
  elaboration       0.19      1.00      0.32       705
  explanation       0.00      0.00      0.00       284
        frame       0.00      0.00      0.00       182
         mode       0.00      0.00      0.00        79
 organization       0.00      0.00      0.00       244
      purpose       0.00      0.00      0.00       136
        query       0.00      0.00      0.00        74
reformulation       0.00      0.00      0.00       137
     temporal       0.00      0.0

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


✅ Logged 3003 misclassified examples to misclassifications.csv


In [9]:
from collections import Counter
print(Counter(dataset["train"]["label"]))

Counter({8: 804, 6: 540, 16: 325, 9: 238, 12: 235, 1: 226, 10: 177, 7: 155, 2: 151, 13: 137, 3: 137, 4: 118, 5: 82, 11: 79, 15: 73, 14: 63, 0: 34})


In [12]:
!pip freeze > /content/requirements.txt

In [13]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
!cp /content/requirements.txt "/content/drive/MyDrive/Colab Notebooks/requirements.txt"

Mounted at /content/drive
