# Event and argument extraction

Imports to run on Google Colab.

In [None]:
"""
!pip install transformers cache_decorator pytorch_lightning

from google.colab import drive
drive.mount('/content/drive')

from drive.MyDrive.historical_events.irproject.models import (
    ArgumentModelWrapper, EventModel, JointModel, RAMSDataModule,
    RAMSArgumentDataModule
)
"""

In [1]:
from irproject.historical_events import (
    get_rams_data_dict, load_rams_data, sanity_check_preprocessed_data,
    evaluate_arguments_results, evaluate_event_results
)

from irproject.models import (
    ArgumentModelWrapper, EventModel, EventBertModel,
    EventBiLSTMModel, JointModel, 
    RAMSDataModule, RAMSArgumentDataModule,
    EventGenModelWrapper, RAMSEventGenDataModule
)
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from transformers import (
    BartModel, BartTokenizer,
    BertModel, BertTokenizer, BertTokenizerFast
)

seed_everything(42, workers=True)

Global seed set to 42


42

## RAMS dataset

Load and preprocess the RAMS dataset.

In [None]:
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
for split in ["train", "dev", "test"]:
    docs, dicts = load_rams_data(split=split)
    rams_dict = get_rams_data_dict(
        docs,
        tokenizer,
        split=split,
        map_dicts=dicts,
        span_max_length=3 # max span length for events is 3
    )

Check that the preprocessed RAMS dataset does not contain errors.

In [None]:
dm = RAMSDataModule(batch_size=1)
dm.prepare_data()
dm.setup()

for split in ["train", "valid", "test"]:
    sanity_check_preprocessed_data(split, dm=dm)

## Models training and evaluation

### Event extraction model

In [None]:
event_checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints/historical_events/event",
    every_n_epochs=1,
    save_top_k=2,
    monitor="valid_f1",
    mode="max"
)

In [None]:
dm = RAMSDataModule(batch_size=2, num_workers=0)
bert = BertModel.from_pretrained("bert-base-cased")
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = EventModel(
    bert=bert,
    tokenizer=tokenizer,
    bart_tokenizer=bart_tokenizer,
    num_events=140
)
logger = TensorBoardLogger(
    "tb_logs", name="event"
)
trainer = Trainer(
    deterministic=True,
    gpus=1,
    precision=16,
    gradient_clip_val=1,
    logger=logger,
    callbacks=[event_checkpoint_callback],
    # resume_from_checkpoint="checkpoints/historical_events/event/epoch=6-step=25654.ckpt"
)
trainer.fit(model, dm)

In [None]:
dm = RAMSDataModule(batch_size=1, num_workers=0)
bert = BertModel.from_pretrained("bert-base-cased")
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = EventModel(
    bert=bert,
    tokenizer=tokenizer,
    bart_tokenizer=bart_tokenizer,
    num_events=140
).load_from_checkpoint(
    "checkpoints/historical_events/event/epoch=9-step=36649.ckpt",
    bert=bert,
    tokenizer=tokenizer,
    bart_tokenizer=bart_tokenizer,
    num_events=140
)
logger = TensorBoardLogger(
    "tb_logs", name="event"
)
trainer = Trainer(
    deterministic=True,
    gpus=1,
    precision=16,
    gradient_clip_val=1,
    logger=logger,
    callbacks=[event_checkpoint_callback],
    resume_from_checkpoint="checkpoints/historical_events/event/epoch=9-step=36649.ckpt"
)
trainer.test(
    model, 
    dm,
    ckpt_path="checkpoints/historical_events/event/epoch=9-step=36649.ckpt"
)

### BiLSTM event extraction model

In [None]:
event_checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints/historical_events/event_bilstm",
    every_n_epochs=1,
    save_top_k=2,
    monitor="valid_f1",
    mode="max"
)

In [None]:
dm = RAMSDataModule(batch_size=2, num_workers=0)
bert = BertModel.from_pretrained("bert-base-cased")
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = EventBiLSTMModel(
    bert=bert,
    tokenizer=tokenizer,
    bart_tokenizer=bart_tokenizer,
    num_events=140,
    freeze_bert=True,
    dropout_rate=0.5
)
logger = TensorBoardLogger(
    "tb_logs", name="event_bilstm"
)
trainer = Trainer(
    deterministic=True,
    gpus=1,
    precision=16,
    gradient_clip_val=1,
    logger=logger,
    callbacks=[event_checkpoint_callback],
    # resume_from_checkpoint="checkpoints/historical_events/event/epoch=6-step=25654.ckpt"
)
trainer.fit(model, dm)

