In [1]:
!nvidia-smi -L

GPU 0: NVIDIA GeForce RTX 3090 (UUID: GPU-3e63a993-41e8-47d9-6b2c-2a5f9f227925)


In [2]:
!ls

'Firm_relation_extraction_LUKE (7).ipynb'		        onstart.log
 LUKE							        onstart.sh
 lr_find_temp_model_75555ed7-5ed1-4cb6-bab9-010cd54dd15d.ckpt   wandb


In [3]:
!rm -r LUKE

In this notebook, we are going to fine-tune [`LukeForEntityPairClassification`](https://huggingface.co/transformers/model_doc/luke.html#lukeforentitypairclassification) on a supervised **relation extraction** dataset.

The goal for the model is to predict, given a sentence and the character spans of two entities within the sentence, the relationship between the entities.

* Paper: https://arxiv.org/abs/2010.01057
* Original repository: https://github.com/studio-ousia/luke

In [4]:
!pip install -q transformers 
!pip install -q pandas
!pip install -q sklearn
!pip install -q pytorch-lightning wandb



In [5]:
from transformers import LukeTokenizer
from torch.utils.data import Dataset, DataLoader
import torch
import pandas as pd

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [7]:
rs = 42

## Read in data

Let's download the data from the web, hosted on Dropbox.

Each row in the dataframe consists of a news article. 

In [8]:
df = pd.read_pickle("https://www.dropbox.com/s/rdwg8d76ytqqgdy/training-data-29-03-22.pkl?dl=1")
df

Unnamed: 0,Date,source,document,firms,spans,rels
233892045,1985-01-01,Thomson SDC alliances - Deal Number 233892045,Accel Capital L.P. invested an undisclosed amo...,"[Accel Capital, Thalamus Electronics Inc]","[(0, 13), (53, 77)]","[StrategicAlliance, Manufacturing]"
216749045,1985-01-01,Thomson SDC alliances - Deal Number 216749045,Standard Microsystems Corp. and Mitsubishi Ele...,"[Standard Microsystems Corp, Mitsubishi Electr...","[(0, 26), (32, 56)]","[StrategicAlliance, Licensing]"
234987045,1985-01-01,Thomson SDC alliances - Deal Number 234987045,"Elf Technologies, Inc., the venture capital in...","[Elf Technologies, Santa Clara Systems]","[(0, 16), (169, 188)]",[StrategicAlliance]
216600045,1985-01-01,Thomson SDC alliances - Deal Number 216600045,Stauffer Chemical Co. has been licensed by Mit...,"[Stauffer Chemical Co, Himont Inc]","[(0, 20), (79, 89)]","[StrategicAlliance, Marketing, Manufacturing, ..."
2410877045,1985-01-01,Thomson SDC alliances - Deal Number 2410877045,The Romanian Government (RG) and Control Data ...,"[Romania, Control Data Corp]","[(4, 11), (33, 50)]","[JointVenture, Manufacturing]"
...,...,...,...,...,...,...
257823,NaT,Reuters News Dataset - Article ID http://www.b...,University of Notre Dame football radio anal...,"[Notre Dame, Notre Dame’s]","[(15, 25), (200, 212)]",[]
416616,NaT,Reuters News Dataset - Article ID http://www.b...,"Schaeffler AG, the industrial-bearing maker th...","[Schaeffler AG, Continental AG]","[(0, 13), (91, 105)]",[]
328325,NaT,Reuters News Dataset - Article ID http://www.r...,Barnes & Noble Inc ( BKS.N ) said on Friday it...,"[Barnes & Noble, Lynch]","[(0, 14), (113, 118)]",[]
121959,NaT,Reuters News Dataset - Article ID http://www.b...,(Corrects third paragraph to show output of 25...,"[First Solar Inc ., Edison International]","[(154, 171), (249, 269)]",[]


Shuffle the data:

In [9]:
df = df.sample(frac=1, random_state=rs)

Let's one hot encode the relationship classes.

In [10]:
rel_frequencies = df.rels.explode().value_counts()
rel_label_names = rel_frequencies.index.to_list()

def label2ids(labels, rel_label_names=rel_label_names):
  return [1 if rel_name in labels else 0 for rel_name in rel_label_names]
df['rel_one_hot'] = df.rels.apply(label2ids)

def ids2labels(ids, rel_label_names=rel_label_names):
  labels = []
  for idx, label in enumerate(rel_label_names):
    if ids[idx] == 1:
      labels.append(label)
    
  return labels

df.sample(10)

Unnamed: 0,Date,source,document,firms,spans,rels,rel_one_hot
29812,NaT,Reuters News Dataset - Article ID http://www.r...,Shares of large-cap technology companies fell ...,"[Intel, Microsoft Corp]","[(99, 104), (200, 214)]",[],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
1875301045,2007-05-30,Thomson SDC alliances - Deal Number 1875301045,PSA Corp Ltd (PC) and International Port Holdi...,"[PSA Corp Ltd, International Port Holdings]","[(0, 12), (22, 49)]","[JointVenture, Pending]","[0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]"
3593976045,2020-06-23,Thomson SDC alliances - Deal Number 3593976045,Processminer Inc and Litmus Automation Inc for...,"[Processminer Inc, Litmus Automation Inc]","[(0, 16), (21, 42)]",[StrategicAlliance],"[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
1033381045,1997-10-13,Thomson SDC alliances - Deal Number 1033381045,Select Media Communications and Edu-Active agr...,"[Select Media Communications, Edu-Active]","[(0, 27), (32, 42)]","[StrategicAlliance, Pending]","[1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]"
426734045,1994-09-16,Thomson SDC alliances - Deal Number 426734045,"Campenon Bernard SGE, a unit of Societe Genera...","[Campenon Bernard, Balkan Airlines]","[(0, 16), (129, 144)]",[JointVenture],"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
69559,NaT,Reuters News Dataset - Article ID http://www.b...,Carlyle Group and TPG Capital agreed to buy Au...,"[Carlyle Group, TPG Capital]","[(0, 13), (18, 29)]",[],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
214987045,1990-09-10,Thomson SDC alliances - Deal Number 214987045,Samsung Electronics Co Ltd (SE) and Teradyne I...,"[Samsung Electronics Co Ltd, Teradyne Inc]","[(0, 26), (36, 48)]","[StrategicAlliance, ResearchandDevelopment]","[1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]"
199784,NaT,Reuters News Dataset - Article ID http://www.b...,Home-video spending by U.S. consumers rose 2.5...,"[Digital Entertainment Group, DEG]","[(178, 205), (300, 303)]",[],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
3212704045,2017-10-25,Thomson SDC alliances - Deal Number 3212704045,Cisco Systems Inc and Google Inc formed a stra...,"[Cisco Systems Inc, Google]","[(0, 17), (22, 28)]","[StrategicAlliance, TechnologyTransfer]","[1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]"
217947045,1991-09-16,Thomson SDC alliances - Deal Number 217947045,Retix Corp and BDS Corp signed an agreement wh...,"[Retix Corp, BDS]","[(0, 10), (15, 18)]","[StrategicAlliance, Marketing]","[1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]"


This is the frequency of our relation types:



In [11]:
rel_frequencies

StrategicAlliance         83845
JointVenture              44556
Pending                   38916
Marketing                 22478
Manufacturing             19926
ResearchandDevelopment    13568
Licensing                 12857
TechnologyTransfer         9889
Supply                     5407
Exploration                3152
Terminated                 1578
Name: rels, dtype: int64

To deal with the imbalance, we use weights in the loss function proportional to the inverse relative frequency of each class, normalizing by the frequency of the most common class. The resulting weights are as follows:

In [12]:
class_weights = 1/(rel_frequencies/rel_frequencies.iloc[0])
class_weights

StrategicAlliance          1.000000
JointVenture               1.881789
Pending                    2.154512
Marketing                  3.730092
Manufacturing              4.207819
ResearchandDevelopment     6.179614
Licensing                  6.521350
TechnologyTransfer         8.478613
Supply                    15.506751
Exploration               26.600571
Terminated                53.133714
Name: rels, dtype: float64

The strategic alliance label is 7 times more frequent than the R&D label, therefore we weight the R&D class 7 times heavier in the loss function.

In [13]:
class_weights = torch.tensor(class_weights.to_list()).to(device)

We can even increase the overweighting:

In [14]:
IMBALANCE_OVERWEIGHTING = 4
class_weights = class_weights**IMBALANCE_OVERWEIGHTING

## Define the PyTorch dataset and dataloaders

Next, we define regular PyTorch datasets and corresponding dataloaders. In PyTorch, you need to define a `Dataset` class that inherits from `torch.utils.data.Dataset`, and you need to implement 3 methods: the `init` method (for initializing the dataset with data), the `len` method (which returns the number of elements in the dataset) and the `getitem()` method, which returns a single item from the dataset.

In our case, each item of the dataset consists of a sentence, the spans of 2 entities in the sentence, and a label of the relationship. We use `LukeTokenizer` (available in the Transformers library) to turn these into the inputs expected by the model, which are `input_ids`, `entity_ids`, `attention_mask`, `entity_attention_mask` and `entity_position_ids`.

For more information regarding these inputs, refer to the [docs](https://huggingface.co/transformers/model_doc/luke.html#lukeforentitypairclassification) of `LukeForEntityPairClassification`.


Let's set our hyperparameters:

In [15]:
MAX_LEN = 128
LEARNING_RATE = 1e-05
BATCH_SIZE = 128
THRESHOLD = 0.5
MAX_EPOCHS = 1

In [16]:
tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", task="entity_pair_classification")

class RelationExtractionDataset(Dataset):
    """Relation extraction dataset."""

    def __init__(self, data):
        """
        Args:
            data : Pandas dataframe.
        """
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data.iloc[idx]

        text = item.document

        encoding = tokenizer(text, entity_spans=item.spans, padding="max_length", truncation=True, return_tensors="pt",
                            max_length=MAX_LEN)

        for k,v in encoding.items():
          encoding[k] = encoding[k].squeeze()

        encoding["label"] = torch.tensor(item.rel_one_hot)

        return encoding

Here we instantiate the class defined above with 3 objects: a training dataset, a validation dataset and a test set.

In [17]:
from sklearn.model_selection import train_test_split

train_df, test_df = train_test_split(df, test_size=0.1, random_state=42, shuffle=True)
train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=42, shuffle=False)

print("FULL Dataset: {}".format(len(df)))
print("TRAIN Dataset: {}".format(len(train_df)))
print("TEST Dataset: {}".format(len(test_df)))
print("VALIDATION Dataset: {}".format(len(val_df)))

# define the dataset
train_dataset = RelationExtractionDataset(data=train_df)
valid_dataset = RelationExtractionDataset(data=val_df)
test_dataset = RelationExtractionDataset(data=test_df)

FULL Dataset: 264928
TRAIN Dataset: 214591
TEST Dataset: 26493
VALIDATION Dataset: 23844


In [18]:
train_dataset[0].keys()

dict_keys(['input_ids', 'entity_ids', 'entity_position_ids', 'attention_mask', 'entity_attention_mask', 'label'])

## Define a PyTorch LightningModule

Let's define the model as a PyTorch LightningModule. A `LightningModule` is actually an `nn.Module`, but with some extra functionality.

For more information regarding how to define this, see the [docs](https://pytorch-lightning.readthedocs.io/en/latest/?_ga=2.56317931.1395871250.1622709933-1738348008.1615553774) of PyTorch Lightning.

In [19]:
import torch, gc
gc.collect()
torch.cuda.empty_cache()

In [20]:
from transformers import LukeForEntityPairClassification, AdamW
import pytorch_lightning as pl
from sklearn import metrics

class LUKE(pl.LightningModule):

    def __init__(self, lr=LEARNING_RATE, batch_size=BATCH_SIZE):
        super().__init__()
        self.model = LukeForEntityPairClassification.from_pretrained("studio-ousia/luke-base", num_labels=len(rel_label_names))
        self.save_hyperparameters()

    def forward(self, input_ids, entity_ids, entity_position_ids, attention_mask, entity_attention_mask):     
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, entity_ids=entity_ids, 
                             entity_attention_mask=entity_attention_mask, entity_position_ids=entity_position_ids)
        return outputs
    
    def common_step(self, batch, batch_idx):
        labels = batch['label'].float()
        del batch['label']
        outputs = self(**batch)
        logits = outputs.logits
        
        criterion = torch.nn.BCEWithLogitsLoss(weight=class_weights) # multi-label classification with weighted classes
        loss = criterion(logits, labels)
        preds = (torch.sigmoid(logits)>THRESHOLD).float()

        return {'loss': loss, 'preds': preds, 'labels': labels}
      
    def training_step(self, batch, batch_idx):
        output = self.common_step(batch, batch_idx)
        loss = output['loss']
        # logs metrics for each training_step,
        # and the average across the epoch
        self.log("training_loss", loss)

        return loss

    def validation_step(self, batch, batch_idx):
        output = self.common_step(batch, batch_idx)     
        loss = output['loss']
        self.log("val_loss", loss, on_epoch=True)
        
        preds = output['preds']
        labels = output['labels']
        
        return {"loss": loss, "preds": preds, "labels": labels}
    
    def validation_epoch_end(self, outputs):
            for i, output in enumerate(outputs):
                preds = output["preds"].detach().cpu().numpy()
                labels = output["labels"].detach().cpu().numpy()
                loss = output["loss"].mean()

                for idx, label_name in enumerate(rel_label_names):
                    label = [label[idx] for label in labels]
                    pred = [pred[idx] for pred in preds]

                    precision = metrics.precision_score(label, pred)
                    recall = metrics.recall_score(label, pred)
                    f1 = metrics.f1_score(label, pred)
                    self.log(f'val_precision_{label_name}', precision, prog_bar=True)
                    self.log(f'val_recall_{label_name}', recall, prog_bar=True)
                    self.log(f'val_f1_{label_name}', f1, prog_bar=True)
        
    def test_step(self, batch, batch_idx):
        output = self.common_step(batch, batch_idx)     
        loss = output['loss']
        self.log("test_loss", loss, on_epoch=True)
        
        preds = output['preds']
        labels = output['labels']
        
        return {"loss": loss, "preds": preds, "labels": labels}
    
    def test_epoch_end(self, outputs):
            for i, output in enumerate(outputs):
                preds = output["preds"].detach().cpu().numpy()
                labels = output["labels"].detach().cpu().numpy()
                loss = output["loss"].mean()

                for idx, label_name in enumerate(rel_label_names):
                    label = [label[idx] for label in labels]
                    pred = [pred[idx] for pred in preds]
                    precision = metrics.precision_score(label, pred)
                    recall = metrics.recall_score(label, pred)
                    f1 = metrics.f1_score(label, pred)
                    self.log(f'test_precision_{label_name}', precision)
                    self.log(f'test_recall_{label_name}', recall)
                    self.log(f'test_f1_{label_name}', f1, prog_bar=True)
        
    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.hparams.lr)
        return optimizer

    def train_dataloader(self):
        return DataLoader(train_dataset, batch_size=self.hparams.batch_size, shuffle=True, num_workers=16)

    def val_dataloader(self):
        return DataLoader(valid_dataset, batch_size=self.hparams.batch_size, num_workers=16)

    def test_dataloader(self):
        return DataLoader(test_dataset, batch_size=self.hparams.batch_size, num_workers=16)

    
