# Relation extraction with BERT

---

The goal of this notebook is to show how to use [BERT](https://arxiv.org/abs/1810.04805)
to [extract relation](https://en.wikipedia.org/wiki/Relationship_extraction) from text.

Used libraries:
- [PyTorch](https://pytorch.org/)
- [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/)
- [Transformers](https://huggingface.co/transformers/index.html)

Used datasets:
- SemEval 2010 Task 8 - [paper](https://arxiv.org/pdf/1911.10422.pdf) - [download](https://github.com/sahitya0000/Relation-Classification/blob/master/corpus/SemEval2010_task8_all_data.zip?raw=true)
- Google IISc Distant Supervision (GIDS) - [paper](https://arxiv.org/pdf/1804.06987.pdf) - [download](https://drive.google.com/open?id=1gTNAbv8My2QDmP-OHLFtJFlzPDoCG4aI)
- Riedel's New York Times - [paper](https://www.researchgate.net/publication/220698997_Modeling_Relations_and_Their_Mentions_without_Labeled_Text) - [download](https://drive.google.com/uc?id=1D7bZPvrSAbIPaFSG7ZswYQcPA3tmouCw&export=download)

## Install dependencies

This project uses [Python 3.7+](https://www.python.org/downloads/release/python-378/)

In [1]:
!pip install requests==2.23.0 numpy==1.18.5 pandas==1.0.3 \
    scikit-learn==0.23.1 pytorch-lightning==0.8.4 torch==1.5.1 \
    transformers==3.0.2 sklearn==0.0 tqdm==4.45.0 


Collecting pytorch-lightning==0.8.4
  Downloading pytorch_lightning-0.8.4-py3-none-any.whl (304 kB)
[K     |████████████████████████████████| 304 kB 2.8 MB/s eta 0:00:01
Collecting transformers==3.0.2
  Downloading transformers-3.0.2-py3-none-any.whl (769 kB)
[K     |████████████████████████████████| 769 kB 8.1 MB/s eta 0:00:01
Collecting tokenizers==0.8.1.rc1
  Downloading tokenizers-0.8.1rc1-cp37-cp37m-manylinux1_x86_64.whl (3.0 MB)
[K     |████████████████████████████████| 3.0 MB 17.6 MB/s eta 0:00:01
[31mERROR: allennlp 1.0.0 has requirement transformers<2.12,>=2.9, but you'll have transformers 3.0.2 which is incompatible.[0m
Installing collected packages: pytorch-lightning, tokenizers, transformers
  Attempting uninstall: tokenizers
    Found existing installation: tokenizers 0.7.0
    Uninstalling tokenizers-0.7.0:
      Successfully uninstalled tokenizers-0.7.0
  Attempting uninstall: transformers
    Found existing installation: transformers 2.11.0
    Uninstalling transfo

## Import needed modules

In [1]:
import gc
import json
import math
import os
import shutil
import zipfile
from abc import ABC, abstractmethod
from typing import TextIO, Iterable
from urllib.parse import urlparse

import pandas as pd
import requests
import torch
from pandas import DataFrame
from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer as LightningTrainer
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from torch import Tensor, nn
from torch.nn import functional as F
from torch.optim import AdamW
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, IterableDataset
from tqdm.auto import tqdm
from transformers import *

## Define constants

In [2]:
# --- Directory ---
ROOT_DIR = os.path.abspath('.')
PROCESSED_DATA_DIR = os.path.join(ROOT_DIR, 'data/processed') 
METADATA_FILE_NAME = os.path.join(PROCESSED_DATA_DIR, 'metadata.json')
CHECKPOINT_DIR = os.path.join(ROOT_DIR, 'checkpoint')

# in local environment
RAW_DATA_DIR =  os.path.join(ROOT_DIR, 'data/raw')

# in Kaggle environment
# 3 datasets should already been added to the notebook
RAW_DATA_DIR = os.path.join(ROOT_DIR, '../input')

# --- Datasets ---
DATASET_MAPPING = {
    'SemEval2010Task8': {
        'dir': os.path.join(RAW_DATA_DIR,'semeval2010-task-8'),
        'extract_dir': os.path.join(RAW_DATA_DIR, 'SemEval2010_task8_all_data'),
        'url': 'https://github.com/sahitya0000/Relation-Classification/'
               'blob/master/corpus/SemEval2010_task8_all_data.zip?raw=true',
        'num_classes': 10,
    },
    'GIDS': {
        'dir': os.path.join(RAW_DATA_DIR,'gids-dataset'),
        'extract_dir': os.path.join(RAW_DATA_DIR, 'gids_data'),
        'url': 'https://drive.google.com/uc?id=1gTNAbv8My2QDmP-OHLFtJFlzPDoCG4aI&export=download',
        'num_classes': 5
    },
    'NYT': {
        'dir': os.path.join(RAW_DATA_DIR,'nyt-relation-extraction'),
        'extract_dir': os.path.join(RAW_DATA_DIR, 'riedel_data'),
        'url': 'https://drive.google.com/uc?id=1D7bZPvrSAbIPaFSG7ZswYQcPA3tmouCw&export=download',
        'num_classes': 53 # TODO merge some relations
    }
}

# change this variable to switch dataset in later tasks
DATASET_NAME = 'GIDS'

# --- BERT ---
SUB_START_CHAR = '{'
SUB_END_CHAR = '}'
OBJ_START_CHAR = '['
OBJ_END_CHAR = ']'

# --- BERT Model ---
# See https://huggingface.co/transformers/pretrained_models.html for the full list

BERT_VARIANT_MAPPING = {
    'bert': {
        'model': BertModel,
        'tokenizer': BartTokenizerFast,
        'pretrain_weight': 'bert-base-uncased',
        'hidden_state_size': 768
    },
    'distilbert': {
        'model': DistilBertModel,
        'tokenizer': DistilBertTokenizerFast,
        'pretrain_weight': 'distilbert-base-uncased',
        'hidden_state_size': 768
    },
    'roberta': {
        'model': RobertaModel,
        'tokenizer': RobertaTokenizerFast,
        'pretrain_weight': 'roberta-base',
        'hidden_state_size': 768
    },
}

# change this variable to switch BERT variant in later tasks
BERT_VARIANT = 'distilbert'

## Download data

This part **CAN BE SKIPPED** if this notebook is running on Kaggle environment since the dataset has already been included.

First, we install `gdown` to download files from Google Drive

In [None]:
!pip install gdown==3.11.1
import gdown

Some download util functions:

In [None]:
def download_from_url(url: str, save_path: str, chunk_size: int = 2048):
    with open(save_path, "wb") as f:
        print(f"Downloading...\nFrom: {url}\nTo: {save_path}")
        response = requests.get(url, stream=True)
        for data in tqdm(response.iter_content(chunk_size=chunk_size)):
            f.write(data)

def download_from_google_drive(url: str, save_path: str):
    gdown.download(url, save_path, use_cookies=False)

def extract_zip(zip_file_path: str, extract_dir: str, remove_zip_file: bool = True):
    with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
        print("Extracting to " + extract_dir)
        for member in tqdm(zip_ref.infolist()):
            zip_ref.extract(member, extract_dir)

    if remove_zip_file:
        print("Removing zip file")
        os.unlink(zip_file_path)

The download function itself:

In [None]:
def download(dataset_name: str, dataset_url: str, dataset_dir: str, dataset_extract_dir: str, force_redownload: bool):
    print(f"\n---> Downloading dataset {dataset_name} <---")
    
    # create raw data dir
    if not os.path.exists(RAW_DATA_DIR):
        print("Creating raw data directory " + RAW_DATA_DIR)
        os.makedirs(RAW_DATA_DIR)
    
    # check data has been downloaded
    if os.path.exists(dataset_dir):
        if force_redownload:
            print(f"Removing old raw data {dataset_dir}")
            shutil.rmtree(dataset_dir)
        else:
            print(f"Directory {dataset_dir} exists, skip downloading.")
            return


    # download
    tmp_file_path = os.path.join(RAW_DATA_DIR, dataset_name + '.zip')
    if urlparse(dataset_url).netloc == 'drive.google.com':
        download_from_google_drive(dataset_url, tmp_file_path)
    else:
        download_from_url(dataset_url, tmp_file_path)

    # unzip
    extract_zip(tmp_file_path, RAW_DATA_DIR)

    # rename
    os.rename(dataset_extract_dir, dataset_dir)

Download all datasets:

In [None]:
def download_all_dataset():
    for dataset_name, dataset_info in DATASET_MAPPING.items():
        download(
            dataset_name,
            dataset_url=dataset_info['url'],
            dataset_dir=dataset_info['dir'],
            dataset_extract_dir=dataset_info['extract_dir'],
            force_redownload=False
        )

download_all_dataset()

## Preprocess

In [17]:
class AbstractPreprocessor(ABC):
    DATASET_NAME = ''
    RANDOM_SEED = 2020
    VAL_DATA_PROPORTION = 0.2

    def __init__(self, tokenizer: PreTrainedTokenizer):
        self.tokenizer = tokenizer
        self.SUB_START_ID, self.SUB_END_ID, self.OBJ_START_ID, self.OBJ_END_ID \
            = tokenizer.convert_tokens_to_ids([SUB_START_CHAR, SUB_END_CHAR, OBJ_START_CHAR, OBJ_END_CHAR])

    def preprocess_data(self, reprocess: bool):
        print(f"\n---> Preprocessing {self.DATASET_NAME} dataset <---")
        
        # create processed data dir
        if not os.path.exists(PROCESSED_DATA_DIR):
            print("Creating processed data directory " + PROCESSED_DATA_DIR)
            os.makedirs(PROCESSED_DATA_DIR)

        # stop preprocessing if file existed
        json_file_names = [self.get_json_file_name(k) for k in ('train', 'val', 'test')]
        existed_files = [fn for fn in json_file_names if os.path.exists(fn)]
        if existed_files:
            file_text = "- " + "\n- ".join(existed_files)
            if not reprocess:
                print("The following files already exist:")
                print(file_text)
                print("Preprocessing is skipped. See option --reprocess.")
                return
            else:
                print("The following files will be overwritten:")
                print(file_text)

        self._preprocess_data()

    @abstractmethod
    def _preprocess_data(self):
        pass

    def _find_sub_obj_pos(self, input_ids_list: Iterable) -> DataFrame:
        return DataFrame({
            'sub_start_pos': [self._index(s, self.SUB_START_ID) + 1 for s in input_ids_list],
            'sub_end_pos': [self._index(s, self.SUB_END_ID) for s in input_ids_list],
            'obj_start_pos': [self._index(s, self.OBJ_START_ID) + 1 for s in input_ids_list],
            'obj_end_pos': [self._index(s, self.OBJ_END_ID) for s in input_ids_list],
        })

    def _index(self, l: list, e: int) -> int:
        try:
            return l.index(e)
        except ValueError:
            return -1

    def _remove_invalid_sentences(self, data: DataFrame) -> DataFrame:
        seq_max_len = self.tokenizer.model_max_length
        return data.loc[
            (data['sub_end_pos'] < seq_max_len)
            & (data['obj_end_pos'] < seq_max_len)
            & (data['sub_end_pos'] > -1)
            & (data['obj_end_pos'] > -1)
        ]

    def _get_label_mapping(self, le: LabelEncoder):
        id_to_label = dict(enumerate(le.classes_))
        label_to_id = {v: k for k, v in id_to_label.items()}
        return {
            'id_to_label': id_to_label,
            'label_to_id': label_to_id
        }

    def _append_data_to_file(self, data: DataFrame, file: TextIO):
        lines = ""
        for _, row in data.iterrows():
            lines += row.to_json() + "\n"
        file.write(lines)

    def _save_metadata(self, metadata: dict):
        # create metadata file
        if not os.path.exists(METADATA_FILE_NAME):
            print(f"Create metadata file at {METADATA_FILE_NAME}")
            with open(METADATA_FILE_NAME, 'w') as f:
                f.write("{}\n")

        # add metadata
        print("Saving metadata")
        with open(METADATA_FILE_NAME) as f:
            root_metadata = json.load(f)
        with open(METADATA_FILE_NAME, 'w') as f:
            root_metadata[self.DATASET_NAME] = metadata
            json.dump(root_metadata, f, indent=4)

    @classmethod
    def get_json_file_name(cls, key: str) -> str:
        return os.path.join(PROCESSED_DATA_DIR, f'{cls.DATASET_NAME.lower()}_{key}.json')

class SemEval2010Task8Preprocessor(AbstractPreprocessor):
    DATASET_NAME = 'SemEval2010Task8'
    RAW_TRAIN_FILE_NAME = os.path.join(DATASET_MAPPING['SemEval2010Task8']['dir'],
                                       'SemEval2010_task8_training/TRAIN_FILE.TXT')
    RAW_TEST_FILE_NAME = os.path.join(DATASET_MAPPING['SemEval2010Task8']['dir'],
                                      'SemEval2010_task8_testing_keys/TEST_FILE_FULL.TXT')
    RAW_TRAIN_DATA_SIZE = 8000
    RAW_TEST_DATA_SIZE = 2717

    def _preprocess_data(self):
        print("Processing training data")
        train_data = self._get_data_from_file(
            self.RAW_TRAIN_FILE_NAME,
            self.RAW_TRAIN_DATA_SIZE
        )

        print("Processing test data")
        test_data = self._get_data_from_file(
            self.RAW_TEST_FILE_NAME,
            self.RAW_TEST_DATA_SIZE
        )

        print("Encoding labels to integers")
        le = LabelEncoder()
        train_data['label'] = le.fit_transform(train_data['label'])
        test_data['label'] = le.transform(test_data['label'])

        print("Splitting train & validate data")
        train_data, val_data = train_test_split(train_data, shuffle=True, random_state=self.RANDOM_SEED)

        print("Saving to json files")
        with open(self.get_json_file_name('train'), 'w') as f:
            self._append_data_to_file(train_data, f)
        with open(self.get_json_file_name('val'), 'w') as f:
            self._append_data_to_file(val_data, f)
        with open(self.get_json_file_name('test'), 'w') as f:
            self._append_data_to_file(test_data, f)

        self._save_metadata({
            'train_size': len(train_data) ,
            'val_size': len(val_data),
            'test_size': len(test_data),
            **self._get_label_mapping(le)
        })

    def _get_data_from_file(self, file_name: str, dataset_size: int) -> DataFrame:
        raw_sentences = []
        labels = []
        with open(file_name) as f:
            for _ in tqdm(range(dataset_size)):
                raw_sentences.append(self._process_sentence(f.readline()))
                labels.append(self._process_label(f.readline()))
                f.readline()
                f.readline()
        tokens = self.tokenizer(raw_sentences, truncation=True, padding=True)
        data = DataFrame(tokens.data)
        data['label'] = labels
        sub_obj_position = self._find_sub_obj_pos(data['input_ids'])
        data = pd.concat([data, sub_obj_position], axis=1)
        data = self._remove_invalid_sentences(data)
        return data

    def _process_sentence(self, sentence: str) -> str:
        # TODO distinguish e1 e2 sub obj
        return sentence.split("\t")[1][1:-2] \
            .replace("<e1>", SUB_START_CHAR) \
            .replace("</e1>", SUB_END_CHAR) \
            .replace("<e2>", OBJ_START_CHAR) \
            .replace("</e2>", OBJ_END_CHAR)

    def _process_label(self, label: str) -> str:
        return "Other" if label == 'Other\n' else label[:-8]

class LargeDatasetPreprocessor(AbstractPreprocessor):
    PROCESS_BATCH_SIZE = 2**12

    def _preprocess_data(self):
        pass

    def _process_batch(self, le: LabelEncoder, in_file: TextIO) -> DataFrame:
        raw_sentences = []
        labels = []

        for _ in range(self.PROCESS_BATCH_SIZE):
            dt = in_file.readline()
            if dt == "": break # EOF
            dt = json.loads(dt)

            # add subject markup
            sub = dt['sub']  # TODO keep _ or not?
            obj = dt['obj']
            new_sub = SUB_START_CHAR + ' ' + sub.replace("_", "") + ' ' + SUB_END_CHAR
            new_obj = OBJ_START_CHAR + ' ' +  obj.replace("_", "") + ' ' + OBJ_END_CHAR
            self._replace_once(dt['sent'], sub, new_sub)
            self._replace_once(dt['sent'], obj, new_obj)
            raw_sentences.append(" ".join(dt['sent']))
            labels.append(dt['rel'])

        if not raw_sentences:
            return DataFrame()

        tokens = self.tokenizer(raw_sentences, truncation=True, padding=True)
        data = DataFrame(tokens.data)
        data['label'] = le.fit_transform(labels)
        sub_obj_position = self._find_sub_obj_pos(data['input_ids'])
        data = pd.concat([data, sub_obj_position], axis=1)
        data = self._remove_invalid_sentences(data)
        return data

    def _replace_once(self, arr: list, element, replacement):
        for i, e in enumerate(arr):
            if e == element:
                arr[i] = replacement
                return
            if e[:-1] == element and e[-1] in ',.?!;:':
                arr[i] = replacement + e[-1]
                return

    def _process_subset(self, le: LabelEncoder, in_file_name, out_file_name, data_size) -> int:
        total_data_size = 0
        with open(in_file_name) as in_file, open(out_file_name, 'w') as out_file:
            batch_count = math.ceil(data_size / self.PROCESS_BATCH_SIZE)
            for _ in tqdm(range(batch_count)):
                data = self._process_batch(le, in_file)
                self._append_data_to_file(data, out_file)
                total_data_size += len(data)
        return total_data_size

class GIDSPreprocessor(LargeDatasetPreprocessor):
    DATASET_NAME = 'GIDS'
    RAW_TRAIN_FILE_NAME = os.path.join(DATASET_MAPPING['GIDS']['dir'], 'gids_train.json')
    RAW_VAL_FILE_NAME = os.path.join(DATASET_MAPPING['GIDS']['dir'], 'gids_dev.json')
    RAW_TEST_FILE_NAME = os.path.join(DATASET_MAPPING['GIDS']['dir'], 'gids_test.json')
    TRAIN_SIZE = 11297
    VAL_SIZE = 1864
    TEST_SIZE = 5663
    PROCESS_BATCH_SIZE = 1024

    def _preprocess_data(self):
        le = LabelEncoder()
        
        print("Process train dataset")
        actual_train_size = self._process_subset(
            le,
            self.RAW_TRAIN_FILE_NAME,
            self.get_json_file_name('train'),
            self.TRAIN_SIZE
        )

        print("Process val dataset")
        actual_val_size = self._process_subset(
            le,
            self.RAW_VAL_FILE_NAME,
            self.get_json_file_name('val'),
            self.VAL_SIZE
        )
        
        print("Process test dataset")
        actual_test_size = self._process_subset(
            le, 
            self.RAW_TEST_FILE_NAME, 
            self.get_json_file_name('test'),
            self.TEST_SIZE
        )

        self._save_metadata({
            'train_size': actual_train_size,
            'val_size': actual_val_size,
            'test_size': actual_test_size,
            **self._get_label_mapping(le)
        })

class NYTPreprocessor(LargeDatasetPreprocessor):
    DATASET_NAME = 'NYT'
    RAW_TRAIN_FILE_NAME = os.path.join(DATASET_MAPPING['NYT']['dir'], 'riedel_train.json')
    RAW_TEST_FILE_NAME = os.path.join(DATASET_MAPPING['NYT']['dir'], 'riedel_test.json')
    TRAIN_SIZE = 570084
    TEST_SIZE = 172448
    PROCESS_BATCH_SIZE = 4096 * 4

    def _preprocess_data(self):
        le = LabelEncoder()
        actual_train_size = 0
        actual_val_size = 0

        print("Process train & val dataset")
        batch_count = math.ceil(self.TRAIN_SIZE / self.PROCESS_BATCH_SIZE)
        with open(self.RAW_TRAIN_FILE_NAME) as in_file,\
                open(self.get_json_file_name('train'), 'w') as train_file,\
                open(self.get_json_file_name('val'), 'w') as val_file:
                    for _ in tqdm(range(batch_count)):
                        data = self._process_batch(le, in_file)
                        train_data, val_data = train_test_split(data, shuffle=True, random_state=self.RANDOM_SEED)
                        self._append_data_to_file(train_data, train_file)
                        self._append_data_to_file(val_data, val_file)
                        actual_train_size += len(train_data)
                        actual_val_size += len(val_data)

        print("Process test dataset")
        actual_test_size = self._process_subset(
            le, 
            self.RAW_TEST_FILE_NAME, 
            self.get_json_file_name('test'),
            self.TEST_SIZE
        )

        self._save_metadata({
            'train_size': actual_train_size,
            'val_size': actual_val_size,
            'test_size': actual_test_size,
            **self._get_label_mapping(le)
        })

def get_preprocessor_class(dataset_name: str):
    return globals()[f'{dataset_name}Preprocessor']

def get_preprocessor(dataset_name: str)-> AbstractPreprocessor:
    bert_model_info = BERT_VARIANT_MAPPING[BERT_VARIANT]
    bert_pretrain_weight = bert_model_info['pretrain_weight']
    tokenizer = bert_model_info['tokenizer'].from_pretrained(bert_pretrain_weight)
    preprocessors_class = get_preprocessor_class(dataset_name)
    return preprocessors_class(tokenizer)

preprocessor = get_preprocessor('SemEval2010Task8')
preprocessor.preprocess_data(reprocess=True)


---> Preprocessing SemEval2010Task8 dataset <---
Processing training data

Processing test data

Encoding labels to integers
Splitting train & validate data
Saving to json files
Saving metadata


HBox(children=(FloatProgress(value=0.0, max=8000.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=2717.0), HTML(value='')))

## Model

### Dataset

In [14]:
class GenericDataset(IterableDataset):

    def __init__(self, dataset_name: str, subset: str, batch_size: int):
        if subset not in ['train', 'val', 'test']:
            raise ValueError('subset must be train, val or test')
            
        with open(METADATA_FILE_NAME) as f:
            metadata = json.load(f)
        self.length = math.ceil(metadata[dataset_name][f'{subset}_size'] / batch_size)
        
        preprocessor_class = get_preprocessor_class(dataset_name)
        self.file = open(preprocessor_class.get_json_file_name(subset))

    def __del__(self):
        self.file.close()

    def __iter__(self):
        def get_data():
            for line in self.file:
                data = json.loads(line)
                input_data = {k: torch.tensor(v) for k, v in data.items() if k != 'label'}
                yield input_data, data['label']

        return get_data()
    
    def __len__(self):
        return self.length

### Torch Lightning Module

In [10]:
class BERTModule(LightningModule):

    def __init__(self, bert_variant, dataset_name, batch_size, learning_rate,
                 bert_cls_size, bert_entity_size):
        super().__init__()
        self.save_hyperparameters()

        self.cls_stream = torch.cuda.Stream()
        self.obj_stream = torch.cuda.Stream()
        self.sub_stream = torch.cuda.Stream()

        bert_info = BERT_VARIANT_MAPPING[bert_variant]
        bert_model_class = bert_info['model']
        bert_pretrain_weight = bert_info['pretrain_weight']
        self.bert = bert_model_class.from_pretrained(bert_pretrain_weight, output_attentions=True)
        
        self.cls_linear = nn.Linear(self.bert.config.hidden_size, bert_cls_size)
        self.cls_activate = nn.PReLU()
        self.sub_linear = nn.Linear(self.bert.config.hidden_size, bert_entity_size)
        self.sub_activate = nn.PReLU()
        self.obj_linear = nn.Linear(self.bert.config.hidden_size, bert_entity_size)
        self.obj_activate = nn.PReLU()
        
        dataset_info = DATASET_MAPPING[dataset_name]
        self.linear = nn.Linear(bert_cls_size + 2 * bert_entity_size, dataset_info['num_classes'])

    def train_dataloader(self) -> DataLoader:
        return self.__get_dataloader('train')

    def val_dataloader(self) -> DataLoader:
        return self.__get_dataloader('val')

    def test_dataloader(self) -> DataLoader:
        return self.__get_dataloader('test')

    def __get_dataloader(self, subset: str) -> DataLoader:
        print(f"Loading {subset} data")
        return DataLoader(
            GenericDataset(self.hparams.dataset_name, subset),
            batch_size=self.hparams.batch_size,
            #shuffle=(subset == 'train'),
            num_workers=1
        )

    def configure_optimizers(self) -> Optimizer:
        return AdamW(
            [p for p in self.parameters() if p.requires_grad],
            lr=self.hparams.learning_rate,
            eps=1e-08
        )

    def forward(self, input_ids, attention_mask, sub_start_pos, sub_end_pos,
                obj_start_pos, obj_end_pos) -> Tensor:
        bert_output, _ = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )

        bert_cls = bert_output[:, 0]
        bert_sub = torch.mean(bert_output[:, sub_start_pos, sub_end_pos], dim=1)
        bert_obj = torch.mean(bert_output[:, obj_start_pos, obj_end_pos], dim=1)
        
        torch.cuda.synchronize()
        
        with torch.cuda.stream(self.cls_stream):
            cls_output = self.cls_activate(self.cls_linear(bert_cls))
        with torch.cuda.stream(self.sub_stream):
            sub_output = self.sub_activate(self.sub_linear(bert_sub))
        with torch.cuda.stream(self.obj_stream):
            obj_output = self.obj_activate(self.obj_linear(bert_obj))

        torch.cuda.synchronize()

        linear_input = torch.cat((cls_output, sub_output, obj_output), dim=1)
        logits = self.linear(linear_input)

        return logits

    def training_step(self, batch, batch_nb) -> dict:
        input_data, label = batch
        y_hat = self(**input_data)

        loss = F.cross_entropy(y_hat, label)
        tensorboard_logs = {'train_loss': loss}

        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_nb) -> dict:
        input_data, label = batch
        y_hat = self(**input_data)

        loss = F.cross_entropy(y_hat, label)

        _, y_hat = torch.max(y_hat, dim=1)
        y_hat = y_hat.cpu()
        label = label.cpu()

        return {
            'val_loss': loss,
            'val_pre': torch.tensor(precision_score(label, y_hat, average='micro')),
            'val_rec': torch.tensor(recall_score(label, y_hat, average='micro')),
            'val_acc': torch.tensor(accuracy_score(label, y_hat)),
            'val_f1': torch.tensor(f1_score(label, y_hat, average='micro'))
        }

    def validation_epoch_end(self, outputs) -> dict:
        avg_val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_val_pre = torch.stack([x['val_pre'] for x in outputs]).mean()
        avg_val_rec = torch.stack([x['val_rec'] for x in outputs]).mean()
        avg_val_acc = torch.stack([x['val_acc'] for x in outputs]).mean()
        avg_val_f1 = torch.stack([x['val_f1'] for x in outputs]).mean()

        tensorboard_logs = {
            'avg_val_loss': avg_val_loss,
            'avg_val_pre': avg_val_pre,
            'avg_val_rec': avg_val_rec,
            'avg_val_acc': avg_val_acc,
            'avg_val_f1': avg_val_f1,
        }
        return {'val_loss': avg_val_loss, 'progress_bar': tensorboard_logs}

    def test_step(self, batch, batch_nb) -> dict:
        input_data, label = batch
        y_hat = self(**input_data)

        a, y_hat = torch.max(y_hat, dim=1)
        
        y_hat = y_hat.cpu()
        label = label.cpu()
        
        test_pre = precision_score(label, y_hat, average='micro')
        test_rec = recall_score(label, y_hat, average='micro')
        test_acc = accuracy_score(label, y_hat)
        test_f1 = f1_score(label, y_hat, average='micro')

        return {
            'test_pre': torch.tensor(test_pre),
            'test_rec': torch.tensor(test_rec),
            'test_acc': torch.tensor(test_acc),
            'test_f1': torch.tensor(test_f1),
        }

    def test_epoch_end(self, outputs) -> dict:
        avg_test_pre = torch.stack([x['test_pre'] for x in outputs]).mean()
        avg_test_rec = torch.stack([x['test_rec'] for x in outputs]).mean()
        avg_test_acc = torch.stack([x['test_acc'] for x in outputs]).mean()
        avg_test_f1 = torch.stack([x['test_f1'] for x in outputs]).mean()

        tensorboard_logs = {
            'avg_test_pre': avg_test_pre,
            'avg_test_rec': avg_test_rec,
            'avg_test_acc': avg_test_acc,
            'avg_test_f1': avg_test_f1,
        }
        return {'progress_bar': tensorboard_logs}

## Claiming back memory

See [this](https://stackoverflow.com/a/61707643/7342188) and [this](https://stackoverflow.com/a/57860310/7342188)

In [16]:
1 / 0

ZeroDivisionError: division by zero

In [17]:
model = None
gc.collect()
torch.cuda.empty_cache()

## Trainer

In [18]:
GPUS = 1
MIN_EPOCHS = 1
MAX_EPOCHS = 4

trainer = LightningTrainer(
    gpus=GPUS,
    min_epochs=MIN_EPOCHS,
    max_epochs=MAX_EPOCHS,
    default_root_dir=CHECKPOINT_DIR,
    reload_dataloaders_every_epoch=True # needed as we loop over a file
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


## Training

Create a model object:

In [19]:
BATCH_SIZE = 32
LEARNING_RATE = 2e-05

BERT_CLS_SIZE = 64
BERT_ENTITY_SIZE = 64

model = BERTModule(
    bert_variant=BERT_VARIANT,
    dataset_name=DATASET_NAME,
    batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    bert_cls_size=BERT_CLS_SIZE,
    bert_entity_size=BERT_ENTITY_SIZE
)

Start training:

In [None]:
trainer.fit(model)


  | Name   | Type            | Params
-------------------------------------------
0 | bert   | DistilBertModel | 66 M  
1 | linear | Linear          | 3 K   


Loading val data


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

Loading val data


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

Loading train data


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Loading val data
Loading train data


## Testing

In [None]:
trainer.test(model)