In [None]:
dm = RAMSDataModule(batch_size=1, num_workers=0)
bert = BertModel.from_pretrained("bert-base-cased")
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = EventBiLSTMModel(
    bert=bert,
    tokenizer=tokenizer,
    bart_tokenizer=bart_tokenizer,
    num_events=140,
    freeze_bert=True,
    dropout_rate=0.5
).load_from_checkpoint(
    "checkpoints/historical_events/event_bilstm/epoch=2-step=10994-v1.ckpt",
    bert=bert,
    tokenizer=tokenizer,
    bart_tokenizer=bart_tokenizer,
    num_events=140,
    freeze_bert=True,
    dropout_rate=0.5
)
logger = TensorBoardLogger(
    "tb_logs", name="event"
)
trainer = Trainer(
    deterministic=True,
    gpus=1,
    precision=16,
    gradient_clip_val=1,
    logger=logger,
    callbacks=[event_checkpoint_callback],
    resume_from_checkpoint="checkpoints/historical_events/event_bilstm/epoch=2-step=10994-v1.ckpt"
)
trainer.test(
    model, 
    dm,
    ckpt_path="checkpoints/historical_events/event_bilstm/epoch=2-step=10994-v1.ckpt"
)

### Bert event extraction model

In [None]:
event_checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints/historical_events/event_bert",
    every_n_epochs=1,
    save_top_k=2,
    monitor="valid_f1",
    mode="max"
)

In [None]:
dm = RAMSDataModule(batch_size=2, num_workers=0)
bert = BertModel.from_pretrained("bert-base-cased")
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = EventBertModel(
    bert=bert,
    tokenizer=tokenizer,
    bart_tokenizer=bart_tokenizer,
    num_events=140
)
logger = TensorBoardLogger(
    "tb_logs", name="event_bert"
)
trainer = Trainer(
    deterministic=True,
    gpus=1,
    precision=16,
    gradient_clip_val=1,
    logger=logger,
    callbacks=[event_checkpoint_callback],
    # resume_from_checkpoint="checkpoints/historical_events/event/epoch=6-step=25654.ckpt"
)
trainer.fit(model, dm)

In [None]:
dm = RAMSDataModule(batch_size=1, num_workers=0)
bert = BertModel.from_pretrained("bert-base-cased")
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = EventBertModel(
    bert=bert,
    tokenizer=tokenizer,
    bart_tokenizer=bart_tokenizer,
    num_events=140
).load_from_checkpoint(
    "checkpoints/historical_events/event_bert/epoch=1-step=7329.ckpt",
    bert=bert,
    tokenizer=tokenizer,
    bart_tokenizer=bart_tokenizer,
    num_events=140
)
logger = TensorBoardLogger(
    "tb_logs", name="event_bert"
)
trainer = Trainer(
    deterministic=True,
    gpus=1,
    precision=16,
    gradient_clip_val=1,
    logger=logger,
    callbacks=[event_checkpoint_callback],
    resume_from_checkpoint="checkpoints/historical_events/event_bert/epoch=1-step=7329.ckpt"
)
trainer.test(
    model, 
    dm,
    ckpt_path="checkpoints/historical_events/event_bert/epoch=1-step=7329.ckpt"
)

### Gen. event extraction model

In [None]:
event_checkpoint_callback = ModelCheckpoint(
    # dirpath="/content/drive/MyDrive/historical_events/checkpoints/argument",
    dirpath="checkpoints/historical_events/event_gen",
    every_n_epochs=1,
    save_top_k=2,
    monitor="valid_loss",
    mode="min"
)

In [None]:
dm = RAMSEventGenDataModule(
    batch_size=1, 
    num_workers=0,
    # data_dir="/content/drive/MyDrive/historical_events/datasets/rams",
    pin_memory=False
)
bart = BartModel.from_pretrained("facebook/bart-base")
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = EventGenModelWrapper(
    bart=bart,
    bart_tokenizer=bart_tokenizer
)
logger = TensorBoardLogger(
    "tb_logs", name="event_gen"
)
trainer = Trainer(
    deterministic=True,
    gpus=1,
    precision=16,
    gradient_clip_val=1, 
    logger=logger,
    callbacks=[event_checkpoint_callback],
    resume_from_checkpoint="checkpoints/historical_events/event_gen/epoch=0-step=7328.ckpt"
)
trainer.fit(model, dm)

