# Relation extraction with BERT

---

The goal of this repo is to show how to use [BERT](https://arxiv.org/abs/1810.04805)
to [extract relation](https://en.wikipedia.org/wiki/Relationship_extraction) from text.

Used libraries:
- [Transformers](https://huggingface.co/transformers/index.html)
- [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/)

Used datasets:
- SemEval 2010 Task 8 - [paper](https://arxiv.org/pdf/1911.10422.pdf) - [download](https://github.com/sahitya0000/Relation-Classification/blob/master/corpus/SemEval2010_task8_all_data.zip?raw=true)
-  Google IISc Distant Supervision (GIDS) - [paper](https://arxiv.org/pdf/1804.06987.pdf) - [download](https://drive.google.com/open?id=1gTNAbv8My2QDmP-OHLFtJFlzPDoCG4aI)

## Install dependencies

This project uses [Python 3.7+](https://www.python.org/downloads/release/python-378/)

In [2]:
!conda install -c conda-forge gdown --yes
!pip install requests==2.24.0 numpy==1.19.0 pandas==1.0.5 \
    scikit-learn==0.23.1 pytorch-lightning==0.8.4 torch==1.5.1 \
    transformers==3.0.2 sklearn==0.0 tqdm==4.47.0




## Import needed modules

In [1]:
import json
import os
import pickle
import shutil
import zipfile
from abc import ABC, abstractmethod
from typing import Tuple
from urllib.parse import urlparse

import gdown
import requests
import torch
from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer as LightningTrainer
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from torch import Tensor, nn
from torch.nn import functional as F
from torch.optim import AdamW
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
from transformers import *

## Define constants

In [2]:
# --- Directory ---
ROOT_DIR = os.path.abspath('.')
RAW_DATA_DIR = os.path.join(ROOT_DIR, 'data/raw')
PROCESSED_DATA_DIR = os.path.join(ROOT_DIR, 'data/processed') 
CHECKPOINT_DIR = os.path.join(ROOT_DIR, 'checkpoint')

# --- Datasets ---
DATASET_MAPPING = {
    'SemEval2010Task8': {
        'dir': os.path.join(RAW_DATA_DIR,'SemEval2010_task8_all_data'),
        'url': 'https://github.com/sahitya0000/Relation-Classification/'
               'blob/master/corpus/SemEval2010_task8_all_data.zip?raw=true',
        'num_classes': 10,
    },
    'GIDS': {
        'dir': os.path.join(RAW_DATA_DIR,'gids_data'),
        'url': 'https://drive.google.com/uc?id=1gTNAbv8My2QDmP-OHLFtJFlzPDoCG4aI&export=download',
        'num_classes': 5
    }
}
DATASET_NAME = 'SemEval2010Task8'

# --- BERT ---
SUB_START_CHAR = '{'
SUB_END_CHAR = '}'
OBJ_START_CHAR = '['
OBJ_END_CHAR = ']'

# --- BERT Model ---
# See https://huggingface.co/transformers/pretrained_models.html for the full list

BERT_VARIANT_MAPPING = {
    'bert': {
        'model': BertModel,
        'tokenizer': BartTokenizer,
        'pretrain_weight': 'bert-base-uncased',
        'available_pretrain_weights': ['bert-base-uncased', 'bert-base-cased']
    },
    'distilbert': {
        'model': DistilBertModel,
        'tokenizer': DistilBertTokenizer,
        'pretrain_weight': 'distilbert-base-uncased',
        'available_pretrain_weights': ['distilbert-base-uncased', 'distilbert-base-cased']
    },
    'roberta': {
        'model': RobertaModel,
        'tokenizer': RobertaTokenizer,
        'pretrain_weight': 'roberta-base',
        'available_pretrain_weights': ['roberta-base', 'distilroberta-base']
    },
}
BERT_VARIANT = 'distilbert'

## Create subdirectories

In [3]:
if not os.path.exists(RAW_DATA_DIR):
    print("Creating raw data directory " + RAW_DATA_DIR)
    os.makedirs(RAW_DATA_DIR)

if not os.path.exists(PROCESSED_DATA_DIR):
    print("Creating processed data directory " + PROCESSED_DATA_DIR)
    os.makedirs(PROCESSED_DATA_DIR)

if not os.path.exists(CHECKPOINT_DIR):
    print(f"Creating checkpoint directory "+ CHECKPOINT_DIR)
    os.makedirs(CHECKPOINT_DIR)

## Download data

First, we define some download util functions:

In [4]:
def download_from_url(url: str, save_path: str, chunk_size: int = 2048) -> None:
    with open(save_path, "wb") as f:
        print(f"Downloading...\nFrom: {url}\nTo: {save_path}")
        response = requests.get(url, stream=True)
        for data in tqdm(response.iter_content(chunk_size=chunk_size)):
            f.write(data)

def download_from_google_drive(url: str, save_path: str) -> None:
    gdown.download(url, save_path, use_cookies=False)

def extract_zip(zip_file_path: str, extract_dir: str, remove_zip_file=True):
    with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
        print("Extracting to " + extract_dir)
        zip_ref.extractall(extract_dir)

    if remove_zip_file:
        print("Removing zip file")
        os.unlink(zip_file_path)

The download function itself:

In [5]:
def download(dataset_name, dataset_url, dataset_dir, force_redownload: bool):
    print(f"\n---> Downloading dataset {dataset_name} <---")

    # check data has been downloaded
    if os.path.exists(dataset_dir):
        if force_redownload:
            print(f"Removing old raw data {dataset_dir}")
            shutil.rmtree(dataset_dir)
        else:
            print(f"Directory {dataset_dir} exists, skip downloading.")
            return


    # download
    tmp_file_path = os.path.join(RAW_DATA_DIR, dataset_name + '.zip')
    if urlparse(dataset_url).netloc == 'drive.google.com':
        download_from_google_drive(dataset_url, tmp_file_path)
    else:
        download_from_url(dataset_url, tmp_file_path)

    # unzip
    extract_zip(tmp_file_path, RAW_DATA_DIR)

Download all datasets:

In [6]:
for dataset_name, dataset_info in DATASET_MAPPING.items():
    download(
        dataset_name,
        dataset_url=dataset_info['url'],
        dataset_dir=dataset_info['dir'],
        force_redownload=False
    )


---> Downloading dataset SemEval2010Task8 <---
Directory /media/dthung1602/WORKING/bert-relation-extraction/data/raw/SemEval2010_task8_all_data exists, skip downloading.

---> Downloading dataset GIDS <---
Directory /media/dthung1602/WORKING/bert-relation-extraction/data/raw/gids_data exists, skip downloading.


## Preprocess

The abstract preprocessor

In [6]:
class AbstractPreprocessor(ABC):
    DATASET_NAME = ''

    def __init__(self, tokenizer: PreTrainedTokenizer):
        self.tokenizer = tokenizer

    def preprocess_data(self, reprocess: bool):
        print(f"\n---> Preprocessing {self.DATASET_NAME} dataset <---")

        # stop preprocessing if file existed
        pickled_file_names = [self.get_pickle_file_name(k) for k in ('train', 'val', 'test')]
        existed_files = [fn for fn in pickled_file_names if os.path.exists(fn)]
        if existed_files:
            file_text = "- " + "\n- ".join(existed_files)
            if not reprocess:
                print("The following files already exist:")
                print(file_text)
                print("Preprocessing is skipped. See option --reprocess.")
                return
            else:
                print("The following files will be overwritten:")
                print(file_text)

        self._preprocess_data()

    @abstractmethod
    def _preprocess_data(self):
        pass

    def _pickle_data(self, data, file_name):
        print(f"Saving to pickle file {file_name}")
        with open(file_name, 'wb') as f:
            pickle.dump(data, f)

    @classmethod
    def get_pickle_file_name(cls, key: str):
        return os.path.join(PROCESSED_DATA_DIR, f'{cls.DATASET_NAME.lower()}_{key}.pkl')

For each dataset, define a preprocessor:

In [8]:
class SemEval2010Task8Preprocessor(AbstractPreprocessor):
    DATASET_NAME = 'SemEval2010Task8'
    RAW_TRAIN_FILE_NAME = os.path.join(RAW_DATA_DIR,
                                       'SemEval2010_task8_all_data/SemEval2010_task8_training/TRAIN_FILE.TXT')
    RAW_TEST_FILE_NAME = os.path.join(RAW_DATA_DIR,
                                      'SemEval2010_task8_all_data/SemEval2010_task8_testing_keys/TEST_FILE_FULL.TXT')
    RAW_TRAIN_DATA_SIZE = 8000
    RAW_TEST_DATA_SIZE = 2717
    RANDOM_SEED = 2020
    VAL_DATA_PROPORTION = 0.2

    def _preprocess_data(self):
        print("Processing training data")
        train_data = self._get_data_from_file(
            self.RAW_TRAIN_FILE_NAME,
            self.RAW_TRAIN_DATA_SIZE
        )

        print("Processing test data")
        test_data = self._get_data_from_file(
            self.RAW_TEST_FILE_NAME,
            self.RAW_TEST_DATA_SIZE
        )

        print("Encoding labels to integers")
        le = LabelEncoder()
        le.fit(train_data['labels'])
        train_data['labels'] = le.transform(train_data['labels']).tolist()
        test_data['labels'] = le.transform(test_data['labels']).tolist()

        print("Splitting train & validate data")
        train_data, val_data = self._train_val_split(train_data)

        self._pickle_data(train_data, self.get_pickle_file_name('train'))
        self._pickle_data(val_data, self.get_pickle_file_name('val'))
        self._pickle_data(test_data, self.get_pickle_file_name('test'))

    def _train_val_split(self, original_data):
        k = list(original_data.keys())[0]
        indies = list(range(len(original_data[k])))
        train_indies, val_indies = train_test_split(
            indies,
            test_size=self.VAL_DATA_PROPORTION,
            random_state=self.RANDOM_SEED
        )
        train_data = {k: self._get_sample(v, train_indies) for k, v in original_data.items()}
        val_data = {k: self._get_sample(v, val_indies) for k, v in original_data.items()}

        return train_data, val_data

    def _get_sample(self, data, indies):
        return [data[i] for i in indies]

    def _get_data_from_file(self, file_name: str, dataset_size: int):
        raw_sentences = []
        labels = []
        with open(file_name) as f:
            for _ in tqdm(range(dataset_size)):
                raw_sentences.append(self._process_sentence(f.readline()))
                labels.append(self._process_label(f.readline()))
                f.readline()
                f.readline()
        data = self.tokenizer(raw_sentences, truncation=True, padding=True)
        data['labels'] = labels
        return data

    def _process_sentence(self, sentence: str):
        # TODO distinguish e1 e2 sub obj
        return sentence.split("\t")[1][1:-2] \
            .replace("<e1>", SUB_START_CHAR) \
            .replace("</e1>", SUB_END_CHAR) \
            .replace("<e2>", OBJ_START_CHAR) \
            .replace("</e2>", OBJ_END_CHAR)

    def _process_label(self, label: str):
        return label[:-8]


class GIDSPreprocessor(AbstractPreprocessor):
    DATASET_NAME = 'GIDS'
    RAW_TRAIN_FILE_NAME = os.path.join(RAW_DATA_DIR, 'gids_data/gids_train.json')
    RAW_VAL_FILE_NAME = os.path.join(RAW_DATA_DIR, 'gids_data/gids_dev.json')
    RAW_TEST_FILE_NAME = os.path.join(RAW_DATA_DIR, 'gids_data/gids_test.json')

    def _preprocess_data(self):
        print("Processing validate data")
        val_data = self._get_data_from_file(self.RAW_VAL_FILE_NAME)
        le = LabelEncoder()
        le.fit(val_data['labels'])
        val_data['labels'] = le.transform(val_data['labels']).tolist()
        self._pickle_data(val_data, self.get_pickle_file_name('val'))
        del val_data

        print("Processing train data")
        train_data = self._get_data_from_file(self.RAW_TRAIN_FILE_NAME)
        train_data['labels'] = le.transform(train_data['labels']).tolist()
        self._pickle_data(train_data, self.get_pickle_file_name('train'))
        del train_data
        
        print("Processing test data")
        test_data = self._get_data_from_file(self.RAW_TEST_FILE_NAME)
        test_data['labels'] = le.transform(test_data['labels']).tolist()
        self._pickle_data(test_data, self.get_pickle_file_name('test'))
        del test_data

    def _get_data_from_file(self, file_name: str):
        raw_sentences = []
        labels = []
        with open(file_name) as f:
            for line in tqdm(f.readlines()):
                dt = json.loads(line)
                sentence = " ".join(dt['sent'])

                # add subject markup
                new_sub = SUB_START_CHAR + dt['sub'].replace('_', '') + SUB_END_CHAR # TODO keep _ or not?
                new_obj = OBJ_START_CHAR + dt['obj'].replace('_', '') + OBJ_END_CHAR
                sentence = sentence.replace(dt['sub'], new_sub).replace(dt['obj'], new_obj)
                raw_sentences.append(sentence)
                labels.append(dt['rel'])
        data = self.tokenizer(raw_sentences, truncation=True, padding=True)
        data['labels'] = labels
        return data


---> Preprocessing SemEval2010Task8 dataset <---
Processing training data

Processing test data

Encoding labels to integers
Splitting train & validate data
Saving to pickle file /media/dthung1602/WORKING/bert-relation-extraction/data/processed/semeval2010task8_train.pkl
Saving to pickle file /media/dthung1602/WORKING/bert-relation-extraction/data/processed/semeval2010task8_val.pkl
Saving to pickle file /media/dthung1602/WORKING/bert-relation-extraction/data/processed/semeval2010task8_test.pkl


HBox(children=(FloatProgress(value=0.0, max=8000.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=2717.0), HTML(value='')))

Factory function to get preprocessor:

In [None]:
def get_preprocessor_class(dataset_name: str):
    return globals()[f'{dataset_name}Preprocessor']

def get_preprocessor(dataset_name: str)-> AbstractPreprocessor:
    bert_model_info = BERT_VARIANT_MAPPING[BERT_VARIANT]
    bert_pretrain_weight = bert_model_info['pretrain_weight']
    tokenizer = bert_model_info['tokenizer'].from_pretrained(bert_pretrain_weight)
    preprocessors_class = get_preprocessor_class(dataset_name)
    return preprocessors_class(tokenizer)

Preprocess data:

In [9]:
preprocessor = get_preprocessor('GIDS')
preprocessor.preprocess_data(reprocess=False)


---> Preprocessing GIDS dataset <---
Processing validate data

Saving to pickle file /media/dthung1602/WORKING/bert-relation-extraction/data/processed/gids_val.pkl
Processing train data

Saving to pickle file /media/dthung1602/WORKING/bert-relation-extraction/data/processed/gids_train.pkl
Processing test data

Saving to pickle file /media/dthung1602/WORKING/bert-relation-extraction/data/processed/gids_test.pkl


HBox(children=(FloatProgress(value=0.0, max=1864.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=11297.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=5663.0), HTML(value='')))

## Model

### Dataset

In [None]:
class GenericDataset(Dataset):

    def __init__(self, dataset_name: str, subset: str):
        preprocessor_class = get_preprocessor_class(dataset_name)
        if subset not in ['train', 'val', 'test']:
            raise ValueError('subset must be train, val or test')
        with open(preprocessor_class.get_pickle_file_name(subset), 'rb') as f:
            self.data = pickle.load(f)

    def __getitem__(self, index: int) -> Tuple[Tensor, Tensor, Tensor]:
        return (torch.tensor(self.data['input_ids'][index]),
                torch.tensor(self.data['attention_mask'][index]),
                torch.tensor(self.data['labels'][index]))

    def __len__(self) -> int:
        return len(self.data['label'])

### Torch Lightning Module

In [None]:
class BERTModule(LightningModule):

    def __init__(self, bert_variant, dataset_name, batch_size, learning_rate):
        print("---> Start building model <----")
        super().__init__()
        self.save_hyperparameters()

        bert_info = BERT_VARIANT_MAPPING[bert_variant]
        bert_model_class = bert_info['model']
        bert_pretrain_weight = bert_info['pretrain_weight']
        self.bert = bert_model_class.from_pretrained(bert_pretrain_weight, output_attentions=True)

        dataset_info = DATASET_MAPPING[dataset_name]
        self.num_classes = dataset_info['num_classes']
        self.linear = nn.Linear(self.bert.config.hidden_size, self.num_classes)

        print("Done building model\n")

    def on_train_start(self) -> None:
        print("\n---> Start training <----")

    def train_dataloader(self) -> DataLoader:
        return self.__get_dataloader('train')

    def val_dataloader(self) -> DataLoader:
        return self.__get_dataloader('val')

    def test_dataloader(self) -> DataLoader:
        return self.__get_dataloader('test')

    def __get_dataloader(self, subset: str) -> DataLoader:
        print(f"Loading {subset} data")
        return DataLoader(
            GenericDataset(self.hparams.dataset_name, subset),
            batch_size=self.hparams.batch_size,
            shuffle=(subset == 'train')
        )

    def configure_optimizers(self) -> Optimizer:
        return AdamW(
            [p for p in self.parameters() if p.requires_grad],
            lr=self.hparams.learning_rate,
            eps=1e-08
        )

    def forward(self, input_ids, attention_mask) -> Tensor:
        bert_output, _ = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        bert_cls = bert_output[:, 0]
        logits = self.linear(bert_cls)
        return logits

    def training_step(self, batch, batch_nb) -> dict:
        input_ids, attention_mask, label = batch

        y_hat = self(input_ids, attention_mask)

        loss = F.cross_entropy(y_hat, label)
        tensorboard_logs = {'train_loss': loss}

        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_nb) -> dict:
        input_ids, attention_mask, label = batch

        y_hat = self(input_ids, attention_mask)

        loss = F.cross_entropy(y_hat, label)

        a, y_hat = torch.max(y_hat, dim=1)
        y_hat = y_hat.cpu()
        label = label.cpu()

        return {
            'val_loss': loss,
            'val_pre': torch.tensor(precision_score(label, y_hat, average='micro')),
            'val_rec': torch.tensor(recall_score(label, y_hat, average='micro')),
            'val_acc': torch.tensor(accuracy_score(label, y_hat)),
            'val_f1': torch.tensor(f1_score(label, y_hat, average='micro'))
        }

    def validation_epoch_end(self, outputs) -> dict:
        avg_val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_val_pre = torch.stack([x['val_pre'] for x in outputs]).mean()
        avg_val_rec = torch.stack([x['val_rec'] for x in outputs]).mean()
        avg_val_acc = torch.stack([x['val_acc'] for x in outputs]).mean()
        avg_val_f1 = torch.stack([x['val_f1'] for x in outputs]).mean()

        tensorboard_logs = {
            'val_loss': avg_val_loss,
            'avg_val_pre': avg_val_pre,
            'avg_val_rec': avg_val_rec,
            'avg_val_acc': avg_val_acc,
            'avg_val_f1': avg_val_f1,
        }
        return {'val_loss': avg_val_loss, 'progress_bar': tensorboard_logs}

    def test_step(self, batch, batch_nb) -> dict:
        input_ids, attention_mask, label = batch

        y_hat = self(input_ids, attention_mask)

        a, y_hat = torch.max(y_hat, dim=1)
        y_hat = y_hat.cpu()
        label = label.cpu()
        test_pre = precision_score(label, y_hat, average='micro')
        test_rec = recall_score(label, y_hat, average='micro')
        test_acc = accuracy_score(label, y_hat)
        test_f1 = f1_score(label, y_hat, average='micro')

        return {
            'test_pre': torch.tensor(test_pre),
            'test_rec': torch.tensor(test_rec),
            'test_acc': torch.tensor(test_acc),
            'test_f1': torch.tensor(test_f1),
        }

    def test_epoch_end(self, outputs) -> dict:
        avg_test_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        avg_test_pre = torch.stack([x['test_pre'] for x in outputs]).mean()
        avg_test_rec = torch.stack([x['test_rec'] for x in outputs]).mean()
        avg_test_acc = torch.stack([x['test_acc'] for x in outputs]).mean()
        avg_test_f1 = torch.stack([x['test_f1'] for x in outputs]).mean()

        tensorboard_logs = {
            'test_loss': avg_test_loss,
            'avg_test_pre': avg_test_pre,
            'avg_test_rec': avg_test_rec,
            'avg_test_acc': avg_test_acc,
            'avg_test_f1': avg_test_f1,
        }
        return {'test_loss': avg_test_loss, 'progress_bar': tensorboard_logs}



## Trainer class

In [None]:
GPUS = 1
MIN_EPOCHS = 1
MAX_EPOCHS = 4

trainer = LightningTrainer(
    gpus=GPUS,
    min_epochs=MIN_EPOCHS,
    max_epochs=MAX_EPOCHS,
    default_root_dir=CHECKPOINT_DIR,
)

## Training

In [None]:
BATCH_SIZE = 32
LEARNING_RATE = 2e-05

model = BERTModule(
    bert_variant=BERT_VARIANT,
    dataset_name=DATASET_NAME,
    batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE
)

trainer.fit(model)

## Testing

In [None]:
trainer.test(model)