model = LUKE()

Some weights of the model checkpoint at studio-ousia/luke-base were not used when initializing LukeForEntityPairClassification: ['embeddings.position_ids']
- This IS expected if you are initializing LukeForEntityPairClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LukeForEntityPairClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LukeForEntityPairClassification were not initialized from the model checkpoint at studio-ousia/luke-base and are newly initialized: ['classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Train the model

Let's train the model. We also use early stopping, to avoid overfitting the training dataset. We also log everything to Weights and Biases, which will give us beautiful charts of the loss and accuracy plotted over time.

If you haven't already, you can create an account on the [website](https://wandb.ai/site), then log in in a web browser, and run the cell below: 

In [21]:
import wandb

wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mjakr[0m (use `wandb login --relogin` to force relogin)


True

In [22]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping

wandb_logger = WandbLogger(name='luke-class-weights', project='LUKE')

# for early stopping, see https://pytorch-lightning.readthedocs.io/en/1.0.0/early_stopping.html?highlight=early%20stopping
early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=5,
    strict=False,
    verbose=False,
    mode='min'
)

checkpoint_callback = ModelCheckpoint(dirpath='LUKE')

trainer = Trainer(gpus=1, logger=wandb_logger, callbacks=[EarlyStopping(monitor='val_loss'), checkpoint_callback], 
                  max_epochs=MAX_EPOCHS, precision=16, 
                  stochastic_weight_avg=True, auto_lr_find=True, benchmark=True, deterministic=True)
                    # fast_dev_run=True
                    # limit_train_batches=0.25

NameError: name 'ModelCheckpoint' is not defined

Finding the optimal learning rater:

In [None]:
lr_finder = trainer.tuner.lr_find(model)

In [None]:
lr_finder.suggestion()

In [None]:
model.hparams.lr = lr_finder.suggestion()

Let's train the model:

In [None]:
trainer.fit(model)

In [None]:
trainer.test()

In [None]:
checkpoint_callback.best_model_path

In [32]:
!pip -q install dropbox



In [3]:
import dropbox
import os
from tqdm import tqdm

access_token = 'sl.BEuonkM5dN24jg8sgsMvxGGn_Sjdn0Is3DJpg2CRKpUgJQnkzkat6QU9OxNC59hnEK-JAACG4aGKNbNL5FiBfNRmtd1R6RpSM8Lb9Z720LFNpCuNKruvreOM3JE4qYDydyfXh9y8XTE'

file_from = 'LUKE/f4zydpek/checkpoints/epoch=4-step=8384.ckpt'
file_to = '/trained_models/29-3-22-17h-epoch=4-step=8384.ckpt'  # The full path to upload the file to, including the file name

def upload(
    access_token,
    file_path,
    target_path,
    timeout=900,
    chunk_size=4 * 1024 * 1024,
):
    dbx = dropbox.Dropbox(access_token, timeout=timeout)
    with open(file_path, "rb") as f:
        file_size = os.path.getsize(file_path)
        if file_size <= chunk_size:
            print(dbx.files_upload(f.read(), target_path))
        else:
            with tqdm(total=file_size, desc="Uploaded") as pbar:
                upload_session_start_result = dbx.files_upload_session_start(
                    f.read(chunk_size)
                )
                pbar.update(chunk_size)
                cursor = dropbox.files.UploadSessionCursor(
                    session_id=upload_session_start_result.session_id,
                    offset=f.tell(),
                )
                commit = dropbox.files.CommitInfo(path=target_path)
                while f.tell() < file_size:
                    if (file_size - f.tell()) <= chunk_size:
                        print(
                            dbx.files_upload_session_finish(
                                f.read(chunk_size), cursor, commit
                            )
                        )
                    else:
                        dbx.files_upload_session_append(
                            f.read(chunk_size),
                            cursor.session_id,
                            cursor.offset,
                        )
                        cursor.offset = f.tell()
                    pbar.update(chunk_size)

upload(
    access_token,
    file_from,
    file_to,
    timeout=900,
    chunk_size=100 * 1024 * 1024,
)

Uploaded: 3355443200it [04:07, 13533766.16it/s]                    

FileMetadata(client_modified=datetime.datetime(2022, 3, 29, 15, 40, 29), content_hash='a3dda202b0625ed24a81318cbb5005cd53468219babefa223bdc7d3fc2c24ffa', export_info=NOT_SET, file_lock_info=NOT_SET, has_explicit_shared_members=NOT_SET, id='id:c181gNQpVocAAAAAAAAfXQ', is_downloadable=True, media_info=NOT_SET, name='29-3-22-17h-epoch=4-step=8384.ckpt', parent_shared_folder_id=NOT_SET, path_display='/trained_models/29-3-22-17h-epoch=4-step=8384.ckpt', path_lower='/trained_models/29-3-22-17h-epoch=4-step=8384.ckpt', property_groups=NOT_SET, rev='5db5d3fa364f00557fa2f', server_modified=datetime.datetime(2022, 3, 29, 15, 40, 29), sharing_info=NOT_SET, size=3289782593, symlink_info=NOT_SET)





## Evaluation

Instead of calling `trainer.test()`, we can also manually evaluate the model on the entire test set:

In [19]:
# from tqdm.notebook import tqdm
# from sklearn.metrics import accuracy_score

# model = LUKE.load_from_checkpoint(checkpoint_path="LUKE/14udvobu/checkpoints/epoch=3-step=2333.ckpt")

# model.to(device)

# model.model.eval()

# predictions_total = []
# logits_total = []
# labels_total = []

# for batch in tqdm(model.test_dataloader()):
#     # get the inputs;
#     labels = batch["label"]
#     del batch["label"]

#     # move everything to the GPU
#     for k,v in batch.items():
#       batch[k] = batch[k].to('cuda')

#     # forward pass
#     outputs = model.model(**batch)
#     logits = outputs.logits
#     logits_total.extend(torch.sigmoid(logits).tolist())
#     predictions = (torch.sigmoid(logits)>THRESHOLD).float()
#     predictions_total.extend(predictions.tolist())
#     labels_total.extend(labels.tolist())

Some weights of the model checkpoint at studio-ousia/luke-base were not used when initializing LukeForEntityPairClassification: ['embeddings.position_ids']
- This IS expected if you are initializing LukeForEntityPairClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LukeForEntityPairClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LukeForEntityPairClassification were not initialized from the model checkpoint at studio-ousia/luke-base and are newly initialized: ['classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/146 [00:00<?, ?it/s]

1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
2 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
2 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens
1 entities are ignored because their entity spans are invalid due to the truncation of input tokens


In [20]:
# for idx, label_name in enumerate(rel_label_names):
#     labels = [label[idx] for label in labels_total]
#     predictions = [pred[idx] for pred in predictions_total]
#     logits = [logit[idx] for logit in logits_total]
#     precision = metrics.precision_score(labels, predictions)
#     recall = metrics.recall_score(labels, predictions)
#     f1 = metrics.f1_score(labels, predictions)
#     roc_auc = metrics.roc_auc_score(labels, logits)
#     print(label_name)
#     print(f'Precision: {precision:.3f}, Recall: {recall:.3f}')
#     print(f'F1-Score: {f1:.3f}, ROC-AUC: {roc_auc:.3f}\n')

StrategicAlliance
Precision: 0.921, Recall: 0.993
F1-Score: 0.956, ROC-AUC: 0.980

JointVenture
Precision: 0.979, Recall: 0.988
F1-Score: 0.983, ROC-AUC: 0.998

Marketing
Precision: 0.821, Recall: 0.725
F1-Score: 0.770, ROC-AUC: 0.943

Licensing
Precision: 0.752, Recall: 0.859
F1-Score: 0.802, ROC-AUC: 0.967

Manufacturing
Precision: 0.816, Recall: 0.809
F1-Score: 0.812, ROC-AUC: 0.972

ResearchandDevelopment
Precision: 0.740, Recall: 0.784
F1-Score: 0.761, ROC-AUC: 0.964



## Inference

Here we test the trained model on a new, unseen sentence.

In [77]:
import torch.nn.functional as F

test_doc = test_df[test_df.relation.str.len() == 0].sample()
text = test_doc.document.iloc[0]
entity_spans = test_doc.entity_spans.iloc[0]  # character-based entity spans 

inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")
inputs.to('cuda')

outputs = model.model(**inputs)
logits = outputs.logits
predicted_classes = (torch.sigmoid(logits)>THRESHOLD).float()
print("Sentence:", text)
print("Ground truth label:", ids2labels(test_doc.rel_one_hot.iloc[0]))
print("Predicted class idx:", ids2labels(predicted_classes.squeeze().tolist()))
print("Confidence:", torch.sigmoid(logits).squeeze().tolist())

Sentence: Jobs rut tips scales in favor of Fed stimulus. The Federal Reserve looks set to launch a third round of bond purchases this week to try to drive borrowing costs lower and breathe more life into an economy that is not growing fast enough to lower unemployment. Despite political opposition and some internal dissent, economists said a weak report on jobs growth for August was likely enough to convince the U.S.
Ground truth label: []
Predicted class idx: ['StrategicAlliance']
Confidence: [0.7804471254348755, 0.0005446382565423846, 0.005502107087522745, 0.0012131521943956614, 0.0020684748888015747, 0.0006311357719823718]


In [62]:
test_df

Unnamed: 0,document,entities,entity_spans,relation,rel_one_hot
289956,Mexican President Calderon Signs Labor-System ...,"[Calderon, Felipe Calderon]","[(18, 26), (84, 99)]",[],"[0, 0, 0, 0, 0, 0]"
13769,Volvo AB and Nvidia Corp formed a strategic al...,"[Volvo AB, Nvidia Corp]","[(0, 8), (13, 24)]",[StrategicAlliance],"[1, 0, 0, 0, 0, 0]"
44305,Generex Biotechnology Corp (GX) and MedGen Cor...,"[Generex Biotechnology Corp, MedGen Corp]","[(0, 26), (36, 47)]","[Licensing, Marketing, StrategicAlliance]","[1, 0, 1, 1, 0, 0]"
2110,Walgreen Co and Centura Health formed a strate...,"[Walgreen Co, Centura Health]","[(0, 11), (16, 30)]",[StrategicAlliance],"[1, 0, 0, 0, 0, 0]"
34642,Beeline.com Inc and Talentnet Inc formed a str...,"[Beeline.com Inc, Talentnet Inc]","[(0, 15), (20, 33)]",[StrategicAlliance],"[1, 0, 0, 0, 0, 0]"
...,...,...,...,...,...
45585,Genomenon Inc and Congenica Ltd formed a strat...,"[Genomenon Inc, Congenica Ltd]","[(0, 13), (18, 31)]","[ResearchandDevelopment, StrategicAlliance]","[1, 0, 0, 0, 0, 1]"
29225,Hashedin Technologies Pvt Ltd and Snowflake In...,"[Hashedin Technologies Pvt Ltd, Snowflake Inc]","[(0, 29), (34, 47)]",[StrategicAlliance],"[1, 0, 0, 0, 0, 0]"
28479,Tera Systems Inc (TS) and ED&C Ltd (ED) formed...,"[Tera Systems Inc, ED&C Ltd]","[(0, 16), (26, 34)]","[Marketing, StrategicAlliance]","[1, 0, 1, 0, 0, 0]"
171107,Bank of Montreal Posts Record Profit on Boost ...,"[Bank of Montreal Posts Record Profit, Bank of...","[(0, 36), (74, 90)]",[],"[0, 0, 0, 0, 0, 0]"
