In [None]:
# !nvidia-smi -L

In [None]:
# !ls

In [None]:
# !ls LUKE

In [None]:
# !rm -r LUKE # delete old runs

In this notebook, we are going to fine-tune [`LukeForEntityPairClassification`](https://huggingface.co/transformers/model_doc/luke.html#lukeforentitypairclassification) on a supervised **relation extraction** dataset.

The goal for the model is to predict, given a sentence and the character spans of two entities within the sentence, the relationship between the entities.

* Paper: https://arxiv.org/abs/2010.01057
* Original repository: https://github.com/studio-ousia/luke

In [None]:
!pip install -q transformers pandas sklearn pytorch-lightning wandb unidecode

In [None]:
from LUKE_model import RelationExtractionDataset, LUKE
import wandb
import torch
import pytorch_lightning
import pandas as pd

In [None]:
pytorch_lightning.utilities.seed.seed_everything(seed=42)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

## Read in data

Let's download the data from the web, hosted on Dropbox.

Each row in the dataframe consists of a news article. 

In [None]:
df = pd.read_pickle("https://www.dropbox.com/s/rdwg8d76ytqqgdy/training-data-29-03-22.pkl?dl=1")
df

Shuffle the data:

In [None]:
df = df.sample(frac=1)

This is the frequency of our relation types:



In [None]:
df.rels.explode().value_counts()

We don't want to consider the less relevant and less consistently applied labels of 'supply', 'exploration', and 'technology transfer'. Let's remove them:

In [None]:
# df['rels'] = df.rels.apply(lambda rels: [rel for rel in rels if rel not in ['TechnologyTransfer', 'Supply', 'Exploration']])

In [None]:
rel_frequencies = df.rels.explode().value_counts()
rel_frequencies

Let's one hot encode the relationship classes:

In [None]:
rel_label_names = rel_frequencies.index.to_list()

def label2ids(labels, rel_label_names=rel_label_names):
  return [1 if rel_name in labels else 0 for rel_name in rel_label_names]

df['rel_one_hot'] = df.rels.apply(label2ids)

def ids2labels(ids, rel_label_names=rel_label_names):
  labels = []
  for idx, label in enumerate(rel_label_names):
    if ids[idx] == 1:
      labels.append(label)
    
  return labels

df.sample(10)

## Define the PyTorch dataset and dataloaders

Next, we define regular PyTorch datasets and corresponding dataloaders. In PyTorch, you need to define a `Dataset` class that inherits from `torch.utils.data.Dataset`, and you need to implement 3 methods: the `init` method (for initializing the dataset with data), the `len` method (which returns the number of elements in the dataset) and the `getitem()` method, which returns a single item from the dataset.

In our case, each item of the dataset consists of a sentence, the spans of 2 entities in the sentence, and a label of the relationship. We use `LukeTokenizer` (available in the Transformers library) to turn these into the inputs expected by the model, which are `input_ids`, `entity_ids`, `attention_mask`, `entity_attention_mask` and `entity_position_ids`.

For more information regarding these inputs, refer to the [docs](https://huggingface.co/transformers/model_doc/luke.html#lukeforentitypairclassification) of `LukeForEntityPairClassification`.


Let's set our hyperparameters:

In [None]:
MAX_LEN = 128
LEARNING_RATE = 1e-5
BATCH_SIZE = 128
MAX_EPOCHS = 10
THRESHOLDS = torch.tensor([0.5 for rel in rel_label_names]).to(device) # can change this for class-specific thresholds
WEIGHT_DECAY= 0.01
GRAD_CLIP_VAL = 0
CLASS_WEIGHTS = torch.tensor([1, 1, 1, 1, 1, 4, 1, 4]).to(device) # equal class weights

Here we instantiate the class defined above with 3 objects: a training dataset, a validation dataset and a test set.

In [None]:
from sklearn.model_selection import train_test_split

train_size = 0.8
validation_size = 0.1
test_size = 0.1

train_df, test_df = train_test_split(df, test_size=1-train_size, shuffle=True)
val_df, test_df = train_test_split(test_df, test_size=test_size/(test_size+validation_size), shuffle=False)

print("FULL Dataset: {}".format(len(df)))
print("TRAIN Dataset: {}".format(len(train_df)))
print("TEST Dataset: {}".format(len(test_df)))
print("VALIDATION Dataset: {}".format(len(val_df)))

# define the dataset
train_dataset = RelationExtractionDataset(data=train_df)
val_dataset = RelationExtractionDataset(data=val_df)
test_dataset = RelationExtractionDataset(data=test_df)

In [None]:
train_dataset[0].keys()

## Define a PyTorch LightningModule

Let's define the model as a PyTorch LightningModule. A `LightningModule` is actually an `nn.Module`, but with some extra functionality.

For more information regarding how to define this, see the [docs](https://pytorch-lightning.readthedocs.io/en/latest/?_ga=2.56317931.1395871250.1622709933-1738348008.1615553774) of PyTorch Lightning.

In [None]:
# import gc
# gc.collect()
# torch.cuda.empty_cache()

In [None]:
print(f'Free CUDA memory: {(torch.cuda.get_device_properties(0).total_memory-torch.cuda.memory_allocated(0))/1e9:.2f}GB')

In [None]:
model = LUKE(num_labels=len(rel_label_names), lr=LEARNING_RATE, batch_size=BATCH_SIZE, class_weights=CLASS_WEIGHTS,
                 thresholds=THRESHOLDS, weight_decay=WEIGHT_DECAY,
                 datasets={'train_dataset': train_dataset, 'val_dataset': val_dataset, 'test_dataset': test_dataset})

## Train the model

Let's train the model. We also use early stopping, to avoid overfitting the training dataset. We also log everything to Weights and Biases, which will give us beautiful charts of the loss and accuracy plotted over time.

If you haven't already, you can create an account on the [website](https://wandb.ai/site), then log in in a web browser, and run the cell below: 

In [None]:
wandb.login()

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

wandb_logger = WandbLogger(project='LUKE')

# for early stopping, see https://pytorch-lightning.readthedocs.io/en/1.0.0/early_stopping.html?highlight=early%20stopping
early_stop_callback = EarlyStopping(
    monitor='val_f1_macro_avg',
    patience=2,
    strict=False,
    verbose=True,
    mode='max',
    min_delta=0 # stops training once no more improvement is made
)

checkpoint_callback = ModelCheckpoint(dirpath='LUKE', save_last=True, verbose=True,
                                      monitor='val_f1_macro_avg', mode='max', every_n_epochs=1)

trainer = Trainer(gpus=1, logger=wandb_logger, callbacks=[early_stop_callback, checkpoint_callback], 
                  max_epochs=MAX_EPOCHS, precision=16, 
                  stochastic_weight_avg=True, auto_lr_find=True, 
                  benchmark=True, deterministic=True,
                  val_check_interval=0.1, gradient_clip_val=GRAD_CLIP_VAL, fast_dev_run=False)
                    # limit_train_batches=0.25

Finding the optimal learning rate:

In [None]:
# lr_finder = trainer.tuner.lr_find(model)

In [None]:
# lr_finder.suggestion()

Setting the learning rate to the optimal rate:

In [None]:
# model.hparams.lr = lr_finder.suggestion()

Let's train the model:

In [None]:
trainer.fit(model)

In [None]:
trainer.test()

In [None]:
checkpoint_callback.best_model_path

In [None]:
# !ls LUKE

## Evaluation

Instead of calling `trainer.test()`, we can also manually evaluate the model on the entire test set:

In [None]:
outputs = trainer.predict(dataloaders=model.test_dataloader())

In [None]:
# # model = LUKE.load_from_checkpoint(checkpoint_path=checkpoint_callback.best_model_path)
# model = LUKE.load_from_checkpoint(checkpoint_path=checkpoint_callback.best_model_path)
# model = model.to(device)

In [None]:
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

predictions_total = []
logits_total = []
labels_total = []

for batch in tqdm(model.test_dataloader()):
    # get the inputs;
    labels = batch["label"]
    del batch["label"]

    # move everything to the GPU
    for k,v in batch.items():
      batch[k] = batch[k].to(device)

    # forward pass
    outputs = model.model(**batch)
    logits = outputs.logits
    logits_total.extend(torch.sigmoid(logits).tolist())
    predictions = (torch.sigmoid(logits)>THRESHOLDS).float()
    predictions_total.extend(predictions.tolist())
    labels_total.extend(labels.tolist())
    del batch

In [None]:
next(iter(model.test_dataloader()))

In [None]:
# for idx, label_name in enumerate(rel_label_names):
#     labels = [label[idx] for label in labels_total]
#     predictions = [pred[idx] for pred in predictions_total]
#     logits = [logit[idx] for logit in logits_total]
#     precision = metrics.precision_score(labels, predictions)
#     recall = metrics.recall_score(labels, predictions)
#     f1 = metrics.f1_score(labels, predictions)
#     roc_auc = metrics.roc_auc_score(labels, logits)
#     print(label_name)
#     print(f'Precision: {precision:.3f}, Recall: {recall:.3f}')
#     print(f'F1-Score: {f1:.3f}, ROC-AUC: {roc_auc:.3f}\n')

## Inference

Here we test the trained model on a new, unseen sentence.

In [None]:
import torch.nn.functional as F

test_doc = test_df[test_df.rels.apply(lambda rels: 'ResearchandDevelopment' in rels)].sample()

text = test_doc.document.iloc[0]
spans = test_doc.spans.iloc[0]  # character-based entity spans 

inputs = tokenizer(text, entity_spans=spans, padding="max_length", truncation=True, return_tensors="pt",
                            max_length=MAX_LEN)
inputs.to(device)

model.model.eval()
model.to(device)

outputs = model.model(**inputs)
logits = outputs.logits
predicted_classes = (torch.sigmoid(logits)>THRESHOLDS).float()
print("Sentence:", text)
print("Tokens:", tokenizer.decode(inputs["input_ids"][0]))
print("Ground truth label:", ids2labels(test_doc.rel_one_hot.iloc[0]))
print("Predicted class idx:", ids2labels(predicted_classes.squeeze().tolist()))
print("Confidence:", torch.sigmoid(logits).squeeze().tolist())

To deal with the imbalance, we use weights in the loss function proportional to the inverse relative frequency of each class, normalizing by the frequency of the most common class. The resulting weights are as follows:

In [None]:
CLASS_WEIGHTS = 1/(rel_frequencies/rel_frequencies.iloc[0])
CLASS_WEIGHTS

The strategic alliance label is 6 times more frequent than the R&D label, therefore we weight the R&D class 6 times heavier in the loss function.

We can even increase the reweighting by exponentiating these weights:

In [None]:
IMBALANCE_REWEIGHTING = 1
CLASS_WEIGHTS = CLASS_WEIGHTS**IMBALANCE_REWEIGHTING
CLASS_WEIGHTS

Now examples with the R&D label are weighted 10 times heavier than those with the SA label.

In [None]:
CLASS_WEIGHTS = torch.tensor(CLASS_WEIGHTS.to_list()).to(device)

## Inference on Lexis Nexis news

In [None]:
model = LUKE.load_from_checkpoint(checkpoint_path='LUKE/epoch=1-step=2316.ckpt')

model = model.to(device)

In [None]:
df = pd.read_pickle('lexis_2017_with_org_preds_spacy.pkl')
df

In [None]:
import torch.nn.functional as F

doc = df.sample()

text = doc.document.iloc[0]
spans = doc.spans.iloc[0]  # character-based entity spans 

inputs = tokenizer(text, entity_spans=spans, padding="max_length", truncation=True, return_tensors="pt",
                            max_length=MAX_LEN)
inputs.to(device)

outputs = model.model(**inputs)
logits = outputs.logits
predicted_classes = (torch.sigmoid(logits)>THRESHOLDS).float()
print("Sentence:", text)
print("Companies: ", doc.firms.iloc[0])
print("Tokens:", tokenizer.decode(inputs["input_ids"][0]))
print("Predicted class idx:", ids2labels(predicted_classes.squeeze().tolist()))
print("Confidence:", torch.sigmoid(logits).squeeze().tolist())