In [None]:
dm = RAMSEventGenDataModule(
    batch_size=1, 
    num_workers=0,
    # data_dir="/content/drive/MyDrive/historical_events/datasets/rams",
    pin_memory=False
)
bart = BartModel.from_pretrained("facebook/bart-base")
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = EventGenModelWrapper(
    bart=bart,
    bart_tokenizer=bart_tokenizer
).load_from_checkpoint(
    "checkpoints/historical_events/event_gen/epoch=1-step=14657.ckpt",
    bart=bart,
    bart_tokenizer=bart_tokenizer
)
logger = TensorBoardLogger(
    "tb_logs", name="event_gen"
)
trainer = Trainer(
    deterministic=True,
    gpus=1,
    precision=16,
    gradient_clip_val=1, 
    logger=logger,
    callbacks=[event_checkpoint_callback],
    resume_from_checkpoint="checkpoints/historical_events/event_gen/epoch=1-step=14657.ckpt"
)
trainer.test(
    model, 
    dm,
    ckpt_path="checkpoints/historical_events/event_gen/epoch=1-step=14657.ckpt"
)

In [2]:
evaluate_event_results()

HBox(children=(HTML(value='Evaluating prediction'), FloatProgress(value=0.0, max=871.0), HTML(value='')))

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


### Argument extraction model

In [None]:
argument_checkpoint_callback = ModelCheckpoint(
    # dirpath="/content/drive/MyDrive/historical_events/checkpoints/argument",
    dirpath="checkpoints/historical_events/argument",
    every_n_epochs=1,
    save_top_k=2,
    monitor="valid_loss",
    mode="min"
)

In [None]:
dm = RAMSArgumentDataModule(
    batch_size=1, 
    num_workers=0,
    # data_dir="/content/drive/MyDrive/historical_events/datasets/rams",
    pin_memory=False
)
bart = BartModel.from_pretrained("facebook/bart-base")
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = ArgumentModelWrapper(
    bart=bart,
    bart_tokenizer=bart_tokenizer
)
logger = TensorBoardLogger(
    "tb_logs", name="argument"
)
trainer = Trainer(
    deterministic=True,
    gpus=1,
    precision=16,
    gradient_clip_val=1, 
    logger=logger,
    callbacks=[argument_checkpoint_callback],
    resume_from_checkpoint="checkpoints/historical_events/argument/epoch=0-step=7328.ckpt"
)
trainer.fit(model, dm)

In [None]:
dm = RAMSArgumentDataModule(
    batch_size=1, 
    num_workers=0,
    # data_dir="/content/drive/MyDrive/historical_events/datasets/rams",
    pin_memory=False
)
bart = BartModel.from_pretrained("facebook/bart-base")
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = ArgumentModelWrapper(
    bart=bart,
    bart_tokenizer=bart_tokenizer
).load_from_checkpoint(
    "checkpoints/historical_events/argument/epoch=2-step=21986.ckpt",
    bart=bart,
    bart_tokenizer=bart_tokenizer
)
logger = TensorBoardLogger(
    "tb_logs", name="argument"
)
trainer = Trainer(
    deterministic=True,
    gpus=1,
    precision=16,
    gradient_clip_val=1, 
    logger=logger,
    callbacks=[argument_checkpoint_callback],
    resume_from_checkpoint="checkpoints/historical_events/argument/epoch=2-step=21986.ckpt"
)
trainer.test(
    model, 
    dm,
    ckpt_path="checkpoints/historical_events/argument/epoch=2-step=21986.ckpt"
)

In [None]:
evaluate_arguments_results()

### Joint (event + argument) model

High RAM demand. Possible crashes.

In [None]:
joint_checkpoint_callback = ModelCheckpoint(
    dirpath="/content/drive/MyDrive/historical_events/checkpoints/joint",
    every_n_epochs=1,
    # save_on_train_epoch_end=True,
    # save_last=True
    # every_n_train_steps=20,
    save_weights_only=True
)

In [None]:
dm = RAMSDataModule(
    batch_size=1, 
    num_workers=0,
    data_dir="/content/drive/MyDrive/historical_events/datasets/rams",
    pin_memory=False
)
bert = BertModel.from_pretrained("bert-base-cased")
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
bart = BartModel.from_pretrained("facebook/bart-base")
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
event_model = EventModel(
    bert=bert,
    tokenizer=tokenizer,
    bart_tokenizer=bart_tokenizer,
    num_events=140,
    data_dir="/content/drive/MyDrive/historical_events/datasets/rams",
    ontology_dir="/content/drive/MyDrive/historical_events/datasets"
)
argument_model = ArgumentModelWrapper(
    bart=bart,
    bart_tokenizer=bart_tokenizer
)
model = JointModel(
    event_model=event_model,
    argument_model=argument_model
)
trainer = Trainer(
    deterministic=True,
    gpus=1,
    precision=16,
    gradient_clip_val=1, 
    num_sanity_val_steps=0,
    callbacks=[joint_checkpoint_callback]
)
trainer.fit(model, dm)

