In [1]:
!pip install torch transformers datasets
!pip install numpy==1.26.4

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [1]:
from google.colab import userdata
import os
os.environ['GIT_TOKEN'] = secret_value = userdata.get('GIT_TOKEN')

In [2]:
!git clone https://$GIT_TOKEN@github.com/Abhishek-P/disrpt25-task.git

Cloning into 'disrpt25-task'...
remote: Enumerating objects: 64, done.[K
remote: Counting objects: 100% (64/64), done.[K
remote: Compressing objects: 100% (50/50), done.[K
remote: Total 64 (delta 36), reused 32 (delta 13), pack-reused 0 (from 0)[K
Receiving objects: 100% (64/64), 39.25 KiB | 648.00 KiB/s, done.
Resolving deltas: 100% (36/36), done.


In [3]:
%cd '/content/disrpt25-task'
!ls
!pip install conllu

/content/disrpt25-task
disrptdata.py  mapping_disrpt25.json  requirements.txt	util.py
LICENSE        README.MD	      run_model.py
Collecting conllu
  Downloading conllu-6.0.0-py3-none-any.whl.metadata (21 kB)
Downloading conllu-6.0.0-py3-none-any.whl (16 kB)
Installing collected packages: conllu
Successfully installed conllu-6.0.0


In [4]:
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)
import sys
sys.path.append('/content/disrpt25-task')

Mounted at /content/drive/


In [5]:
from transformers import MT5Tokenizer, MT5ForConditionalGeneration
import torch
import torch.nn as nn
import datasets
from transformers import Trainer, TrainingArguments, DataCollatorWithPadding
from sklearn.metrics import classification_report, accuracy_score, f1_score
import argparse
import numpy as np
import csv

In [6]:
import disrptdata
disrptdata.DATA_DIR = "/content/drive/MyDrive/Colab Notebooks/sample_data"

In [8]:
# === load dataset ===
"""
experiment with two languages: Chinese and English
"""
DATA_DIR = "/content/drive/MyDrive/Colab Notebooks/sample_data"

# load and combine datasets
zho = disrptdata.get_dataset("zho.rst.gcdt")
eng = disrptdata.get_dataset("eng.erst.gum")
combined = disrptdata.get_combined_dataset()
print("Train examples:", combined["train"].num_rows)
print("Dev examples:", combined["dev"].num_rows)

train_dataset = combined['train']
dev_dataset = combined['dev']
train_dataset = train_dataset.class_encode_column('label')
dev_dataset   = dev_dataset.class_encode_column('label')

tokenizer = MT5Tokenizer.from_pretrained("google/mt5-base")

# mappings
lang2id = {'eng': 0, 'zho': 1}
fw2id = {'rst': 0, 'erst': 1}
corpus2id = {'gum': 0, 'gcdt': 1}
dir2id = {'1>2': 0, '1<2': 1}


# preprocess data
def preprocess(example):
    text = f"Classify: Arg1: {example['u1']} Arg2: {example['u2']}"
    encoded = tokenizer(text, padding="max_length", truncation=True, max_length=512)
    encoded["label"] = example["label"]

    # add meta info ids
    encoded["language_ids"] = lang2id[example["lang"]]
    encoded["framework_ids"] = fw2id[example["framework"]]
    encoded["corpus_ids"] = corpus2id[example["corpus"]]
    encoded["direction_ids"] = dir2id[example["direction"]]

    return encoded

# will use batch when dataset size scales up
train_tokenized = train_dataset.map(preprocess, batched=False)
dev_tokenized   = dev_dataset.map(preprocess, batched=False)

train_tokenized.set_format('torch', columns=[
    "input_ids", "attention_mask", "label",
    "language_ids", "framework_ids", "corpus_ids", "direction_ids"
])
dev_tokenized.set_format('torch', columns=[
    "input_ids", "attention_mask", "label",
    "language_ids", "framework_ids", "corpus_ids", "direction_ids"
])

Found the following datasets in the data directory:
Train examples: 4527
Dev examples: 4714


Casting to class labels:   0%|          | 0/4527 [00:00<?, ? examples/s]

Casting to class labels:   0%|          | 0/4714 [00:00<?, ? examples/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'T5Tokenizer'. 
The class this function is called from is 'MT5Tokenizer'.


Map:   0%|          | 0/4527 [00:00<?, ? examples/s]

Map:   0%|          | 0/4714 [00:00<?, ? examples/s]

In [11]:
class MT5Classifier(nn.Module):
    def __init__(self, num_labels, num_languages=2, num_frameworks=2, num_corpora=2, lang_emb_dim=None):
        super().__init__()
        self.encoder = MT5ForConditionalGeneration.from_pretrained("google/mt5-small").get_encoder()
        hidden_size = self.encoder.config.d_model

        # embedding layer for language, corpus, direction, framework
        self.language_embedding = nn.Embedding(num_languages, hidden_size)
        self.framework_embedding = nn.Embedding(num_frameworks, hidden_size)
        self.corpus_embedding = nn.Embedding(num_corpora, hidden_size)
        self.direction_embedding = nn.Embedding(2, hidden_size)  # 0 = 1>2, 1 = 1<2

        # classifier head
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(hidden_size, num_labels)
        self.loss_fct = nn.CrossEntropyLoss()


    def forward(self, input_ids, attention_mask,
            language_ids=None, framework_ids=None, corpus_ids=None, direction_ids=None,
            labels=None):

        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        pooled = outputs.last_hidden_state.mean(dim=1)  # shape: [batch_size, hidden_size]

        # add metadata embeddings
        if language_ids is not None:
          pooled += self.language_embedding(language_ids)

        if framework_ids is not None:
          pooled += self.framework_embedding(framework_ids)

        if corpus_ids is not None:
          pooled += self.corpus_embedding(corpus_ids)

        if direction_ids is not None:
          pooled += self.direction_embedding(direction_ids)

        logits = self.classifier(pooled)
        loss = self.loss_fct(logits, labels) if labels is not None else None
        return {"loss": loss, "logits": logits}

In [12]:
num_labels = train_tokenized.features["label"].num_classes
num_languages = len(lang2id)
num_frameworks = len(fw2id)
num_corpora = len(corpus2id)

model = MT5Classifier(
    num_labels=num_labels,
    num_languages=num_languages,
    num_frameworks=num_frameworks,
    num_corpora=num_corpora,
    lang_emb_dim=None
)

config.json:   0%|          | 0.00/553 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [13]:
# === training setup ===
use_cuda = True
device = torch.device("cuda" if torch.cuda.is_available() and use_cuda else "cpu")

training_args = TrainingArguments(
    output_dir="/content/drive/MyDrive/mt5_embedding/mt5_classifier_embedding.results",
    overwrite_output_dir=False,
    learning_rate=2e-5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=2,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    auto_find_batch_size=True,
)


data_collator = DataCollatorWithPadding(tokenizer)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=1)
    acc = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds, average="weighted")

    # get label names from the feature
    label_names = train_tokenized.features["label"].names

    report = classification_report(labels, preds, target_names=label_names)
    print("\n=== Classification Report ===")
    print(report)
    return {"accuracy": acc, "f1": f1}


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=dev_tokenized,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

batch = next(iter(trainer.get_train_dataloader()))
print("Batch input_ids shape:", batch["input_ids"].shape)

  trainer = Trainer(


Batch input_ids shape: torch.Size([2, 512])


In [14]:
trainer.train()
trainer.evaluate()



<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mjw2175[0m ([33mjw2175-georgetown-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss,Accuracy,F1
1,2.2867,2.369758,0.258379,0.153531
2,2.2251,2.336187,0.263258,0.163887



=== Classification Report ===
               precision    recall  f1-score   support

  alternation       0.00      0.00      0.00        36
  attribution       0.09      0.01      0.01       318
       causal       0.00      0.00      0.00       238
      comment       0.00      0.00      0.00       180
   concession       0.00      0.00      0.00       189
    condition       0.00      0.00      0.00        93
  conjunction       0.43      0.30      0.36       777
     contrast       0.00      0.00      0.00       199
  elaboration       0.26      0.80      0.39       879
  explanation       0.00      0.00      0.00       382
        frame       0.00      0.00      0.00       235
         mode       0.00      0.00      0.00       104
 organization       0.20      0.87      0.32       315
      purpose       0.00      0.00      0.00       161
        query       0.00      0.00      0.00        79
reformulation       0.00      0.00      0.00       156
     temporal       0.00      0.0

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



=== Classification Report ===
               precision    recall  f1-score   support

  alternation       0.00      0.00      0.00        36
  attribution       0.20      0.13      0.16       318
       causal       0.00      0.00      0.00       238
      comment       0.00      0.00      0.00       180
   concession       0.00      0.00      0.00       189
    condition       0.00      0.00      0.00        93
  conjunction       0.44      0.29      0.35       777
     contrast       0.00      0.00      0.00       199
  elaboration       0.26      0.83      0.39       879
  explanation       0.00      0.00      0.00       382
        frame       0.00      0.00      0.00       235
         mode       0.00      0.00      0.00       104
 organization       0.21      0.78      0.33       315
      purpose       0.00      0.00      0.00       161
        query       0.00      0.00      0.00        79
reformulation       0.00      0.00      0.00       156
     temporal       0.00      0.0

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



=== Classification Report ===
               precision    recall  f1-score   support

  alternation       0.00      0.00      0.00        36
  attribution       0.20      0.13      0.16       318
       causal       0.00      0.00      0.00       238
      comment       0.00      0.00      0.00       180
   concession       0.00      0.00      0.00       189
    condition       0.00      0.00      0.00        93
  conjunction       0.44      0.29      0.35       777
     contrast       0.00      0.00      0.00       199
  elaboration       0.26      0.83      0.39       879
  explanation       0.00      0.00      0.00       382
        frame       0.00      0.00      0.00       235
         mode       0.00      0.00      0.00       104
 organization       0.21      0.78      0.33       315
      purpose       0.00      0.00      0.00       161
        query       0.00      0.00      0.00        79
reformulation       0.00      0.00      0.00       156
     temporal       0.00      0.0

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


{'eval_loss': 2.3361873626708984,
 'eval_accuracy': 0.2632583792957149,
 'eval_f1': 0.16388686296075097,
 'eval_runtime': 59.9086,
 'eval_samples_per_second': 78.686,
 'eval_steps_per_second': 39.343,
 'epoch': 2.0}

In [None]:
# === error analysis ===
# get predictions on dev set
pred_out = trainer.predict(dev_tokenized)
logits = pred_out.predictions
labels = pred_out.label_ids

# convert logits to label IDs
preds = np.argmax(logits, axis=1)

# softmax confidence score
probs = torch.softmax(torch.tensor(logits), dim=1).numpy()

# extract missclassified examples
dev_texts = dev_tokenized
formatted_texts = [
    f"Arg1: {u1} | Arg2: {u2}"
    for u1, u2 in zip(dev_texts["u1"], dev_texts["u2"])
]

mis = [
    (text, label_encoder.classes_[true], label_encoder.classes_[pred], probs[i][pred])
    for i, (text, true, pred) in enumerate(zip(formatted_texts, labels, preds))
    if true != pred
]

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

out_path = '/content/drive/MyDrive/mt5_vanilla/misclassifications_metainfo.csv'

# log the results to a csv file
with open(out_path, 'w', newline='', encoding='utf-8') as f:
    writer = csv.writer(f)
    writer.writerow(["text", "true_label", "pred_label", "confidence"])
    writer.writerows(mis)

print(f"✅ Saved {len(mis)} misclassified examples to:\n{out_path}")