In [None]:
!pip install transformers datasets evaluate sacrebleu

In [None]:
!wget https://pmb.let.rug.nl/releases/exp_data_4.0.0.zip
!unzip exp_data_4.0.0.zip

In [None]:
!git clone https://github.com/WPoelman/ud-boxer.git ud_boxer_repo
!pip install -r ud_boxer_repo/requirements/requirements.txt

In [None]:
# https://github.com/huggingface/transformers/blob/main/examples/pytorch/translation/run_translation.py

In [6]:
import sys
sys.path.append('/content/ud_boxer_repo')

In [22]:
import re
import os
import random
from collections import defaultdict

import numpy as np
import torch
import evaluate
from datasets import Dataset, DatasetDict
from transformers import (
    T5ForConditionalGeneration, AutoTokenizer,
    DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, 
    Seq2SeqTrainer
)

from ud_boxer_repo.ud_boxer.sbn import SBNGraph
from ud_boxer_repo.ud_boxer.helpers import smatch_score

In [10]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [12]:
space_re = re.compile('\s+')

parts = ['train', 'dev', 'test']
data_dir = 'exp_data_4.0.0'
data = {p: [] for p in parts}

for lang in os.listdir(data_dir):
  gold_folder = os.path.join(data_dir, lang, 'gold')
  for part in parts:
    with open(os.path.join(gold_folder, f'{part}.txt.raw')) as f:
      raw_sents = f.read().strip().split('\n')
    
    with open(os.path.join(gold_folder, f'{part}.txt.sbn')) as f:
      all_drg = f.read().strip()
      all_drg = all_drg.split('\n\n')
      all_drg = [
        '\n'.join(
          space_re.sub(' ', line.split('%')[0]).strip()
          for line in raw_drg.split('\n') 
          if not line.startswith('%%%')
        ).strip()
        for raw_drg in all_drg
      ]
    
    assert len(raw_sents) == len(all_drg)

    part_data = [
        {'lang': lang, 'text': raw_text, 'drg': drg}
        for raw_text, drg in zip(raw_sents, all_drg)
    ]
    data[part].extend(part_data)

In [15]:
stats = defaultdict(int)
for part, part_data in data.items():
  for info in part_data:
    stats[info['lang']] += 1

for l, num in stats.items():
  print(l, num)

nl 1467
it 1686
de 2844
en 9885


In [17]:
data['train'][0]

{'lang': 'nl',
 'text': 'Ze zei dat ze gelukkig was.',
 'drg': 'female.n.02\nsay.v.01 Agent -1 Proposition +1 Time +1\ntime.n.08 TPR now\nATTRIBUTION -1\nfemale.n.02 EQU -3\nhappy.a.01 Time -2 Experiencer -1\ntime.n.08 TPR now EQU -3'}

In [56]:
ds = {part: Dataset.from_list(part_data) for part, part_data in data.items()}
ds = DatasetDict(ds)
ds = ds.shuffle(seed=SEED)

In [29]:
ds

DatasetDict({
    train: Dataset({
        features: ['lang', 'text', 'drg'],
        num_rows: 10630
    })
    dev: Dataset({
        features: ['lang', 'text', 'drg'],
        num_rows: 2705
    })
    test: Dataset({
        features: ['lang', 'text', 'drg'],
        num_rows: 2547
    })
})

In [30]:
model = T5ForConditionalGeneration.from_pretrained("google/byt5-small")
tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")

Downloading:   0%|          | 0.00/698 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.20G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.59k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.50k [00:00<?, ?B/s]

In [31]:
max_len = 512

def process(examples):
  model_inputs = tokenizer(examples['text'], max_length=max_len, truncation=True)
  labels = tokenizer(examples['drg'], max_length=max_len, truncation=True)
  model_inputs['labels'] = labels['input_ids']
  return model_inputs

ds = ds.map(process, batched=True)

  0%|          | 0/11 [00:00<?, ?ba/s]

  0%|          | 0/3 [00:00<?, ?ba/s]

  0%|          | 0/3 [00:00<?, ?ba/s]

In [32]:
metric = evaluate.load("chrf")

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]
    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"chrf": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}

    return result

Downloading builder script:   0%|          | 0.00/9.01k [00:00<?, ?B/s]

In [33]:
label_pad_token_id = -100
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=label_pad_token_id
)

In [34]:
training_args = Seq2SeqTrainingArguments(
    output_dir='results',
    report_to='none',
    evaluation_strategy='epoch',
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    save_strategy='epoch',
    save_total_limit=3,
    predict_with_generate=True,
    generation_max_length=max_len,
    generation_num_beams=3
)

trainer = Seq2SeqTrainer(
  model=model,
  args=training_args,
  train_dataset=ds['train'],
  eval_dataset=ds['dev'],
  tokenizer=tokenizer,
  data_collator=data_collator,
  compute_metrics=compute_metrics
  # callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
)

max_steps is given, it will override any value given in num_train_epochs


In [None]:
trainer.train()

In [57]:
test_example = ds['test'][0]
print(test_example['text'])
print(test_example['drg'])

Tom sighed.
male.n.02 Name "Tom"
sigh.v.01 Agent -1 Time +1
time.n.08 TPR now


In [None]:
outputs = model.generate(
    torch.tensor([test_example['input_ids']]).cuda(), 
    max_new_tokens=512,
    num_beams=3
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

In [None]:
preds = trainer.predict(ds['test'])
decoded_preds = tokenizer.batch_decode(preds[0], skip_special_tokens=True)

In [51]:
y_true = defaultdict(list)
y_pred = defaultdict(list)

for y_true_drg, y_pred_drg in zip(ds['test'], decoded_preds):
  lang = y_true_drg['lang']

  try:
    y_true_penman = SBNGraph().from_string(y_true_drg['drg']).to_penman_string()
  except Exception as e:
    print('error in GS', [y_true_drg], e, '', sep='\n')
    continue

  try:
    y_pred_penman = SBNGraph().from_string(y_pred_drg).to_penman_string()
  except Exception as e:
    # print('error in pred', [y_pred_drg], e, '', sep='\n')
    continue
  
  y_true[lang].append(y_true_penman)
  y_pred[lang].append(y_pred_penman)

In [52]:
total_scores = defaultdict(lambda: defaultdict(list))

for lang in y_true.keys():
  y_true_lang = y_true[lang]
  y_pred_lang = y_pred[lang]
  for yt, yp in zip(y_true_lang, y_pred_lang):
    with open('tempgold', "w") as gold_f:
      gold_f.write(yt)

    with open('temppred', "w") as pred_f:
        pred_f.write(yp)

    scores = smatch_score('/content/tempgold', '/content/temppred')
    for k, v in scores.items():
      total_scores[lang][k].append(v)

In [None]:
final_scores = {
    lang: {k: sum(v) / len(v) for k, v in lang_scores.items()}
    for lang, lang_scores in total_scores.items()
}
final_scores