## Wikipedia data set test

In [None]:
import json
import torch
import re

from irproject.historical_events import (
    get_event_names_dict, load_ontology,
    load_rams_data, template2tokens
)
from tqdm.notebook import tqdm

In [None]:
evt2sent, sent2evt, _, _ = get_event_names_dict()
docs, dicts = load_rams_data()
evt2idx = dicts["evt2idx"]
ontology_dict = load_ontology()

In [None]:
evt_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
evt_tokenizer.add_tokens([" <arg>", " <trg>", " <evt>"])
bart1 = BartModel.from_pretrained("facebook/bart-base")
bart2 = BartModel.from_pretrained("facebook/bart-base")
bart_tokenizer1 = BartTokenizer.from_pretrained("facebook/bart-base")
bart_tokenizer2 = BartTokenizer.from_pretrained("facebook/bart-base")

In [None]:
with open(
    "datasets/historical_events/wiki_dataset.json",
    encoding="utf-8"
) as f_in:
    data = json.load(f_in)

In [None]:
evt_model = EventGenModelWrapper(
    bart=bart1,
    bart_tokenizer=bart_tokenizer1
).load_from_checkpoint(
    "checkpoints/historical_events/event_gen/epoch=1-step=14657.ckpt",
    bart=bart1,
    bart_tokenizer=bart_tokenizer1
)

arg_model = ArgumentModelWrapper(
    bart=bart2,
    bart_tokenizer=bart_tokenizer2
).load_from_checkpoint(
    "checkpoints/historical_events/argument/epoch=2-step=21986.ckpt",
    bart=bart2,
    bart_tokenizer=bart_tokenizer2
)

In [None]:
evt_model.to("cuda")
arg_model.to("cuda")

In [None]:
evt_template_in = "This document is about <evt>"

res = {
    "results": []
}

for idx, paragraph in tqdm(
    enumerate(data["paragraphs"]),
    total=len(data["paragraphs"]),
    desc="Processing paragraph",
    leave=False
):
    if paragraph["historical"] == 1:
        text = paragraph["clean_content"]

        context = evt_tokenizer.tokenize(
            text,
            add_prefix_space=True
        )

        if context == []:
            continue

        evt_in = evt_tokenizer.encode_plus(
            evt_template_in, 
            context, 
            add_special_tokens=True,
            add_prefix_space=True,
            max_length=424,
            truncation="only_second",
            padding="max_length"
        )

        evt_input_ids = torch.tensor([evt_in["input_ids"]]).to("cuda")

        evt_res = evt_model.model.generate(
            input_ids=evt_input_ids, 
            do_sample=True, 
            top_k=20, 
            top_p=0.95,
            max_length=30, 
            num_return_sequences=1,
            num_beams=1
        )

        predicted_evt_sent = evt_tokenizer.decode(
            evt_res[0], 
            skip_special_tokens=True
        )

        _RE_COMBINE_WHITESPACE = re.compile(r"\s+")
        predicted_evt_sent = _RE_COMBINE_WHITESPACE.sub(" ", predicted_evt_sent).strip()
        predicted_evt_sent = re.sub("This document is about ", "", predicted_evt_sent)

        if predicted_evt_sent in sent2evt.keys():
            predicted_evt = sent2evt[predicted_evt_sent]
            template = ontology_dict[predicted_evt]["template"]
        else:
            print("No event found:", predicted_evt_sent)
            continue
        
        template_in = template2tokens(
            template, evt_tokenizer
        )

        arg_in = evt_tokenizer.encode_plus(
            template_in, 
            context, 
            add_special_tokens=True,
            add_prefix_space=True,
            max_length=424,
            truncation="only_second",
            padding="max_length"
        )

        arg_input_ids = torch.tensor([arg_in["input_ids"]]).to("cuda")

        arg_res = arg_model.model.generate(
            input_ids=arg_input_ids, 
            do_sample=True, 
            top_k=20, 
            top_p=0.95,
            max_length=30, 
            num_return_sequences=1,
            num_beams=1
        )

        arg_sent = evt_tokenizer.decode(
            arg_res[0],
            skip_special_tokens=True
        )

        d = {
            "text": text,
            "predicted_evt": predicted_evt_sent,
            "predicted_args": arg_sent
        }

        res["results"].append(d)

        if idx % 100 == 0:
            with open(
                "results/historical_events/wiki_dataset/predictions.json", 
                "w", 
                encoding="utf-8"
            ) as f:
                json.dump(res, f, indent=4)


with open(
    "results/historical_events/wiki_dataset/predictions.json", 
    "w", 
    encoding="utf-8"
) as f:
    json.dump(res, f, indent=4)