# Relation extraction with BERT

---

The goal of this repo is to show how to use [BERT](https://arxiv.org/abs/1810.04805)
to [extract relation](https://en.wikipedia.org/wiki/Relationship_extraction) from text.

Used libraries:
- [Transformers](https://huggingface.co/transformers/index.html)
- [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/)

Used datasets:
- SemEval 2010 Task 8 - [paper](https://arxiv.org/pdf/1911.10422.pdf) - [download](https://github.com/sahitya0000/Relation-Classification/blob/master/corpus/SemEval2010_task8_all_data.zip?raw=true)
-  Google IISc Distant Supervision (GIDS) - [paper](https://arxiv.org/pdf/1804.06987.pdf) - [download](https://drive.google.com/open?id=1gTNAbv8My2QDmP-OHLFtJFlzPDoCG4aI)

## Install dependencies

This project uses [Python 3.7+](https://www.python.org/downloads/release/python-378/)

In [13]:
!pip install requests==2.23.0 numpy==1.18.5 pandas==1.0.3 \
    scikit-learn==0.23.1 pytorch-lightning==0.8.4 torch==1.5.1 \
    transformers==3.0.2 sklearn==0.0 tqdm==4.45.0 




## Import needed modules

In [1]:
import json
import math
import multiprocessing
import os
import shutil
import zipfile
from abc import ABC, abstractmethod
from urllib.parse import urlparse

import requests
import torch
from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer as LightningTrainer
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from torch import Tensor, nn
from torch.nn import functional as F
from torch.optim import AdamW
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, IterableDataset
from tqdm.auto import tqdm
from transformers import *

## Define constants

Change the following constant to `False` if you are not running on Kaggle environment:

In [2]:
KAGGLE = False

Other constants:

In [12]:
# --- Directory ---
ROOT_DIR = os.path.abspath('.')
RAW_DATA_DIR = os.path.join(ROOT_DIR, '../input') if KAGGLE else os.path.join(ROOT_DIR, 'data/raw')
PROCESSED_DATA_DIR = os.path.join(ROOT_DIR, 'data/processed') 
CHECKPOINT_DIR = os.path.join(ROOT_DIR, 'checkpoint')

# --- Datasets ---
DATASET_MAPPING = {
    'SemEval2010Task8': {
        'dir': os.path.join(RAW_DATA_DIR,'semeval2010-task-8'),
        'extract_dir': os.path.join(RAW_DATA_DIR, 'SemEval2010_task8_all_data'),
        'url': 'https://github.com/sahitya0000/Relation-Classification/'
               'blob/master/corpus/SemEval2010_task8_all_data.zip?raw=true',
        'num_classes': 10,
    },
    'GIDS': {
        'dir': os.path.join(RAW_DATA_DIR,'gids-dataset'),
        'extract_dir': os.path.join(RAW_DATA_DIR, 'gids_data'),
        'url': 'https://drive.google.com/uc?id=1gTNAbv8My2QDmP-OHLFtJFlzPDoCG4aI&export=download',
        'num_classes': 5
    },
    'NYT': {
        'dir': os.path.join(RAW_DATA_DIR,'nyt-relation-extraction'),
        'extract_dir': os.path.join(RAW_DATA_DIR, 'riedel_data'),
        'url': 'https://drive.google.com/uc?id=1D7bZPvrSAbIPaFSG7ZswYQcPA3tmouCw&export=download',
        'num_classes': 53 # TODO merge some relations
    }
}
DATASET_NAME = 'GIDS'

# --- BERT ---
SUB_START_CHAR = '{'
SUB_END_CHAR = '}'
OBJ_START_CHAR = '['
OBJ_END_CHAR = ']'

# --- BERT Model ---
# See https://huggingface.co/transformers/pretrained_models.html for the full list

BERT_VARIANT_MAPPING = {
    'bert': {
        'model': BertModel,
        'tokenizer': BartTokenizerFast,
        'pretrain_weight': 'bert-base-uncased',
        'hidden_state_size': 768
    },
    'distilbert': {
        'model': DistilBertModel,
        'tokenizer': DistilBertTokenizerFast,
        'pretrain_weight': 'distilbert-base-uncased',
        'hidden_state_size': 768
    },
    'roberta': {
        'model': RobertaModel,
        'tokenizer': RobertaTokenizerFast,
        'pretrain_weight': 'roberta-base',
        'hidden_state_size': 768
    },
}
BERT_VARIANT = 'distilbert'

## Download data

This part **CAN BE SKIPPED** if this notebook is running on Kaggle environment since the dataset has already been included.

First, we install `gdown` to download files from Google Drive

In [17]:
!pip install gdown==3.11.1
import gdown



Some download util functions:

In [18]:
def download_from_url(url: str, save_path: str, chunk_size: int = 2048) -> None:
    with open(save_path, "wb") as f:
        print(f"Downloading...\nFrom: {url}\nTo: {save_path}")
        response = requests.get(url, stream=True)
        for data in tqdm(response.iter_content(chunk_size=chunk_size)):
            f.write(data)

def download_from_google_drive(url: str, save_path: str) -> None:
    gdown.download(url, save_path, use_cookies=False)

def extract_zip(zip_file_path: str, extract_dir: str, remove_zip_file=True):
    # TODO https://stackoverflow.com/questions/4341584/extract-zipfile-using-python-display-progress-percentage
    with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
        print("Extracting to " + extract_dir)
        zip_ref.extractall(extract_dir)

    if remove_zip_file:
        print("Removing zip file")
        os.unlink(zip_file_path)

The download function itself:

In [19]:
def download(dataset_name, dataset_url, dataset_dir, dataset_extract_dir, force_redownload: bool):
    print(f"\n---> Downloading dataset {dataset_name} <---")
    
    # create raw data dir
    if not os.path.exists(RAW_DATA_DIR):
        print("Creating raw data directory " + RAW_DATA_DIR)
        os.makedirs(RAW_DATA_DIR)
    
    # check data has been downloaded
    if os.path.exists(dataset_dir):
        if force_redownload:
            print(f"Removing old raw data {dataset_dir}")
            shutil.rmtree(dataset_dir)
        else:
            print(f"Directory {dataset_dir} exists, skip downloading.")
            return


    # download
    tmp_file_path = os.path.join(RAW_DATA_DIR, dataset_name + '.zip')
    if urlparse(dataset_url).netloc == 'drive.google.com':
        download_from_google_drive(dataset_url, tmp_file_path)
    else:
        download_from_url(dataset_url, tmp_file_path)

    # unzip
    extract_zip(tmp_file_path, RAW_DATA_DIR)

    # rename
    os.rename(dataset_extract_dir, dataset_dir)

Download all datasets:

In [20]:
def download_all_dataset():
    for dataset_name, dataset_info in DATASET_MAPPING.items():
        download(
            dataset_name,
            dataset_url=dataset_info['url'],
            dataset_dir=dataset_info['dir'],
            dataset_extract_dir=dataset_info['extract_dir'],
            force_redownload=False
        )

download_all_dataset()


---> Downloading dataset SemEval2010Task8 <---
Directory /media/dthung1602/WORKING/bert-relation-extraction/data/raw/semeval2010-task-8 exists, skip downloading.

---> Downloading dataset GIDS <---
Directory /media/dthung1602/WORKING/bert-relation-extraction/data/raw/gids-dataset exists, skip downloading.

---> Downloading dataset NYT <---
Directory /media/dthung1602/WORKING/bert-relation-extraction/data/raw/nyt-relation-extraction exists, skip downloading.


## Preprocess

The abstract preprocessor

In [16]:
class AbstractPreprocessor(ABC):
    DATASET_NAME = ''
    RANDOM_SEED = 2020
    VAL_DATA_PROPORTION = 0.2

    def __init__(self, tokenizer: PreTrainedTokenizer):
        self.tokenizer = tokenizer

    def preprocess_data(self, reprocess: bool):
        print(f"\n---> Preprocessing {self.DATASET_NAME} dataset <---")
        
        # create processed data dir
        if not os.path.exists(PROCESSED_DATA_DIR):
            print("Creating processed data directory " + PROCESSED_DATA_DIR)
            os.makedirs(PROCESSED_DATA_DIR)

        # stop preprocessing if file existed
        pickled_file_names = [self.get_json_file_name(k) for k in ('train', 'val', 'test')]
        existed_files = [fn for fn in pickled_file_names if os.path.exists(fn)]
        if existed_files:
            file_text = "- " + "\n- ".join(existed_files)
            if not reprocess:
                print("The following files already exist:")
                print(file_text)
                print("Preprocessing is skipped. See option --reprocess.")
                return
            else:
                print("The following files will be overwritten:")
                print(file_text)

        self._preprocess_data()

    @abstractmethod
    def _preprocess_data(self):
        pass

    def _train_val_split(self, original_data):
        def get_sample(d, idxs):
            return [d[i] for i in idxs]

        k = list(original_data.keys())[0]
        indies = list(range(len(original_data[k])))
        train_indies, val_indies = train_test_split(
            indies,
            test_size=self.VAL_DATA_PROPORTION,
            random_state=self.RANDOM_SEED
        )
        train_data = {k: get_sample(v, train_indies) for k, v in original_data.items()}
        val_data = {k: get_sample(v, val_indies) for k, v in original_data.items()}

        return train_data, val_data

    def _append_data_to_file(self, data: dict, file):
        keys = list(data.keys())
        lines = ""
        for values in zip(*data.values()):
            tmp = {k: v for k, v in zip(keys, values)}
            lines += json.dumps(tmp) + "\n"
        file.write(lines)

    @classmethod
    def get_json_file_name(cls, key: str):
        return os.path.join(PROCESSED_DATA_DIR, f'{cls.DATASET_NAME.lower()}_{key}.json')

Custom preprocessor for each dataset:

In [17]:
class SemEval2010Task8Preprocessor(AbstractPreprocessor):
    DATASET_NAME = 'SemEval2010Task8'
    RAW_TRAIN_FILE_NAME = os.path.join(DATASET_MAPPING['SemEval2010Task8']['dir'],
                                       'SemEval2010_task8_training/TRAIN_FILE.TXT')
    RAW_TEST_FILE_NAME = os.path.join(DATASET_MAPPING['SemEval2010Task8']['dir'],
                                      'SemEval2010_task8_testing_keys/TEST_FILE_FULL.TXT')
    RAW_TRAIN_DATA_SIZE = 8000
    RAW_TEST_DATA_SIZE = 2717

    def _preprocess_data(self):
        print("Processing training data")
        train_data = self._get_data_from_file(
            self.RAW_TRAIN_FILE_NAME,
            self.RAW_TRAIN_DATA_SIZE
        )

        print("Processing test data")
        test_data = self._get_data_from_file(
            self.RAW_TEST_FILE_NAME,
            self.RAW_TEST_DATA_SIZE
        )

        print("Encoding labels to integers")
        le = LabelEncoder()
        le.fit(train_data['label'])
        train_data['label'] = le.transform(train_data['label']).tolist()
        test_data['label'] = le.transform(test_data['label']).tolist()

        print("Splitting train & validate data")
        train_data, val_data = self._train_val_split(train_data)

        print("Saving to json files")
        with open(self.get_json_file_name('train'), 'w') as f:
            self._append_data_to_file(train_data, f)
        with open(self.get_json_file_name('val'), 'w') as f:
            self._append_data_to_file(val_data, f)
        with open(self.get_json_file_name('test'), 'w') as f:
            self._append_data_to_file(test_data, f)

    def _get_data_from_file(self, file_name: str, dataset_size: int):
        raw_sentences = []
        labels = []
        with open(file_name) as f:
            for _ in tqdm(range(dataset_size)):
                raw_sentences.append(self._process_sentence(f.readline()))
                labels.append(self._process_label(f.readline()))
                f.readline()
                f.readline()
        data = self.tokenizer(raw_sentences, truncation=True, padding=True)
        data['label'] = labels
        return data

    def _process_sentence(self, sentence: str):
        # TODO distinguish e1 e2 sub obj
        return sentence.split("\t")[1][1:-2] \
            .replace("<e1>", SUB_START_CHAR) \
            .replace("</e1>", SUB_END_CHAR) \
            .replace("<e2>", OBJ_START_CHAR) \
            .replace("</e2>", OBJ_END_CHAR)

    def _process_label(self, label: str):
        return label[:-8]


class LargeDatasetPreprocessor(AbstractPreprocessor):
    PROCESS_BATCH_SIZE = 2**12

    def _preprocess_data(self):
        pass

    def _process_batch(self, le: LabelEncoder, in_file):
        raw_sentences = []
        labels = []
        for _ in range(self.PROCESS_BATCH_SIZE):
            dt = in_file.readline()
            if dt == "": break # EOF
            dt = json.loads(dt)

            # add subject markup
            sentence = " ".join(dt['sent'])
            new_sub = SUB_START_CHAR + dt['sub'].replace('_', '') + SUB_END_CHAR # TODO keep _ or not?
            new_obj = OBJ_START_CHAR + dt['obj'].replace('_', '') + OBJ_END_CHAR
            sentence = sentence.replace(dt['sub'], new_sub).replace(dt['obj'], new_obj)
            raw_sentences.append(sentence)
            labels.append(dt['rel'])

        data = self.tokenizer(raw_sentences, truncation=True, padding=True)
        data['labels'] = le.fit_transform(labels).tolist()

        return data

    def _process_subset(self, le: LabelEncoder, in_file_name, out_file_name, data_size):
        with open(in_file_name) as in_file, open(out_file_name, 'w') as out_file:
            for _ in tqdm(range(math.ceil(data_size/self.PROCESS_BATCH_SIZE))):
                data = self._process_batch(le, in_file)
                self._append_data_to_file(data, out_file)

class GIDSPreprocessor(LargeDatasetPreprocessor):
    DATASET_NAME = 'GIDS'
    RAW_TRAIN_FILE_NAME = os.path.join(DATASET_MAPPING['GIDS']['dir'], 'gids_train.json')
    RAW_VAL_FILE_NAME = os.path.join(DATASET_MAPPING['GIDS']['dir'], 'gids_dev.json')
    RAW_TEST_FILE_NAME = os.path.join(DATASET_MAPPING['GIDS']['dir'], 'gids_test.json')
    TRAIN_SIZE = 11297
    VAL_SIZE = 1864
    TEST_SIZE = 5663
    PROCESS_BATCH_SIZE = 1024

    def _preprocess_data(self):
        le = LabelEncoder()
        
        print("Process train dataset")
        self._process_subset(
            le,
            self.RAW_TRAIN_FILE_NAME,
            self.get_json_file_name('train'),
            self.TRAIN_SIZE
        )
        
        print("Process val dataset")
        self._process_subset(
            le,
            self.RAW_VAL_FILE_NAME,
            self.get_json_file_name('val'),
            self.VAL_SIZE
        )
        
        print("Process test dataset")
        self._process_subset(
            le, 
            self.RAW_TEST_FILE_NAME, 
            self.get_json_file_name('test'),
            self.TEST_SIZE
        )

class NYTPreprocessor(LargeDatasetPreprocessor):
    DATASET_NAME = 'NYT'
    RAW_TRAIN_FILE_NAME = os.path.join(DATASET_MAPPING['NYT']['dir'], 'riedel_train.json')
    RAW_TEST_FILE_NAME = os.path.join(DATASET_MAPPING['NYT']['dir'], 'riedel_test.json')
    TRAIN_SIZE = 570084
    TEST_SIZE = 172448
    PROCESS_BATCH_SIZE = 4096 * 4

    def _preprocess_data(self):
        le = LabelEncoder()
        
        print("Process train & val dataset")
        with open(self.RAW_TRAIN_FILE_NAME) as in_file,\
                open(self.get_json_file_name('train'), 'w') as train_file,\
                open(self.get_json_file_name('val'), 'w') as val_file:
                    for _ in tqdm(range(math.ceil(self.TRAIN_SIZE / self.PROCESS_BATCH_SIZE))):
                        data = self._process_batch(le, in_file)
                        train_data, val_data = self._train_val_split(data)
                        self._append_data_to_file(train_data, train_file)
                        self._append_data_to_file(val_data, val_file)

        print("Process test dataset")
        self._process_subset(
            le, 
            self.RAW_TEST_FILE_NAME, 
            self.get_json_file_name('test'),
            self.TEST_SIZE
        )


Factory function to get preprocessor:

In [5]:
def get_preprocessor_class(dataset_name: str):
    return globals()[f'{dataset_name}Preprocessor']

def get_preprocessor(dataset_name: str)-> AbstractPreprocessor:
    bert_model_info = BERT_VARIANT_MAPPING[BERT_VARIANT]
    bert_pretrain_weight = bert_model_info['pretrain_weight']
    tokenizer = bert_model_info['tokenizer'].from_pretrained(bert_pretrain_weight)
    preprocessors_class = get_preprocessor_class(dataset_name)
    return preprocessors_class(tokenizer)

Preprocess data:

In [19]:
preprocessor = get_preprocessor('SemEval2010Task8')
preprocessor.preprocess_data(reprocess=True)


---> Preprocessing SemEval2010Task8 dataset <---
Processing training data

Processing test data

Encoding labels to integers
Splitting train & validate data
Saving to json files


HBox(children=(FloatProgress(value=0.0, max=8000.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=2717.0), HTML(value='')))

## Model

### Dataset

In [None]:
class GenericDataset(IterableDataset):

    def __init__(self, dataset_name: str, subset: str):
        preprocessor_class = get_preprocessor_class(dataset_name)
        if subset not in ['train', 'val', 'test']:
            raise ValueError('subset must be train, val or test')
        self.file = open(preprocessor_class.get_json_file_name(subset))

    def __del__(self):
        self.file.close()

    def __iter__(self):
        def get_data():
            for line in self.file:
                yield json.loads(line)

        return get_data()

### Torch Lightning Module

In [None]:
class BERTModule(LightningModule):

    def __init__(self, bert_variant, dataset_name, batch_size, learning_rate):
        super().__init__()
        self.save_hyperparameters()

        bert_info = BERT_VARIANT_MAPPING[bert_variant]
        bert_model_class = bert_info['model']
        bert_pretrain_weight = bert_info['pretrain_weight']
        self.bert = bert_model_class.from_pretrained(bert_pretrain_weight, output_attentions=True)

        dataset_info = DATASET_MAPPING[dataset_name]
        self.num_classes = dataset_info['num_classes']
        self.linear = nn.Linear(self.bert.config.hidden_size, self.num_classes)

    def train_dataloader(self) -> DataLoader:
        return self.__get_dataloader('train')

    def val_dataloader(self) -> DataLoader:
        return self.__get_dataloader('val')

    def test_dataloader(self) -> DataLoader:
        return self.__get_dataloader('test')

    def __get_dataloader(self, subset: str) -> DataLoader:
        print(f"Loading {subset} data")
        return DataLoader(
            GenericDataset(self.hparams.dataset_name, subset),
            batch_size=self.hparams.batch_size,
            shuffle=(subset == 'train'),
            num_workers=multiprocessing.cpu_count() + 1
        )

    def configure_optimizers(self) -> Optimizer:
        return AdamW(
            [p for p in self.parameters() if p.requires_grad],
            lr=self.hparams.learning_rate,
            eps=1e-08
        )

    def forward(self, input_ids, attention_mask) -> Tensor:
        bert_output, _ = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        bert_cls = bert_output[:, 0]
        logits = self.linear(bert_cls)
        return logits

    def training_step(self, batch, batch_nb) -> dict:
        input_ids, attention_mask, label = batch

        y_hat = self(input_ids, attention_mask)

        loss = F.cross_entropy(y_hat, label)
        tensorboard_logs = {'train_loss': loss}

        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_nb) -> dict:
        input_ids, attention_mask, label = batch

        y_hat = self(input_ids, attention_mask)

        loss = F.cross_entropy(y_hat, label)

        a, y_hat = torch.max(y_hat, dim=1)
        y_hat = y_hat.cpu()
        label = label.cpu()

        return {
            'val_loss': loss,
            'val_pre': torch.tensor(precision_score(label, y_hat, average='micro')),
            'val_rec': torch.tensor(recall_score(label, y_hat, average='micro')),
            'val_acc': torch.tensor(accuracy_score(label, y_hat)),
            'val_f1': torch.tensor(f1_score(label, y_hat, average='micro'))
        }

    def validation_epoch_end(self, outputs) -> dict:
        avg_val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_val_pre = torch.stack([x['val_pre'] for x in outputs]).mean()
        avg_val_rec = torch.stack([x['val_rec'] for x in outputs]).mean()
        avg_val_acc = torch.stack([x['val_acc'] for x in outputs]).mean()
        avg_val_f1 = torch.stack([x['val_f1'] for x in outputs]).mean()

        tensorboard_logs = {
            'val_loss': avg_val_loss,
            'avg_val_pre': avg_val_pre,
            'avg_val_rec': avg_val_rec,
            'avg_val_acc': avg_val_acc,
            'avg_val_f1': avg_val_f1,
        }
        return {'val_loss': avg_val_loss, 'progress_bar': tensorboard_logs}

    def test_step(self, batch, batch_nb) -> dict:
        input_ids, attention_mask, label = batch

        y_hat = self(input_ids, attention_mask)
        a, y_hat = torch.max(y_hat, dim=1)
        
        y_hat = y_hat.cpu()
        label = label.cpu()
        
        test_pre = precision_score(label, y_hat, average='micro')
        test_rec = recall_score(label, y_hat, average='micro')
        test_acc = accuracy_score(label, y_hat)
        test_f1 = f1_score(label, y_hat, average='micro')

        return {
            'test_pre': torch.tensor(test_pre),
            'test_rec': torch.tensor(test_rec),
            'test_acc': torch.tensor(test_acc),
            'test_f1': torch.tensor(test_f1),
        }

    def test_epoch_end(self, outputs) -> dict:
        avg_test_pre = torch.stack([x['test_pre'] for x in outputs]).mean()
        avg_test_rec = torch.stack([x['test_rec'] for x in outputs]).mean()
        avg_test_acc = torch.stack([x['test_acc'] for x in outputs]).mean()
        avg_test_f1 = torch.stack([x['test_f1'] for x in outputs]).mean()

        tensorboard_logs = {
            'avg_test_pre': avg_test_pre,
            'avg_test_rec': avg_test_rec,
            'avg_test_acc': avg_test_acc,
            'avg_test_f1': avg_test_f1,
        }
        return {'progress_bar': tensorboard_logs}

## Trainer

In [None]:
GPUS = 1
MIN_EPOCHS = 1
MAX_EPOCHS = 1

trainer = LightningTrainer(
    gpus=GPUS,
    min_epochs=MIN_EPOCHS,
    max_epochs=MAX_EPOCHS,
    default_root_dir=CHECKPOINT_DIR,
)

## Training

In [None]:
BATCH_SIZE = 64
LEARNING_RATE = 2e-05

model = BERTModule(
    bert_variant=BERT_VARIANT,
    dataset_name=DATASET_NAME,
    batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE
)

In [None]:
trainer.fit(model)

## Testing

In [None]:
trainer.test(model)