In [1]:
!nvidia-smi -L

GPU 0: GeForce RTX 3090 (UUID: GPU-334c16f1-da9a-6664-e565-aa4364d6bc92)


In [2]:
!ls

Firm_relation_extraction_LUKE.ipynb  lightning_logs
LUKE				     onstart.log
lexis_2017_with_org_preds_spacy.pkl  onstart.sh
lexis_entity_extraction_spacy.ipynb  wandb


In [3]:
!ls LUKE

'epoch=1-step=2316.ckpt'


In [4]:
# !rm -r LUKE # delete old runs

In this notebook, we are going to fine-tune [`LukeForEntityPairClassification`](https://huggingface.co/transformers/model_doc/luke.html#lukeforentitypairclassification) on a supervised **relation extraction** dataset.

The goal for the model is to predict, given a sentence and the character spans of two entities within the sentence, the relationship between the entities.

* Paper: https://arxiv.org/abs/2010.01057
* Original repository: https://github.com/studio-ousia/luke

In [5]:
!pip install -q transformers 
!pip install -q pandas
!pip install -q sklearn
!pip install -q pytorch-lightning wandb



In [40]:
!pip install -q unidecode



In [43]:
from transformers import LukeTokenizer
from torch.utils.data import Dataset, DataLoader
import torch
import pandas as pd

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [8]:
rs = 42

## Read in data

Let's download the data from the web, hosted on Dropbox.

Each row in the dataframe consists of a news article. 

In [9]:
df = pd.read_pickle("https://www.dropbox.com/s/rdwg8d76ytqqgdy/training-data-29-03-22.pkl?dl=1")
df

Unnamed: 0,Date,source,document,firms,spans,rels
233892045,1985-01-01,Thomson SDC alliances - Deal Number 233892045,Accel Capital L.P. invested an undisclosed amo...,"[Accel Capital, Thalamus Electronics Inc]","[(0, 13), (53, 77)]","[StrategicAlliance, Manufacturing]"
216749045,1985-01-01,Thomson SDC alliances - Deal Number 216749045,Standard Microsystems Corp. and Mitsubishi Ele...,"[Standard Microsystems Corp, Mitsubishi Electr...","[(0, 26), (32, 56)]","[StrategicAlliance, Licensing]"
234987045,1985-01-01,Thomson SDC alliances - Deal Number 234987045,"Elf Technologies, Inc., the venture capital in...","[Elf Technologies, Santa Clara Systems]","[(0, 16), (169, 188)]",[StrategicAlliance]
216600045,1985-01-01,Thomson SDC alliances - Deal Number 216600045,Stauffer Chemical Co. has been licensed by Mit...,"[Stauffer Chemical Co, Himont Inc]","[(0, 20), (79, 89)]","[StrategicAlliance, Marketing, Manufacturing, ..."
2410877045,1985-01-01,Thomson SDC alliances - Deal Number 2410877045,The Romanian Government (RG) and Control Data ...,"[Romania, Control Data Corp]","[(4, 11), (33, 50)]","[JointVenture, Manufacturing]"
...,...,...,...,...,...,...
257823,NaT,Reuters News Dataset - Article ID http://www.b...,University of Notre Dame football radio anal...,"[Notre Dame, Notre Dame’s]","[(15, 25), (200, 212)]",[]
416616,NaT,Reuters News Dataset - Article ID http://www.b...,"Schaeffler AG, the industrial-bearing maker th...","[Schaeffler AG, Continental AG]","[(0, 13), (91, 105)]",[]
328325,NaT,Reuters News Dataset - Article ID http://www.r...,Barnes & Noble Inc ( BKS.N ) said on Friday it...,"[Barnes & Noble, Lynch]","[(0, 14), (113, 118)]",[]
121959,NaT,Reuters News Dataset - Article ID http://www.b...,(Corrects third paragraph to show output of 25...,"[First Solar Inc ., Edison International]","[(154, 171), (249, 269)]",[]


Shuffle the data:

In [10]:
df = df.sample(frac=1, random_state=rs)

This is the frequency of our relation types:



In [11]:
df.rels.explode().value_counts()

StrategicAlliance         83845
JointVenture              44556
Pending                   38916
Marketing                 22478
Manufacturing             19926
ResearchandDevelopment    13568
Licensing                 12857
TechnologyTransfer         9889
Supply                     5407
Exploration                3152
Terminated                 1578
Name: rels, dtype: int64

We don't want to consider the less relevant and less consistently applied labels of 'supply', 'exploration', and 'technology transfer'. Let's remove them:

In [12]:
df['rels'] = df.rels.apply(lambda rels: [rel for rel in rels if rel not in ['TechnologyTransfer', 'Supply', 'Exploration']])

In [13]:
rel_frequencies = df.rels.explode().value_counts()
rel_frequencies

StrategicAlliance         83845
JointVenture              44556
Pending                   38916
Marketing                 22478
Manufacturing             19926
ResearchandDevelopment    13568
Licensing                 12857
Terminated                 1578
Name: rels, dtype: int64

Let's one hot encode the relationship classes:

In [14]:
rel_label_names = rel_frequencies.index.to_list()

def label2ids(labels, rel_label_names=rel_label_names):
  return [1 if rel_name in labels else 0 for rel_name in rel_label_names]

df['rel_one_hot'] = df.rels.apply(label2ids)

def ids2labels(ids, rel_label_names=rel_label_names):
  labels = []
  for idx, label in enumerate(rel_label_names):
    if ids[idx] == 1:
      labels.append(label)
    
  return labels

df.sample(10)

Unnamed: 0,Date,source,document,firms,spans,rels,rel_one_hot
2171577045,2010-01-05,Thomson SDC alliances - Deal Number 2171577045,Strides Arcolab Ltd and Pfizer Inc formed a st...,"[Strides Arcolab Ltd, Pfizer Inc]","[(0, 19), (24, 34)]",[StrategicAlliance],"[1, 0, 0, 0, 0, 0, 0, 0]"
57660,NaT,Reuters News Dataset - Article ID http://www.b...,Bond traders are giving Marfrig Alimentos SA a...,"[Keystone Foods LLC, McDonald’s Corp.]","[(408, 426), (465, 481)]",[],"[0, 0, 0, 0, 0, 0, 0, 0]"
532522045,1996-01-23,Thomson SDC alliances - Deal Number 532522045,Teikoku Oil Co Ltd and Nippon Oil Co Ltd forme...,"[Teikoku Oil Co Ltd, Nippon Oil Co Ltd]","[(0, 18), (23, 40)]",[JointVenture],"[0, 1, 0, 0, 0, 0, 0, 0]"
1258127045,2001-12-28,Thomson SDC alliances - Deal Number 1258127045,"Kyiv Borispol, Gera International and Airport ...","[Kyiv Borispol, Gera International]","[(0, 13), (15, 33)]","[StrategicAlliance, Pending]","[1, 0, 1, 0, 0, 0, 0, 0]"
361065,NaT,Reuters News Dataset - Article ID http://www.b...,"The New York City Health Department, locked in...","[The New York City Health Department, Coca-Col...","[(0, 35), (99, 112)]",[],"[0, 0, 0, 0, 0, 0, 0, 0]"
131279,NaT,Reuters News Dataset - Article ID http://www.r...,The Asian Development Bank trimmed most of its...,"[The Asian Development Bank, ADB]","[(0, 26), (358, 361)]",[],"[0, 0, 0, 0, 0, 0, 0, 0]"
2015615045,2008-09-25,Thomson SDC alliances - Deal Number 2015615045,Danfoss A/S (DA) and Tianjin Sanhua Refrigerat...,"[Danfoss A/S, Tianjin Sanhua Refrigeration]","[(0, 11), (21, 49)]","[JointVenture, Manufacturing, ResearchandDevel...","[0, 1, 0, 0, 1, 1, 0, 0]"
2449453045,2012-09-11,Thomson SDC alliances - Deal Number 2449453045,newsR.in (NR) and PR Newswire Inc (PN) formed ...,"[newsR.in, PR Newswire Inc]","[(0, 8), (18, 33)]",[StrategicAlliance],"[1, 0, 0, 0, 0, 0, 0, 0]"
5252,NaT,Reuters News Dataset - Article ID http://www.r...,Chicago real estate magnate Sam Zell may have ...,"[Zell, Tribune Co.]","[(32, 36), (92, 103)]",[],"[0, 0, 0, 0, 0, 0, 0, 0]"
1962224045,2008-03-18,Thomson SDC alliances - Deal Number 1962224045,Microsoft Corp (MC) and Yellowpages.com LLC (Y...,"[Microsoft Corp, Yellowpages.com LLC]","[(0, 14), (24, 43)]",[StrategicAlliance],"[1, 0, 0, 0, 0, 0, 0, 0]"


## Define the PyTorch dataset and dataloaders

Next, we define regular PyTorch datasets and corresponding dataloaders. In PyTorch, you need to define a `Dataset` class that inherits from `torch.utils.data.Dataset`, and you need to implement 3 methods: the `init` method (for initializing the dataset with data), the `len` method (which returns the number of elements in the dataset) and the `getitem()` method, which returns a single item from the dataset.

In our case, each item of the dataset consists of a sentence, the spans of 2 entities in the sentence, and a label of the relationship. We use `LukeTokenizer` (available in the Transformers library) to turn these into the inputs expected by the model, which are `input_ids`, `entity_ids`, `attention_mask`, `entity_attention_mask` and `entity_position_ids`.

For more information regarding these inputs, refer to the [docs](https://huggingface.co/transformers/model_doc/luke.html#lukeforentitypairclassification) of `LukeForEntityPairClassification`.


Let's set our hyperparameters:

In [15]:
MAX_LEN = 128
LEARNING_RATE = 1e-5
BATCH_SIZE = 128
MAX_EPOCHS = 10
THRESHOLDS = torch.tensor([0.5 for rel in rel_label_names]).to(device) # can change this for class-specific thresholds
WEIGHT_DECAY= 0.01
GRAD_CLIP_VAL = 0
CLASS_WEIGHTS = torch.tensor([1, 1, 1, 1, 1, 4, 1, 4]).to(device) # equal class weights 

In [16]:
tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", task="entity_pair_classification")

class RelationExtractionDataset(Dataset):
    """Relation extraction dataset."""

    def __init__(self, data, has_labels=True):
        """
        Args:
            data : Pandas dataframe.
        """
        self.data = data
        self.has_labels = has_labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data.iloc[idx]

        text = item.document

        encoding = tokenizer(text, entity_spans=item.spans, padding="max_length", truncation=True, return_tensors="pt",
                            max_length=MAX_LEN)

        for k,v in encoding.items():
          encoding[k] = encoding[k].squeeze()

        if self.has_labels:
            encoding["label"] = torch.tensor(item.rel_one_hot)

        return encoding

Here we instantiate the class defined above with 3 objects: a training dataset, a validation dataset and a test set.

In [17]:
from sklearn.model_selection import train_test_split

train_size = 0.8
validation_size = 0.1
test_size = 0.1

train_df, test_df = train_test_split(df, test_size=1-train_size, random_state=rs, shuffle=True)
val_df, test_df = train_test_split(test_df, test_size=test_size/(test_size+validation_size), random_state=rs, shuffle=False)

print("FULL Dataset: {}".format(len(df)))
print("TRAIN Dataset: {}".format(len(train_df)))
print("TEST Dataset: {}".format(len(test_df)))
print("VALIDATION Dataset: {}".format(len(val_df)))

# define the dataset
train_dataset = RelationExtractionDataset(data=train_df)
val_dataset = RelationExtractionDataset(data=val_df)
test_dataset = RelationExtractionDataset(data=test_df)

FULL Dataset: 264928
TRAIN Dataset: 211942
TEST Dataset: 26493
VALIDATION Dataset: 26493


In [18]:
train_dataset[0].keys()

dict_keys(['input_ids', 'entity_ids', 'entity_position_ids', 'attention_mask', 'entity_attention_mask', 'label'])

## Define a PyTorch LightningModule

Let's define the model as a PyTorch LightningModule. A `LightningModule` is actually an `nn.Module`, but with some extra functionality.

For more information regarding how to define this, see the [docs](https://pytorch-lightning.readthedocs.io/en/latest/?_ga=2.56317931.1395871250.1622709933-1738348008.1615553774) of PyTorch Lightning.

In [19]:
# import gc
# gc.collect()
# torch.cuda.empty_cache()

In [20]:
print(f'Free CUDA memory: {(torch.cuda.get_device_properties(0).total_memory-torch.cuda.memory_allocated(0))/1e9:.2f}GB')

Free CUDA memory: 25.45GB


In [21]:
from transformers import LukeForEntityPairClassification, AdamW
import pytorch_lightning as pl
from sklearn import metrics

class LUKE(pl.LightningModule):

    def __init__(self, lr=LEARNING_RATE, batch_size=BATCH_SIZE, class_weights=CLASS_WEIGHTS,
                 thresholds=THRESHOLDS, weight_decay=WEIGHT_DECAY,
                 train_dataset=train_dataset, val_dataset=val_dataset, test_dataset=test_dataset):
        super().__init__()
        self.model = LukeForEntityPairClassification.from_pretrained("studio-ousia/luke-base", 
                                                                    num_labels=len(rel_label_names))
        self.save_hyperparameters()
        print(self.hparams)


    def forward(self, input_ids, entity_ids, entity_position_ids, attention_mask, entity_attention_mask):     
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, entity_ids=entity_ids, 
                             entity_attention_mask=entity_attention_mask, entity_position_ids=entity_position_ids)
        return outputs
    
    def common_step(self, batch, batch_idx):
        labels = batch['label'].float()
        del batch['label']
        outputs = self(**batch)
        logits = outputs.logits
        
        criterion = torch.nn.BCEWithLogitsLoss(weight=self.hparams.class_weights) # multi-label classification with weighted classes
        loss = criterion(logits, labels)
        preds = (torch.sigmoid(logits)>self.hparams.thresholds).float()
        
        return {'loss': loss, 'preds': preds, 'labels': labels}
      
    def training_step(self, batch, batch_idx):
        output = self.common_step(batch, batch_idx)
        loss = output['loss']
        # logs metrics for each training_step,
        # and the average across the epoch
        self.log("training_loss", loss)

        return loss

    def validation_step(self, batch, batch_idx):
        output = self.common_step(batch, batch_idx)     
        loss = output['loss']
        self.log("val_loss", loss, on_epoch=True)
        
        preds = output['preds']
        labels = output['labels']
        
        return {"loss": loss, "preds": preds, "labels": labels}
    
    def validation_epoch_end(self, outputs):
            for i, output in enumerate(outputs):
                preds = output["preds"].detach().cpu().numpy()
                labels = output["labels"].detach().cpu().numpy()
                loss = output["loss"].mean()
                
                f1_scores = []
                for idx, label_name in enumerate(rel_label_names):
                    label = [label[idx] for label in labels]
                    pred = [pred[idx] for pred in preds]

                    precision = metrics.precision_score(label, pred, zero_division=0)
                    recall = metrics.recall_score(label, pred, zero_division=0)
                    f1 = metrics.f1_score(label, pred, zero_division=0)
                    self.log(f'val_precision_{label_name}', precision)
                    self.log(f'val_recall_{label_name}', recall)
                    self.log(f'val_f1_{label_name}', f1, prog_bar=True)
                    f1_scores.append(f1)
                self.log(f'val_f1_macro_avg', sum(f1_scores)/len(f1_scores), prog_bar=True)
        
    def test_step(self, batch, batch_idx):
        output = self.common_step(batch, batch_idx)     
        loss = output['loss']
        self.log("test_loss", loss, on_epoch=True)
        
        preds = output['preds']
        labels = output['labels']
        
        return {"loss": loss, "preds": preds, "labels": labels}
    
    def test_epoch_end(self, outputs):
            for i, output in enumerate(outputs):
                preds = output["preds"].detach().cpu().numpy()
                labels = output["labels"].detach().cpu().numpy()
                loss = output["loss"].mean()

                f1_scores = []
                for idx, label_name in enumerate(rel_label_names):
                    label = [label[idx] for label in labels]
                    pred = [pred[idx] for pred in preds]
                    precision = metrics.precision_score(label, pred, zero_division=0)
                    recall = metrics.recall_score(label, pred, zero_division=0)
                    f1 = metrics.f1_score(label, pred, zero_division=0)
                    self.log(f'test_precision_{label_name}', precision)
                    self.log(f'test_recall_{label_name}', recall)
                    self.log(f'test_f1_{label_name}', f1, prog_bar=True)
                    f1_scores.append(f1)
                self.log(f'test_f1_macro_avg', sum(f1_scores)/len(f1_scores), prog_bar=True)
    
    def predict_step(self, batch, batch_idx):
        output = self.common_step(batch, batch_idx)     
        loss = output['loss']
        
        preds = output['preds']
        labels = output['labels']
            
        return {"loss": loss, "preds": preds, "labels": labels}
    
    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay,
                          no_deprecation_warning=True)

        return optimizer

    def train_dataloader(self):
        return DataLoader(train_dataset, batch_size=self.hparams.batch_size, shuffle=True, num_workers=16)

    def val_dataloader(self):
        return DataLoader(val_dataset, batch_size=self.hparams.batch_size, num_workers=16)

    def test_dataloader(self):
        return DataLoader(test_dataset, batch_size=self.hparams.batch_size, num_workers=16)

    
model = LUKE()

Some weights of the model checkpoint at studio-ousia/luke-base were not used when initializing LukeForEntityPairClassification: ['embeddings.position_ids']
- This IS expected if you are initializing LukeForEntityPairClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LukeForEntityPairClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LukeForEntityPairClassification were not initialized from the model checkpoint at studio-ousia/luke-base and are newly initialized: ['classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


"batch_size":    128
"class_weights": tensor([1, 1, 1, 1, 1, 4, 1, 4], device='cuda:0')
"lr":            1e-05
"test_dataset":  <__main__.RelationExtractionDataset object at 0x7f884342b040>
"thresholds":    tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
       device='cuda:0')
"train_dataset": <__main__.RelationExtractionDataset object at 0x7f8843413250>
"val_dataset":   <__main__.RelationExtractionDataset object at 0x7f8843413fa0>
"weight_decay":  0.01


## Train the model

Let's train the model. We also use early stopping, to avoid overfitting the training dataset. We also log everything to Weights and Biases, which will give us beautiful charts of the loss and accuracy plotted over time.

If you haven't already, you can create an account on the [website](https://wandb.ai/site), then log in in a web browser, and run the cell below: 

In [22]:
import wandb

wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mjakr[0m (use `wandb login --relogin` to force relogin)


True

In [78]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

wandb_logger = WandbLogger(project='LUKE')

# for early stopping, see https://pytorch-lightning.readthedocs.io/en/1.0.0/early_stopping.html?highlight=early%20stopping
early_stop_callback = EarlyStopping(
    monitor='val_f1_macro_avg',
    patience=2,
    strict=False,
    verbose=True,
    mode='max',
    min_delta=0 # stops training once no more improvement is made
)

checkpoint_callback = ModelCheckpoint(dirpath='LUKE', save_top_k=1, save_last=False, verbose=True,
                                      monitor='val_f1_macro_avg', mode='max', every_n_epochs=1)

trainer = Trainer(gpus=1, logger=wandb_logger, callbacks=[early_stop_callback, checkpoint_callback], 
                  max_epochs=MAX_EPOCHS, precision=16, 
                  stochastic_weight_avg=True, auto_lr_find=True, 
                  benchmark=True, deterministic=True,
                  val_check_interval=0.1, gradient_clip_val=GRAD_CLIP_VAL, fast_dev_run=False)
                    # limit_train_batches=0.25

[34m[1mwandb[0m: Currently logged in as: [33mjakr[0m (use `wandb login --relogin` to force relogin)


  rank_zero_warn(
Using 16bit native Automatic Mixed Precision (AMP)
  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Finding the optimal learning rate:

In [24]:
# lr_finder = trainer.tuner.lr_find(model)

In [25]:
# lr_finder.suggestion()

Setting the learning rate to the optimal rate:

In [26]:
# model.hparams.lr = lr_finder.suggestion()

Let's train the model:

In [27]:
trainer.fit(model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                            | Params
----------------------------------------------------------
0 | model | LukeForEntityPairClassification | 274 M 
----------------------------------------------------------
274 M     Trainable params
0         Non-trainable params
274 M     Total params
549.029   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens


Training: 0it [00:00, ?it/s]

1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
2 entities are ignored because their entity spans are invalid due to the truncation of input tokens


Validation: 0it [00:00, ?it/s]

1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens


Validation: 0it [00:00, ?it/s]

1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens


Validation: 0it [00:00, ?it/s]

1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens


Validation: 0it [00:00, ?it/s]

1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens


Validation: 0it [00:00, ?it/s]

1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens


Validation: 0it [00:00, ?it/s]

1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens


Validation: 0it [00:00, ?it/s]

1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens


Validation: 0it [00:00, ?it/s]

1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens


Validation: 0it [00:00, ?it/s]

1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens


Validation: 0it [00:00, ?it/s]

1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens


Validation: 0it [00:00, ?it/s]

1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens


Validation: 0it [00:00, ?it/s]

1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens


Validation: 0it [00:00, ?it/s]

1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens


Validation: 0it [00:00, ?it/s]

1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens


Validation: 0it [00:00, ?it/s]

1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens


Validation: 0it [00:00, ?it/s]

1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens


In [28]:
trainer.test()

  rank_zero_warn(
Restoring states from the checkpoint path at /workspace/LUKE/epoch=1-step=2316.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at /workspace/LUKE/epoch=1-step=2316.ckpt


Testing: 0it [00:00, ?it/s]

1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
             Test metric                         DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_f1_JointVenture                  0.9874034270564443
          test_f1_Licensing                   0.8324084692977363
        test_f1_Manufacturing                 0.8315207967632101
          test_f1_Marketing                   0.7628345289248974
           test_f1_Pending                    0.9480708497484134
   test_f1_ResearchandDevelopment             0.7430008018636078
      test_f1_StrategicAlliance               0.9926623538603108
         test_f1_Terminated                   0.43140096618357493
          test_f1_macro_avg                   0.8161627742122745
              test_loss                       0.07398442178964615
     test_precision_JointVenture            

[{'test_loss': 0.07398442178964615,
  'test_precision_StrategicAlliance': 0.9927371562056587,
  'test_recall_StrategicAlliance': 0.9927783037001732,
  'test_f1_StrategicAlliance': 0.9926623538603108,
  'test_precision_JointVenture': 0.9850657970757907,
  'test_recall_JointVenture': 0.9903459079698528,
  'test_f1_JointVenture': 0.9874034270564443,
  'test_precision_Pending': 0.935172886215819,
  'test_recall_Pending': 0.9642232050501605,
  'test_f1_Pending': 0.9480708497484134,
  'test_precision_Marketing': 0.8822727192292408,
  'test_recall_Marketing': 0.6897875261046609,
  'test_f1_Marketing': 0.7628345289248974,
  'test_precision_Manufacturing': 0.862597434219344,
  'test_recall_Manufacturing': 0.8184253146209668,
  'test_f1_Manufacturing': 0.8315207967632101,
  'test_precision_ResearchandDevelopment': 0.7771515387457419,
  'test_recall_ResearchandDevelopment': 0.7429151497992077,
  'test_f1_ResearchandDevelopment': 0.7430008018636078,
  'test_precision_Licensing': 0.8787476211389255

In [29]:
checkpoint_callback.best_model_path

'/workspace/LUKE/epoch=1-step=2316.ckpt'

In [15]:
ls LUKE

'epoch=1-step=2316.ckpt'


## Evaluation

Instead of calling `trainer.test()`, we can also manually evaluate the model on the entire test set:

NameError: name 'trainer' is not defined

In [76]:
# model = LUKE.load_from_checkpoint(checkpoint_path=checkpoint_callback.best_model_path)
model = LUKE.load_from_checkpoint(checkpoint_path='LUKE/epoch=1-step=2316.ckpt')

model = model.to(device)

Some weights of the model checkpoint at studio-ousia/luke-base were not used when initializing LukeForEntityPairClassification: ['embeddings.position_ids']
- This IS expected if you are initializing LukeForEntityPairClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LukeForEntityPairClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LukeForEntityPairClassification were not initialized from the model checkpoint at studio-ousia/luke-base and are newly initialized: ['classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


"batch_size":    128
"class_weights": tensor([1, 1, 1, 1, 1, 4, 1, 4])
"lr":            1e-05
"test_dataset":  <__main__.RelationExtractionDataset object at 0x7f90f6a9ad60>
"thresholds":    tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000])
"train_dataset": <__main__.RelationExtractionDataset object at 0x7f9153d90220>
"val_dataset":   <__main__.RelationExtractionDataset object at 0x7f90f23bd8e0>
"weight_decay":  0.01


LUKE(
  (model): LukeForEntityPairClassification(
    (luke): LukeModel(
      (embeddings): LukeEmbeddings(
        (word_embeddings): Embedding(50267, 768, padding_idx=1)
        (position_embeddings): Embedding(514, 768, padding_idx=1)
        (token_type_embeddings): Embedding(1, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (entity_embeddings): LukeEntityEmbeddings(
        (entity_embeddings): Embedding(500000, 256, padding_idx=0)
        (entity_embedding_dense): Linear(in_features=256, out_features=768, bias=False)
        (position_embeddings): Embedding(514, 768)
        (token_type_embeddings): Embedding(1, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): LukeEncoder(
        (layer): ModuleList(
          (0): LukeLayer(
            (attention): LukeAttention(
              (

In [79]:
trainer = Trainer(gpus=1)

  rank_zero_warn(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.predict(model, dataloaders=model.test_dataloader())

Missing logger folder: /workspace/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [34]:
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

predictions_total = []
logits_total = []
labels_total = []

for batch in tqdm(model.test_dataloader()):
    # get the inputs;
    labels = batch["label"]
    del batch["label"]

    # move everything to the GPU
    for k,v in batch.items():
      batch[k] = batch[k].to(device)

    # forward pass
    outputs = model.model(**batch)
    logits = outputs.logits
    logits_total.extend(torch.sigmoid(logits).tolist())
    predictions = (torch.sigmoid(logits)>THRESHOLDS).float()
    predictions_total.extend(predictions.tolist())
    labels_total.extend(labels.tolist())
    del batch

  0%|          | 0/207 [00:00<?, ?it/s]

1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens


RuntimeError: CUDA out of memory. Tried to allocate 196.00 MiB (GPU 0; 23.70 GiB total capacity; 20.85 GiB already allocated; 176.56 MiB free; 21.72 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [72]:
next(iter(model.test_dataloader()))

1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens


{'input_ids': tensor([[    0, 50265, 25415,  ...,     1,     1,     1],
        [    0, 50265,   347,  ...,     1,     1,     1],
        [    0, 10653, 16768,  ...,   208,  2562,     2],
        ...,
        [    0,   791,     4,  ...,    26,    11,     2],
        [    0, 50265, 39928,  ...,     1,     1,     1],
        [    0, 50265, 26292,  ...,     1,     1,     1]]), 'entity_ids': tensor([[2, 3],
        [2, 3],
        [2, 3],
        [2, 3],
        [2, 3],
        [2, 3],
        [2, 3],
        [2, 3],
        [2, 3],
        [2, 0],
        [2, 3],
        [2, 3],
        [2, 3],
        [2, 3],
        [2, 3],
        [2, 3],
        [2, 3],
        [2, 0],
        [2, 3],
        [2, 3],
        [2, 3],
        [2, 3],
        [2, 3],
        [2, 3],
        [2, 3],
        [2, 3],
        [2, 3],
        [2, 3],
        [2, 3],
        [2, 3],
        [2, 3],
        [2, 3],
        [2, 3],
        [2, 3],
        [2, 0],
        [2, 3],
        [2, 3],
        [2, 3],
 

In [None]:
for idx, label_name in enumerate(rel_label_names):
    labels = [label[idx] for label in labels_total]
    predictions = [pred[idx] for pred in predictions_total]
    logits = [logit[idx] for logit in logits_total]
    precision = metrics.precision_score(labels, predictions)
    recall = metrics.recall_score(labels, predictions)
    f1 = metrics.f1_score(labels, predictions)
    roc_auc = metrics.roc_auc_score(labels, logits)
    print(label_name)
    print(f'Precision: {precision:.3f}, Recall: {recall:.3f}')
    print(f'F1-Score: {f1:.3f}, ROC-AUC: {roc_auc:.3f}\n')

## Inference

Here we test the trained model on a new, unseen sentence.

In [68]:
import torch.nn.functional as F

test_doc = test_df[test_df.rels.apply(lambda rels: 'ResearchandDevelopment' in rels)].sample()

text = test_doc.document.iloc[0]
spans = test_doc.spans.iloc[0]  # character-based entity spans 

inputs = tokenizer(text, entity_spans=spans, padding="max_length", truncation=True, return_tensors="pt",
                            max_length=MAX_LEN)
inputs.to(device)

outputs = model.model(**inputs)
logits = outputs.logits
predicted_classes = (torch.sigmoid(logits)>THRESHOLDS).float()
print("Sentence:", text)
print("Tokens:", tokenizer.decode(inputs["input_ids"][0]))
print("Ground truth label:", ids2labels(test_doc.rel_one_hot.iloc[0]))
print("Predicted class idx:", ids2labels(predicted_classes.squeeze().tolist()))
print("Confidence:", torch.sigmoid(logits).squeeze().tolist())

Sentence: NexDx Inc (ND) and University of California (UC) formed a strategic alliance to develop and commercialize epigenetics discoveries in rheumatoid arthritis (RA) in United States. The original research was performed in the laboratory of Gary S. Firestein, MD Professor of Medicine at UC San Diego School of Medicine. The findings from this research will help ND to discover novel DNA methylation biomarkers (patterns which can indicate if a patient has RA or not).
Tokens: <s> <ent> NexDx Inc <ent>  (ND) and <ent2>  University of California <ent2>  (UC) formed a strategic alliance to develop and commercialize epigenetics discoveries in rheumatoid arthritis (RA) in United States. The original research was performed in the laboratory of Gary S. Firestein, MD Professor of Medicine at UC San Diego School of Medicine. The findings from this research will help ND to discover novel DNA methylation biomarkers (patterns which can indicate if a patient has RA or not).</s><pad><pad><pad><pad><p

To deal with the imbalance, we use weights in the loss function proportional to the inverse relative frequency of each class, normalizing by the frequency of the most common class. The resulting weights are as follows:

In [None]:
CLASS_WEIGHTS = 1/(rel_frequencies/rel_frequencies.iloc[0])
CLASS_WEIGHTS

The strategic alliance label is 6 times more frequent than the R&D label, therefore we weight the R&D class 6 times heavier in the loss function.

We can even increase the reweighting by exponentiating these weights:

In [None]:
IMBALANCE_REWEIGHTING = 1
CLASS_WEIGHTS = CLASS_WEIGHTS**IMBALANCE_REWEIGHTING
CLASS_WEIGHTS

Now examples with the R&D label are weighted 10 times heavier than those with the SA label.

In [None]:
CLASS_WEIGHTS = torch.tensor(CLASS_WEIGHTS.to_list()).to(device)

## Inference on Lexis Nexis news

In [22]:
model = LUKE.load_from_checkpoint(checkpoint_path='LUKE/epoch=1-step=2316.ckpt')

model = model.to(device)

Some weights of the model checkpoint at studio-ousia/luke-base were not used when initializing LukeForEntityPairClassification: ['embeddings.position_ids']
- This IS expected if you are initializing LukeForEntityPairClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LukeForEntityPairClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LukeForEntityPairClassification were not initialized from the model checkpoint at studio-ousia/luke-base and are newly initialized: ['classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


"batch_size":    128
"class_weights": tensor([1, 1, 1, 1, 1, 4, 1, 4])
"lr":            1e-05
"test_dataset":  <__main__.RelationExtractionDataset object at 0x7f883d68c280>
"thresholds":    tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000])
"train_dataset": <__main__.RelationExtractionDataset object at 0x7f88cdaac730>
"val_dataset":   <__main__.RelationExtractionDataset object at 0x7f88968f2ee0>
"weight_decay":  0.01


In [141]:
df = pd.read_pickle('lexis_2017_with_org_preds_spacy.pkl')
df

Unnamed: 0,title,content,publication,word_count,publication_date,publication_date_text,author,copyright,subject,country,city,person,industry,company,lang,ents_pred,firms_pred,spans_pred
113,"""Anhui Jianghuai Automobile Group Corp., Ltd.""...","(600418) ""Anhui Jianghuai Automobile Group Cor...",['Shanghai Stock Exchange'],459,2017-11-28,"November 28, 2017 Tuesday",,Copyright 2017 Shanghai Stock Exchange All Rig...,"[ALLIANCES & PARTNERSHIPS, STOCK EXCHANGES]","[CHINA, ASIA]","[SHANGHAI, CHINA]",[],[SIC3711 MOTOR VEHICLES & PASSENGER CAR BODIES...,"[ANHUI JIANGHUAI AUTOMOBILE CO LTD, VOLKSWAGEN...",en,"[(Anhui Jianghuai Automobile Group Corp., Ltd....","[Anhui Jianghuai Automobile Group Corp., Ltd.,...","[(10, 54), (156, 175), (180, 228)]"
659,"""Forced Labour"" Case To Go To Trial","Dec 08, 2017( Mondaq: http://www.mondaq.com / ...",['Mondaq'],7046,2017-12-08,"December 8, 2017 Friday","Mr Rick Williams, Tim Pritchard and Auke Visser",Copyright 2017 Newstex LLC All Rights Reserved,"[APPEALS, APPEALS COURTS, APPELLATE DECISIONS,...","[ERITREA, CANADA]","[ASMARA, ERITREA]",[],"[NAICS212221 GOLD ORE MINING, SIC1041 GOLD ORE...","[NEVSUN RESOURCES LTD, TAHOE RESOURCES INC]",en,"[(Newstex, (59, 66)), (Nevsun Resources Ltd., ...","[Newstex, Nevsun Resources Ltd.]","[(59, 66), (330, 351)]"
699,"""GD Power Development Co., Ltd."" (the Company)...",The Company held the 48th session of the 7th d...,['Shanghai Stock Exchange'],803,2017-08-29,"August 29, 2017 Tuesday",,Copyright 2017 Shanghai Stock Exchange All Rig...,"[JOINT VENTURES, STOCK EXCHANGES, AGREEMENTS]","[CHINA, ASIA]","[SHANGHAI, CHINA, BEIJING, CHINA]",[],"[NAICS221122 ELECTRIC POWER DISTRIBUTION, NAIC...","[GD POWER DEVELOPMENT CO LTD, CHINA SHENHUA EN...",en,"[(China Shenhua Energy Company Limited, (489, ...",[China Shenhua Energy Company Limited],"[(489, 525)]"
1667,"""Shanghai Chlor-Alkali Chemical Co., Ltd."" (th...","(600618) ""Shanghai Chlor-Alkali Chemical Co., ...",['Shanghai Stock Exchange'],425,2017-01-12,"January 12, 2017 Thursday",,Copyright 2017 Shanghai Stock Exchange All Rig...,[STOCK EXCHANGES],"[CHINA, ASIA]","[SHANGHAI, CHINA]",[],[NAICS325180 OTHER BASIC INORGANIC CHEMICAL MA...,"[SHANGHAI CHLOR-ALKALI CHEMICAL CO, ALLANFIELD...",en,"[(Shanghai Chlor-Alkali Chemical Co., Ltd., (1...","[Shanghai Chlor-Alkali Chemical Co., Ltd.]","[(10, 50)]"
1821,"""Sinoma Energy Conservation Ltd."" (the Company...",The Company held the 22nd session of the 2nd d...,['Shanghai Stock Exchange'],1008,2017-07-11,"July 11, 2017 Tuesday",,Copyright 2017 Shanghai Stock Exchange All Rig...,"[ENERGY DEVELOPMENT PROGRAMS, ENGINEERING, HOL...","[CHINA, PHILIPPINES, ASIA, EGYPT]","[SHANGHAI, CHINA, WUHAN, HUBEI, CHINA]",[],"[SIC8711 ENGINEERING SERVICES, SIC3241 CEMENT,...","[APO CEMENT CORP, CHINA NATIONAL MATERIALS CO ...",en,"[(Nantong Wanda Boiler Co. Ltd., (234, 263)), ...","[Nantong Wanda Boiler Co. Ltd., Philippine APO...","[(234, 263), (361, 394)]"
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1492128,£30m centre aiming to foster the next Facebook...,A£30M research centre that aims to ensure the ...,"['The Journal (Newcastle, UK)']",3064,2017-09-19,"September 19, 2017 Tuesday",COREENA FORD,Copyright 2017 Newcastle Chronicle & Journal L...,"[INTERNET SOCIAL NETWORKING, ARCHITECTURE, ASS...","[UNITED KINGDOM, ENGLAND]","[NEWCASTLE UPON TYNE, ENGLAND]",[],[NAICS519130 INTERNET PUBLISHING & BROADCASTIN...,"[FACEBOOK INC, GOOGLE INC]",en,"[(Google, (51, 57)), (Facebook, (61, 69)), (Th...","[Google, Facebook, The National Innovation Cen...","[(51, 57), (61, 69), (120, 159), (251, 273), (..."
1492515,Â BreadTalk plans 200 stores in 7 years,"BTM (Thailand) Ltd, a joint venture between Si...",['The Nation (Thailand)'],3332,2017-06-01,"June 1, 2017 Thursday",,"Copyright 2017 Nation News Network Co., Ltd. A...","[BUSINESS OPERATIONS, JOINT VENTURES, MANAGERS...","[THAILAND, SINGAPORE, MYANMAR, MALAYSIA, BRUNE...","[BANGKOK, THAILAND]",[],[NAICS531110 LESSORS OF RESIDENTIAL BUILDINGS ...,"[MINOR FOOD GROUP PCL, AL MUDON INTERNATIONAL ...",en,"[(BTM (Thailand) Ltd, (0, 18)), (BreadTalk Co,...","[BTM (Thailand) Ltd, BreadTalk Co, Minor Food ...","[(0, 18), (60, 72), (77, 93), (109, 118), (188..."
1492517,Â Sansiri partners with Tokyu on Bt2 bn Ekkam...,LISTED property firm Sansiri Plc has set up a ...,['The Nation (Thailand)'],3960,2017-08-09,"August 9, 2017 Wednesday",,"Copyright 2017 Nation News Network Co., Ltd. A...","[HOLDING COMPANIES, JOINT VENTURES, MANAGERS &...","[JAPAN, THAILAND]","[BANGKOK, THAILAND, TOKYO, JAPAN]",[],"[NAICS485112 COMMUTER RAIL SYSTEMS, NAICS23622...","[TOKYU CORP, SANSIRI PCL, TOKYU LIVABLE INC, B...",en,"[(Sansiri Plc, (21, 32)), (Siri TK One Co Ltd,...","[Sansiri Plc, Siri TK One Co Ltd, Tokyu Corp, ...","[(21, 32), (66, 84), (103, 113), (439, 448), (..."
1492884,Čibuk 1 Wind Farm To Receive EUR215.4mn Loan,News: The European Bank for Reconstruction and...,['Business Monitor Online'],1186,2017-10-20,"October 20, 2017 Friday",,Copyright 2017 Business Monitor International ...,"[INTERNATIONAL ECONOMIC ORGANIZATIONS, DEVELOP...","[SERBIA, EUROPE]","[BELGRADE, SERBIA]",[],[NAICS551112 OFFICES OF OTHER HOLDING COMPANIE...,"[INTESA SANPAOLO SPA, GENERAL ELECTRIC CO]",en,[(The European Bank for Reconstruction and Dev...,[The European Bank for Reconstruction and Deve...,"[(6, 58), (60, 64), (70, 107), (222, 226), (31..."


In [158]:
import torch.nn.functional as F

doc = df.sample()

text = doc.document.iloc[0]
spans = doc.spans.iloc[0]  # character-based entity spans 

inputs = tokenizer(text, entity_spans=spans, padding="max_length", truncation=True, return_tensors="pt",
                            max_length=MAX_LEN)
inputs.to(device)

outputs = model.model(**inputs)
logits = outputs.logits
predicted_classes = (torch.sigmoid(logits)>THRESHOLDS).float()
print("Sentence:", text)
print("Companies: ", doc.firms.iloc[0])
print("Tokens:", tokenizer.decode(inputs["input_ids"][0]))
print("Predicted class idx:", ids2labels(predicted_classes.squeeze().tolist()))
print("Confidence:", torch.sigmoid(logits).squeeze().tolist())

Sentence: Toronto:Algonquin Power & Utilities Corp.   has issued the following news release: Algonquin Power & Utilities Corp. (TSX: AQN, NYSE: AQN) (“APUC” or “Algonquin”) announced today that it has entered into an agreement to create a joint venture (“AAGES”) with Seville, Spain-based Abengoa, S.A. (MCE: ABG) (“Abengoa”) to identify, develop, and construct clean energy and water infrastructure assets with a global focus. Concurrently with the creation of the AAGES joint venture, APUC has entered into a definitive agreement to purchase from Abengoa a 25% equity interest in Atlantica Yield plc (NASDAQ: ABY) (“Atlantica”) for a total purchase price of ~US $608 million, based on a price of US $24.25 per ordinary share of Atlantica plus a contingent payment of up to US $0.60 per-share payable two year
Companies:  ['NYSE', 'MCE']
Tokens: <s>Toronto:Algonquin Power & Utilities Corp.   has issued the following news release: Algonquin Power & Utilities Corp. (TSX: AQN, <ent>  NYSE <ent> : AQ