In [1]:
!nvidia-smi -L

GPU 0: NVIDIA GeForce RTX 3090 (UUID: GPU-b104c15c-7745-a8d7-39fd-c7570e2f41a5)


In [2]:
!ls

Firm_relation_extraction_LUKE.ipynb  onstart.log  wandb
luke-first-run			     onstart.sh


In [4]:
!rm -r luke-first-run

rm: cannot remove 'luke-first-run': No such file or directory


In this notebook, we are going to fine-tune [`LukeForEntityPairClassification`](https://huggingface.co/transformers/model_doc/luke.html#lukeforentitypairclassification) on a supervised **relation extraction** dataset.

The goal for the model is to predict, given a sentence and the character spans of two entities within the sentence, the relationship between the entities.

* Paper: https://arxiv.org/abs/2010.01057
* Original repository: https://github.com/studio-ousia/luke

In [5]:
!pip install -q transformers 



In [6]:
!pip install -q pytorch-lightning wandb



In [7]:
!pip install -q pandas
!pip install -q sklearn



## Read in data

Let's download the data from the web, hosted on Dropbox.

Each row in the dataframe consists of a news article. 

In [8]:
import pandas as pd

df = pd.read_pickle("https://www.dropbox.com/s/j6jvmzpedgf7jxg/kb_6class_unbalanced_neg_examples.pkl?dl=1")
df.head()

Unnamed: 0,document,entities,entity_spans,relation
0,Salix Pharmaceuticals Ltd (SP) and Pharmatel P...,"[Salix Pharmaceuticals Ltd, Pharmatel Pty Ltd]","[(0, 25), (35, 52)]","[Marketing, StrategicAlliance]"
2,Praxair Inc and Phillips 66 Co formed a strate...,"[Praxair Inc, Phillips 66 Co]","[(0, 11), (16, 30)]",[StrategicAlliance]
3,Apple Computer Inc (ACI) and Samsung Electroni...,"[Apple Computer Inc, Samsung Electronics Co Ltd]","[(0, 18), (29, 55)]",[StrategicAlliance]
4,Robert Koch Institute granted Hoechst Schering...,"[Robert Koch Institute, Hoechst Schering AgrEv...","[(0, 21), (30, 58)]","[Licensing, ResearchandDevelopment, StrategicA..."
5,Thai Advance Innovation Co Ltd and Ai And Robo...,"[Thai Advance Innovation Co, Ai And Robotics V...","[(0, 26), (35, 62)]",[JointVenture]


This is the frequency of our relation types:



In [9]:
df.relation.explode().value_counts()

StrategicAlliance         34707
JointVenture               9074
Marketing                  7267
Licensing                  5170
Manufacturing              5125
ResearchandDevelopment     4936
Name: relation, dtype: int64

To deal with the imbalance, we use weights in the loss function proportional to the inverse relative frequency of each class. The weights are as:

In [10]:
rel_label_weights = 1/(2*df.relation.explode().value_counts()/len(df))
rel_label_weights

StrategicAlliance         0.671579
JointVenture              2.568713
Marketing                 3.207445
Licensing                 4.508414
Manufacturing             4.548000
ResearchandDevelopment    4.722143
Name: relation, dtype: float64

Let's one hot encode them.

In [11]:
rel_label_names = df.relation.explode().value_counts().index.to_list()
def label2ids(labels, rel_label_names=rel_label_names):
  return [1 if rel_name in labels else 0 for rel_name in rel_label_names]
df['rel_one_hot'] = df.relation.apply(label2ids)
df.head()

Unnamed: 0,document,entities,entity_spans,relation,rel_one_hot
0,Salix Pharmaceuticals Ltd (SP) and Pharmatel P...,"[Salix Pharmaceuticals Ltd, Pharmatel Pty Ltd]","[(0, 25), (35, 52)]","[Marketing, StrategicAlliance]","[1, 0, 1, 0, 0, 0]"
2,Praxair Inc and Phillips 66 Co formed a strate...,"[Praxair Inc, Phillips 66 Co]","[(0, 11), (16, 30)]",[StrategicAlliance],"[1, 0, 0, 0, 0, 0]"
3,Apple Computer Inc (ACI) and Samsung Electroni...,"[Apple Computer Inc, Samsung Electronics Co Ltd]","[(0, 18), (29, 55)]",[StrategicAlliance],"[1, 0, 0, 0, 0, 0]"
4,Robert Koch Institute granted Hoechst Schering...,"[Robert Koch Institute, Hoechst Schering AgrEv...","[(0, 21), (30, 58)]","[Licensing, ResearchandDevelopment, StrategicA...","[1, 0, 0, 1, 0, 1]"
5,Thai Advance Innovation Co Ltd and Ai And Robo...,"[Thai Advance Innovation Co, Ai And Robotics V...","[(0, 26), (35, 62)]",[JointVenture],"[0, 1, 0, 0, 0, 0]"


In [12]:
def ids2labels(ids, rel_label_names=rel_label_names):
  labels = []
  for idx, label in enumerate(rel_label_names):
    if ids[idx] == 1:
      labels.append(label)
    
  return labels

In [13]:
df.shape

(46617, 5)

## Define the PyTorch dataset and dataloaders

Next, we define regular PyTorch datasets and corresponding dataloaders. In PyTorch, you need to define a `Dataset` class that inherits from `torch.utils.data.Dataset`, and you need to implement 3 methods: the `init` method (for initializing the dataset with data), the `len` method (which returns the number of elements in the dataset) and the `getitem()` method, which returns a single item from the dataset.

In our case, each item of the dataset consists of a sentence, the spans of 2 entities in the sentence, and a label of the relationship. We use `LukeTokenizer` (available in the Transformers library) to turn these into the inputs expected by the model, which are `input_ids`, `entity_ids`, `attention_mask`, `entity_attention_mask` and `entity_position_ids`.

For more information regarding these inputs, refer to the [docs](https://huggingface.co/transformers/model_doc/luke.html#lukeforentitypairclassification) of `LukeForEntityPairClassification`.


In [14]:
MAX_LEN = 128
THRESHOLD = 0.5
TRAIN_BATCH_SIZE = 8
VALID_BATCH_SIZE, TEST_BATCH_SIZE = 64, 64
LEARNING_RATE = 1e-05
MAX_EPOCHS = 1
IMBALANCE_OVERWEIGHTING = 1.3

In [15]:
from transformers import LukeTokenizer
from torch.utils.data import Dataset, DataLoader
import torch

tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", task="entity_pair_classification")

class RelationExtractionDataset(Dataset):
    """Relation extraction dataset."""

    def __init__(self, data):
        """
        Args:
            data : Pandas dataframe.
        """
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data.iloc[idx]

        text = item.document

        encoding = tokenizer(text, entity_spans=item.entity_spans, padding="max_length", truncation=True, return_tensors="pt",
                            max_length=MAX_LEN)

        for k,v in encoding.items():
          encoding[k] = encoding[k].squeeze()

        encoding["label"] = torch.tensor(item.rel_one_hot)

        return encoding

Here we instantiate the class defined above with 3 objects: a training dataset, a validation dataset and a test set.

In [16]:
from sklearn.model_selection import train_test_split

train_df, test_df = train_test_split(df, test_size=0.2, random_state=42, shuffle=True)
train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=42, shuffle=False)

print("FULL Dataset: {}".format(len(df)))
print("TRAIN Dataset: {}".format(len(train_df)))
print("TEST Dataset: {}".format(len(test_df)))
print("VALIDATION Dataset: {}".format(len(val_df)))

# define the dataset
train_dataset = RelationExtractionDataset(data=train_df)
valid_dataset = RelationExtractionDataset(data=val_df)
test_dataset = RelationExtractionDataset(data=test_df)

FULL Dataset: 46617
TRAIN Dataset: 29834
TEST Dataset: 9324
VALIDATION Dataset: 7459


In [17]:
train_dataset[0].keys()

dict_keys(['input_ids', 'entity_ids', 'entity_position_ids', 'attention_mask', 'entity_attention_mask', 'label'])

Let's define the corresponding dataloaders (which allow us to iterate over the elements of the dataset):

In [18]:
train_dataloader = DataLoader(train_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=VALID_BATCH_SIZE)
test_dataloader = DataLoader(test_dataset, batch_size=TEST_BATCH_SIZE)

Let's verify an example of a batch:

In [19]:
batch = next(iter(train_dataloader))
tokenizer.decode(batch["input_ids"][1])

'<s> <ent> Silcock Express Holdings Ltd <ent>, a unit of Tibbett & Britten Group PLC, and British state-owned <ent2>  Railfreight Distribution <ent2>  (RfD) have agreed to form a joint venture to manage a car terminal which was linked to the Channel Tunnel rail network in Europe. The new company was called Autotrax Ltd. Financial terms were not disclosed.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>'

In [20]:
ids2labels(batch["label"][0])

['StrategicAlliance']

## Define a PyTorch LightningModule

Let's define the model as a PyTorch LightningModule. A `LightningModule` is actually an `nn.Module`, but with some extra functionality.

For more information regarding how to define this, see the [docs](https://pytorch-lightning.readthedocs.io/en/latest/?_ga=2.56317931.1395871250.1622709933-1738348008.1615553774) of PyTorch Lightning.

In [72]:
import torch, gc
gc.collect()
torch.cuda.empty_cache()

In [28]:
from transformers import LukeForEntityPairClassification, AdamW
import pytorch_lightning as pl
from sklearn import metrics

rel_label_weights = rel_label_weights**IMBALANCE_OVERWEIGHTING
class_weights = torch.cuda.FloatTensor(rel_label_weights.to_list())

class LUKE(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.model = LukeForEntityPairClassification.from_pretrained("studio-ousia/luke-base", num_labels=len(rel_label_names))

    def forward(self, input_ids, entity_ids, entity_position_ids, attention_mask, entity_attention_mask):     
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, entity_ids=entity_ids, 
                             entity_attention_mask=entity_attention_mask, entity_position_ids=entity_position_ids)
        return outputs
    
    def common_step(self, batch, batch_idx):
        labels = batch['label'].float()
        del batch['label']
        outputs = self(**batch)
        logits = outputs.logits
        
        criterion = torch.nn.BCEWithLogitsLoss(weight=class_weights) # multi-label classification with weighted classes
        loss = criterion(logits, labels)
        preds = (torch.sigmoid(logits)>THRESHOLD).float()
        
        # targets = labels.cpu().detach().numpy().tolist()
        # outputs = predictions.cpu().detach().numpy().tolist()
        # with torch.no_grad():
        #   f1_scores_micro, f1_scores_macro = [], []
        #   for target, output in zip(targets, outputs):
        #     print('Target: ', target, ' Output: ', output)
        #     f1_scores_micro.append(metrics.f1_score(targets, outputs, average='micro'))
        #     f1_scores_macro.append(metrics.f1_score(targets, outputs, average='macro'))
        #   f1_score_micro = sum(f1_scores_micro)/len(f1_scores_micro)
        #   f1_score_macro = sum(f1_scores_macro)/len(f1_scores_macro)
        # self.log("training_f1_score_macro", f1_score_macro, on_epoch=True)

        return {'loss': val_loss, 'preds': preds, 'labels': labels}
      
    def training_step(self, batch, batch_idx):
        output = self.common_step(batch, batch_idx)
        loss = output['loss']
        # logs metrics for each training_step,
        # and the average across the epoch
        self.log("training_loss", loss)

        return loss

    def validation_step(self, batch, batch_idx):
        output = self.common_step(batch, batch_idx)     
        loss = output['loss']
        self.log("validation_loss", loss, on_epoch=True)
        
        preds = output['preds']
        labels = output['labels']
        
        return {"loss": val_loss, "preds": preds, "labels": labels}
    
    def validation_epoch_end(self, outputs):
            for i, output in enumerate(outputs):
                preds = torch.cat([x["preds"] for x in output]).detach().cpu().numpy()
                labels = torch.cat([x["labels"] for x in output]).detach().cpu().numpy()
                loss = torch.stack([x["loss"] for x in output]).mean()
                self.log(f"val_loss_{split}", loss, prog_bar=True)
                split_metrics = {
                    f"{k}_{split}": v for k, v in self.metric.compute(predictions=preds, references=labels).items()
                }
                self.log_dict(split_metrics, prog_bar=True)
            return loss
        
    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=LEARNING_RATE)
        return optimizer

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return valid_dataloader

    def test_dataloader(self):
        return test_dataloader
    
    def test_epoch_end(self, outputs):
        avg_acc = 100 * self.test_correct_counter / self.test_total_counter
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()

        self.test_correct_counter = 0
        self.test_total_counter = 0

        tensorboard_logs = {'avg_acc': avg_acc, 'test_loss': avg_loss}
        return {'avg_test_loss': avg_loss, 'avg_acc': avg_acc, 'log': tensorboard_logs}


model = LUKE()

Some weights of the model checkpoint at studio-ousia/luke-base were not used when initializing LukeForEntityPairClassification: ['embeddings.position_ids']
- This IS expected if you are initializing LukeForEntityPairClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LukeForEntityPairClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LukeForEntityPairClassification were not initialized from the model checkpoint at studio-ousia/luke-base and are newly initialized: ['classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Train the model

Let's train the model. We also use early stopping, to avoid overfitting the training dataset. We also log everything to Weights and Biases, which will give us beautiful charts of the loss and accuracy plotted over time.

If you haven't already, you can create an account on the [website](https://wandb.ai/site), then log in in a web browser, and run the cell below: 

In [22]:
import wandb

wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mjakr[0m (use `wandb login --relogin` to force relogin)


True

In [35]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping

wandb_logger = WandbLogger(name='luke-first-run', project='LUKE')
# for early stopping, see https://pytorch-lightning.readthedocs.io/en/1.0.0/early_stopping.html?highlight=early%20stopping
early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=2,
    strict=False,
    verbose=False,
    mode='min'
)

trainer = Trainer(gpus=1, logger=wandb_logger, callbacks=[EarlyStopping(monitor='validation_loss')], max_epochs=MAX_EPOCHS)
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                            | Params
----------------------------------------------------------
0 | model | LukeForEntityPairClassification | 274 M 
----------------------------------------------------------
274 M     Trainable params
0         Non-trainable params
274 M     Total params
1,098.045 Total estimated model params size (MB)
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
  rank_zero_warn(
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncatio

1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
2 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens


Validating: 0it [00:00, ?it/s]

1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
2 entities are ignored because their entity spans are invalid due to the truncation of input tokens
2 entities are ignored because their entity spans are invalid due to the truncation of input tokens
2 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
2 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens


In [24]:
ls luke-first-run/version_None/checkpoints/

'epoch=0-step=3729.ckpt'


In [27]:
trainer.test(test_dataloaders=test_dataloader)

  rank_zero_deprecation(
  rank_zero_warn(
Restoring states from the checkpoint path at /workspace/luke-first-run/version_None/checkpoints/epoch=0-step=3729.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at /workspace/luke-first-run/version_None/checkpoints/epoch=0-step=3729.ckpt
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
2 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
2 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{}
--------------------------------------------------------------------------------


[{}]

## Evaluation

Instead of calling `trainer.test()`, we can also manually evaluate the model on the entire test set:

In [33]:
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

# loaded_model = LUKE.load_from_checkpoint(checkpoint_path="luke-first-run/version_None/checkpoints/epoch=0-step=3729.ckpt")

# loaded_model.model.eval()

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# loaded_model.to(device)

model.to('cuda')
model.model.eval()

predictions_total = []
logits_total = []
labels_total = []
for batch in tqdm(test_dataloader):
    # get the inputs;
    labels = batch["label"]
    del batch["label"]

    # move everything to the GPU
    for k,v in batch.items():
      batch[k] = batch[k].to('cuda')

    # forward pass
    outputs = model.model(**batch)
    logits = outputs.logits
    logits_total.extend(torch.sigmoid(logits).tolist())
    predictions = (torch.sigmoid(logits)>THRESHOLD).float()
    predictions_total.extend(predictions.tolist())
    labels_total.extend(labels.tolist())

  0%|          | 0/146 [00:00<?, ?it/s]

1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
2 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
2 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens


In [34]:
for idx, label_name in enumerate(rel_label_names):
    labels = [label[idx] for label in labels_total]
    predictions = [pred[idx] for pred in predictions_total]
    logits = [logit[idx] for logit in logits_total]
    precision = metrics.precision_score(labels, predictions)
    recall = metrics.recall_score(labels, predictions)
    f1 = metrics.f1_score(labels, predictions)
    roc_auc = metrics.roc_auc_score(labels, logits)
    print(label_name)
    print(f'Precision: {precision:.3f}, Recall: {recall:.3f}')
    print(f'F1-Score: {f1:.3f}, ROC-AUC: {roc_auc:.3f}\n')

StrategicAlliance
Precision: 0.617, Recall: 0.078
F1-Score: 0.138, ROC-AUC: 0.429

JointVenture
Precision: 0.200, Recall: 1.000
F1-Score: 0.334, ROC-AUC: 0.459

Marketing
Precision: 0.000, Recall: 0.000
F1-Score: 0.000, ROC-AUC: 0.555

Licensing
Precision: 0.118, Recall: 0.154
F1-Score: 0.133, ROC-AUC: 0.513

Manufacturing
Precision: 0.114, Recall: 0.678
F1-Score: 0.195, ROC-AUC: 0.511

ResearchandDevelopment
Precision: 0.103, Recall: 0.916
F1-Score: 0.185, ROC-AUC: 0.414



  _warn_prf(average, modifier, msg_start, len(result))


## Inference

Here we test the trained model on a new, unseen sentence.

In [30]:
import torch.nn.functional as F

idx = 8
text = test_df.iloc[idx].document
entity_spans = test_df.iloc[idx].entity_spans  # character-based entity spans 

inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")
inputs.to('cuda')

outputs = model.model(**inputs)
logits = outputs.logits
predicted_classes = (torch.sigmoid(logits)>THRESHOLD).float()
print("Sentence:", text)
print("Ground truth label:", ids2labels(test_df.iloc[idx].rel_one_hot))
print("Predicted class idx:", ids2labels(predicted_classes.squeeze().tolist()))
print("Confidence:", torch.sigmoid(logits).squeeze().tolist())

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper__index_select)