In [105]:
from random import sample
import re
import warnings

import pandas as pd
import xml.etree.ElementTree as et
from collections import defaultdict

import build_bert_ready_dataset as bbrd

## Processing Data for BERT 

In this notebook we process the combined dataset and write a function for splitting the dataset into training and testing, based on the topics.

The topics in training and testing must be mutually exclusive.

In [130]:
df = pd.read_csv("../data/trials_topics_combined_full.csv")
df.to_pickle("../data/trials_topics_combined_full.pickle")

In [131]:
df.head()

Unnamed: 0,score,id,brief_summary,brief_title,minimum_age,gender,primary_outcome,detailed_description,keywords,official_title,...,intervention_browse,condition_browse,inclusion,exclusion,topic,_,label,disease,gene,age
0,1.0,NCT01513603,"CLAG-M is an active, well tolerated regimen in...","Trial of Cladribine, Cytarabine, Mitoxantrone,...",,female,Complete remission percentage 1 month,Patients will receive standard dose CLAG-M (cl...,,Phase II Trial of CLAG-M in Relapsed ALL,...,Cytarabine Cladribine Mitoxantrone,Lymphoma Leukemia Precursor Cell Lymphoblastic...,- Relapsed or refractory acute lymphoblastic l...,,32,0,0,leukemia,ABL1,4
1,1.0,NCT00582257,The purpose of this study is to establish a ga...,Early Onset and Familial Gastric Cancer Registry,6570.0,female,Create registry of families w/ early onset & f...,,Gastric Cancer Stomach Cancer 05-118,Early Onset and Familial Gastric Cancer Registry,...,,Stomach Neoplasms,Patient/Relative Cohort: Must meet one or more...,Patients are ineligible for the study if they:...,33,0,0,gastric cancer,EGFR,60
2,1.0,NCT02472678,Patients with a neuroendocrine tumor (NET) fre...,Web-based Tailored Information and Support for...,6570.0,female,A composite outcome of difference in distress ...,Rationale: Patients with a neuroendocrine tumo...,,Web-based Tailored Information and Support for...,...,,Neuroendocrine Tumors Carcinoid Tumor,- Adult NET patients (aged ≥ 18 years of age) ...,- Estimated life expectancy less than 3 months...,20,0,0,melanoma,high tumor mutational burden,86
3,1.0,NCT02956889,"This is a Fleming-A' Hern, single arm, multice...",To Assess The Efficacy And Safety Of Vismodegi...,6570.0,female,evaluate the activity of the study therapy in ...,"This is a Fleming-A' Hern, single arm, multice...","Carcinoma, Basal Cell Vismodegib Radiotherapy","A Single Arm, Phase II, Multicenter Study To A...",...,,"Carcinoma Carcinoma, Basal Cell","1. Written, signed informed consent 2. Age ≥ 1...",1. Inability or unwillingness to swallow capsu...,43,0,0,basal cell carcinoma,PTCH1,56
4,1.0,NCT02510001,This trial is designed to try two new cancer d...,MErCuRIC1: MEK and MET Inhibition in Colorecta...,5840.0,male,Maximal tolerated dose (MTD) of PD-0325901 or ...,This is a two stage study. Firstly a dose esca...,RASMT CRC RASWT/c-MET CRC Dose Escalation Hist...,A Sequential Phase I Study of MEK1/2 Inhibitor...,...,Crizotinib,Colorectal Neoplasms,(Inclusion criteria for the completed initial ...,(Exclusion criteria for the completed initial ...,26,0,0,colorectal cancer,NRAS,49


In [107]:
SUBSET_COLUMNS = [
    "id", "gene", "disease", "brief_summary", "brief_title",
    "topic", "label"
]

df_input = df[SUBSET_COLUMNS].copy()

# Restrict to binary class
df_input["label"] = df_input["label"].replace(to_replace=2, value=1)

In [108]:
df_sub = bbrd.subset_data(df_input, "label", p=0.3).reset_index(drop=True)
df_sub["label"].value_counts()

0    3642
1     614
Name: label, dtype: int64

In [109]:
sample(range(0, 10), 2)

[9, 6]

In [110]:
def remove_duplicates():
    pass

def select_topics_for_training(df, col, max_training_topics, num_training_topics):
    """ The 2019 TREC PM dataset contains up to 50 topics """
    
    topic_idx_arr = sample(range(0, max_training_topics), num_training_topics)
    
    training_set = df[df[col].isin(topic_idx_arr)]
    test_set = df[~df[col].isin(topic_idx_arr)]
    
    return training_set, test_set

In [111]:
training_set, test_set = select_topics_for_training(
    df, "topic", 
    max_training_topics=50,
    num_training_topics=30
)

In [112]:
training_set["topic"].unique()

array([20, 26, 49, 36, 21, 44, 12, 17, 18, 31, 11, 48,  1,  2,  3,  4,  5,
        6,  8,  9, 13, 23, 29, 30, 40, 35, 39, 46, 27])

In [113]:
test_set["topic"].unique()

array([32, 33, 43, 50, 37, 41, 42, 19, 22, 10, 14, 25, 38, 34, 24,  7, 15,
       16, 45, 28, 47])

In [114]:
df["topic"].value_counts().sort_values()

27    154
38    173
8     187
39    188
9     191
28    197
2     199
6     205
1     207
11    210
41    217
3     218
26    220
48    224
5     224
12    225
14    231
10    231
50    234
13    234
7     236
4     236
30    236
45    240
43    251
47    266
34    286
49    288
16    289
40    310
15    314
23    320
46    329
44    330
29    343
32    343
21    351
25    355
22    366
33    368
36    368
31    369
35    372
42    380
24    386
17    390
18    401
19    405
20    445
37    446
Name: topic, dtype: int64

In [115]:
for col in df.columns:
    print(df[col].isnull().value_counts())

False    14188
Name: score, dtype: int64
False    14188
Name: id, dtype: int64
False    14188
Name: brief_summary, dtype: int64
False    14188
Name: brief_title, dtype: int64
False    12831
True      1357
Name: minimum_age, dtype: int64
False    14188
Name: gender, dtype: int64
False    13003
True      1185
Name: primary_outcome, dtype: int64
False    10528
True      3660
Name: detailed_description, dtype: int64
False    10325
True      3863
Name: keywords, dtype: int64
False    13913
True       275
Name: official_title, dtype: int64
False    12718
True      1470
Name: intervention_type, dtype: int64
False    12718
True      1470
Name: intervention_name, dtype: int64
False    8008
True     6180
Name: intervention_browse, dtype: int64
False    13408
True       780
Name: condition_browse, dtype: int64
False    14181
True         7
Name: inclusion, dtype: int64
False    11686
True      2502
Name: exclusion, dtype: int64
False    14188
Name: topic, dtype: int64
False    14188
Name: _, dtyp

In [116]:
# trial_cols = ["brief_t_and_s", "brief_title", "brief_summary"]

df["brief_t_and_s"] = df["brief_title"] + " " + df["brief_summary"]

In [117]:
print(df["brief_t_and_s"][0])

Trial of Cladribine, Cytarabine, Mitoxantrone, Filgrastim (CLAG-M) in Relapsed Acute Lymphoblastic Leukemia CLAG-M is an active, well tolerated regimen in acute myelogenous leukemia. Each of the agents is active in Acute Lymphoblastic Leukemia (ALL) as well. The current trial will determine the efficacy of the regimen in patients with relapsed ALL.


In [118]:
df["d_and_g"] = df["disease"] + " " + df["gene"]

In [119]:
seq_a = [x.split() for x in list(df["d_and_g"])]
seq_b = [x.split() for x in list(df["brief_t_and_s"])]

In [125]:
labels = list(df["label"])
labels[:5]

[0, 0, 0, 0, 0]

In [129]:
attribute_seq = list(df["topic"])
print(set(attribute_seq))

{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50}


In [16]:
from keras.preprocessing.sequence import pad_sequences

import torch

from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

from sklearn.model_selection import train_test_split

from transformers import BertConfig, BertTokenizer
# from transformers import AdamW
# from transformers import get_linear_schedule_with_warmup

Using TensorFlow backend.


In [17]:
model_name = "bert-base-uncased"

tokenizer = BertTokenizer.from_pretrained(
            model_name,
            do_lower_case=True
        )

In [18]:
seq_a = list(df["d_and_g"])
seq_b = list(df["brief_t_and_s"])
print(seq_a[:5])

['leukemia ABL1', 'gastric cancer EGFR', 'melanoma high tumor mutational burden', 'basal cell carcinoma PTCH1', 'colorectal cancer NRAS']


In [19]:
seq_a_tokenized = [tokenizer.tokenize(s) for s in seq_a]
seq_a_tokenized[:5]

[['leukemia', 'ab', '##l', '##1'],
 ['gas', '##tric', 'cancer', 'e', '##gf', '##r'],
 ['mel', '##ano', '##ma', 'high', 'tumor', 'mutation', '##al', 'burden'],
 ['basal', 'cell', 'car', '##cino', '##ma', 'pt', '##ch', '##1'],
 ['color', '##ect', '##al', 'cancer', 'nr', '##as']]

In [20]:
seq_b_tokenized = [tokenizer.tokenize(s) for s in seq_b]
seq_b_tokenized[:5]

[['trial',
  'of',
  'clad',
  '##ri',
  '##bine',
  ',',
  'cy',
  '##tara',
  '##bine',
  ',',
  'mit',
  '##ox',
  '##ant',
  '##rone',
  ',',
  'fi',
  '##l',
  '##gra',
  '##sti',
  '##m',
  '(',
  'cl',
  '##ag',
  '-',
  'm',
  ')',
  'in',
  're',
  '##la',
  '##pse',
  '##d',
  'acute',
  'l',
  '##ym',
  '##ph',
  '##ob',
  '##lastic',
  'leukemia',
  'cl',
  '##ag',
  '-',
  'm',
  'is',
  'an',
  'active',
  ',',
  'well',
  'tolerated',
  'regime',
  '##n',
  'in',
  'acute',
  'my',
  '##elo',
  '##gen',
  '##ous',
  'leukemia',
  '.',
  'each',
  'of',
  'the',
  'agents',
  'is',
  'active',
  'in',
  'acute',
  'l',
  '##ym',
  '##ph',
  '##ob',
  '##lastic',
  'leukemia',
  '(',
  'all',
  ')',
  'as',
  'well',
  '.',
  'the',
  'current',
  'trial',
  'will',
  'determine',
  'the',
  'efficacy',
  'of',
  'the',
  'regime',
  '##n',
  'in',
  'patients',
  'with',
  're',
  '##la',
  '##pse',
  '##d',
  'all',
  '.'],
 ['early',
  'onset',
  'and',
  'fa',
  '##mi'

In [64]:
class dummy():
    def __init__(self):
        self.model_name = model_name 
#         self.config = config # initialised outside of class
        self.model = None
        self.tokenizer = None
#         self.hf_model_class = BertForSequenceClassification
        self.hf_token_class = BertTokenizer

        self.tokenised_ids = None
        self.labels = None
        self.attention_masks = None
        self.training_data_loader = None
        self.test_data_loader = None
        self.validation_accuracy = None
        
        self.NUM_LABELS = 2
        self.MAX_TOKEN_LEN = 128
    
    #######################################################
    
    def _truncate_seq_pair(self, pair_a, pair_b=None):
        """Truncates a sequence pair to the maximum length."""

        if pair_b is None:
            if len(pair_a) > (self.MAX_TOKEN_LEN - 2):
                return pair_a[:(self.MAX_TOKEN_LEN - 2)], pair_b
        else:
            while True:
                total_length = len(pair_a) + len(pair_b)
                if total_length <= (self.MAX_TOKEN_LEN - 3):
                    break
                if len(pair_a) > len(pair_b):
                    pair_a.pop()
                else:
                    pair_b.pop()

        return pair_a, pair_b
    
    def _tokenize_seq(self, pair_a, pair_b):
        
        pair_a, pair_b = self._truncate_seq_pair(pair_a, pair_b)
        
        # print(f"pair_a: {pair_a}")
        pair_a = ["[CLS]"] + pair_a + ["[SEP]"]
        seg_ids_a = [0] * len(pair_a)
        
        if pair_b is not None:
            pair_b = pair_b + ["[SEP]"]
            seg_ids_b = [1] * len(pair_b)

            pair_ab = pair_a + pair_b
            seg_ids = seg_ids_a + seg_ids_b
            input_mask = [1] * (len(pair_a) + len(pair_b))
        else:
            pair_ab = pair_a
            seg_ids = seg_ids_a
            input_mask = [1] * len(pair_a)
            
        pair_ab_token_id = [self.tokenizer.convert_tokens_to_ids(token) for token in pair_ab] 
        
        # Pad the rest
        # We only pad the tokens that have been converted to ids 
        # corresponding to the BERT vocabulary book.
        while len(pair_ab_token_id) < self.MAX_TOKEN_LEN:
            pair_ab_token_id.append(0)
            seg_ids.append(0)
            input_mask.append(0)
        
        return pair_ab_token_id, seg_ids, input_mask
            
    def _build_tokenised_dataset_seq(self, seq_a, seq_b):
        seq_a_tokenised = [self.tokenizer.tokenize(s) for s in seq_a]
        
        # Use naming convention consistent with example code in 
        # HuggingFace repo. `seg_ids` corresponds to `token_type_ids` 
        # and thus the name change.
        self.input_ids = []
        self.token_type_ids = []
        self.attention_masks = []
        
        if seq_b is not None:
            seq_b_tokenised = [self.tokenizer.tokenize(s) for s in seq_b]

            for pair_a, pair_b in tqdm(zip(seq_a_tokenised, seq_b_tokenised), desc="SEQ_A_and_B"):
                pair_ab_token_id, seg_ids, input_mask = self._tokenize_seq(pair_a, pair_b)

                self.input_ids.append(pair_ab_token_id)
                self.token_type_ids.append(seg_ids)
                self.attention_masks.append(input_mask)
        else:
            for pair_a in tqdm(seq_a_tokenised, desc="SEQ_A"):
                pair_a_token_id, seg_ids, input_mask = self._tokenize_seq(pair_a, None)

                self.input_ids.append(pair_a_token_id)
                self.token_type_ids.append(seg_ids)
                self.attention_masks.append(input_mask)
            

    #######################################################

    def _split_data_by_attribute(self, seq_a, seq_b, attribute_seq):
        sample(set(attribute_seq)
        
    def create_train_and_test_dataset(self, 
                                      seq_a, seq_b, labels, 
                                      test_size, batch_size,
                                      split_by_attribute=None,
                                      attribute_seq=None,
                                      ):
        """
        
        Params:

        seq_a: Array of text strings containing the first sentence pair
        seq_b: Likewise, for the second sentence pair
        labels: Labels of <seq_a, seq_b>
        test_size: Hold out percentage
        batch_size: Number of training samples per backprop
        split_by_attribute: Choose specific attribute to split up training and test set.
                            This is required for TREC PM datasets, where we randomise topics
                            so that during validation the test set contains ONLY topics not 
                            seen during training time.
        attribute_seq: If using `split_by_attribute`, then you must pass in an array that 
                       specifies the attribute corresponding to the data.
        
        """

        if split_by_attribute is not None:
            pass

        # Convert to tokenised_text and save as self.input_ids
        self._build_tokenised_dataset_seq(seq_a, seq_b)
        self.labels = labels
        
        # To ensure stratify goes correctly (actually for any of this to go 
        # correctly) we need to set the random_state.
        X_train, X_test, y_train, y_test = train_test_split(
            self.input_ids, self.labels,
            random_state=self.RANDOM_STATE,
            test_size=test_size,
            stratify=self.labels
        )
        X_mask, X_mask_test, _, _ = train_test_split(
            self.attention_masks, self.labels,
            random_state=self.RANDOM_STATE,
            test_size=test_size,
            stratify=self.labels
        )
        
        ########################################

        # Load the training and testing data in.
        
        self.training_data_loader = DataLoader(
            TensorDataset(
                torch.tensor(X_train),
                torch.tensor(X_mask),
                torch.tensor(y_train)
            ),
            shuffle=True,
            batch_size=batch_size
        )

        self.test_data_loader = DataLoader(
            TensorDataset(
                torch.tensor(X_test),
                torch.tensor(X_mask_test),
                torch.tensor(y_test)
            ),
            shuffle=False,
            batch_size=batch_size
        )


In [62]:
seq_a_tokenized[0] + seq_b_tokenized[0]

['leukemia',
 'ab',
 '##l',
 '##1',
 'trial',
 'of',
 'clad',
 '##ri',
 '##bine',
 ',',
 'cy',
 '##tara',
 '##bine',
 ',',
 'mit',
 '##ox',
 '##ant',
 '##rone',
 ',',
 'fi',
 '##l',
 '##gra',
 '##sti',
 '##m',
 '(',
 'cl',
 '##ag',
 '-',
 'm',
 ')',
 'in',
 're',
 '##la',
 '##pse',
 '##d',
 'acute',
 'l',
 '##ym',
 '##ph',
 '##ob',
 '##lastic',
 'leukemia',
 'cl',
 '##ag',
 '-',
 'm',
 'is',
 'an',
 'active',
 ',',
 'well',
 'tolerated',
 'regime',
 '##n',
 'in',
 'acute',
 'my',
 '##elo',
 '##gen',
 '##ous',
 'leukemia',
 '.',
 'each',
 'of',
 'the',
 'agents',
 'is',
 'active',
 'in',
 'acute',
 'l',
 '##ym',
 '##ph',
 '##ob',
 '##lastic',
 'leukemia',
 '(',
 'all',
 ')',
 'as',
 'well',
 '.',
 'the',
 'current',
 'trial',
 'will',
 'determine',
 'the',
 'efficacy',
 'of',
 'the',
 'regime',
 '##n',
 'in',
 'patients',
 'with',
 're',
 '##la',
 '##pse',
 '##d',
 'all',
 '.']

In [63]:
model = dummy()

pair_ab, seg_ids, input_mask = model._tokenize_one_two_seq_pair(
    seq_a_tokenized[0], seq_b_tokenized[0] 
)

print(len(seg_ids))

assert len(pair_ab) == 128
assert len(seg_ids) == 128
assert len(input_mask) == 128

print(seg_ids[5:])
print(seg_ids[6+98:])

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


In [None]:
pair_ab, seg_ids, input_mask = model._tokenize_one_two_seq_pair(
    seq_a_tokenized[0], seq_b_tokenized[0] 
)


In [45]:
pair_ab[:100]

['[CLS]',
 'leukemia ABL1',
 'gastric cancer EGFR',
 'melanoma high tumor mutational burden',
 'basal cell carcinoma PTCH1',
 'colorectal cancer NRAS',
 'acute myeloid leukemia IDH1',
 'acute myeloid leukemia FLT3',
 'lung cancer ERBB2',
 'non-small cell carcinoma MET',
 'papillary thyroid carcinoma NTRK1',
 'glioblastoma CDK6',
 'gastric cancer EGFR',
 'melanoma tumor cells negative for PD-L1 expression',
 'melanoma extensive tumor infiltrating lymphocytes',
 'melanoma no tumor infiltrating lymphocytes',
 'glioblastoma CDK6',
 'glioma BRAF',
 'melanoma KIT (L576P)',
 'melanoma KIT (K642E)',
 'melanoma KIT amplification',
 'melanoma TP53 loss of function',
 'melanoma tumor cells with >50% membranous PD-L1 expression',
 'melanoma tumor cells negative for PD-L1 expression',
 'melanoma high tumor mutational burden',
 'melanoma high serum LDH levels',
 'cholangiocarcinoma IDH1',
 'sarcoma MDM2',
 'basal cell carcinoma PTCH1',
 'melanoma APC loss of function',
 'glioblastoma CDK6',
 'non-sm

In [176]:
len(seg_ids)

28382

### Easier way but not doing it

This way is easier, but inconsistent with how I want the data to be ingested.

And also it doesn't take into account the length of the tokens in each sequence.

In [96]:
df["input"] = " [CLS] " + df["disease"] + " " + df["gene"] + " [SEP] " \
                + df["brief_title"] + " " + df["brief_summary"] + " [SEP] "

In [97]:
[x.split() for x in df["input"]]

[['[CLS]',
  'leukemiaABL1',
  '[SEP]',
  'Trial',
  'of',
  'Cladribine,',
  'Cytarabine,',
  'Mitoxantrone,',
  'Filgrastim',
  '(CLAG-M)',
  'in',
  'Relapsed',
  'Acute',
  'Lymphoblastic',
  'LeukemiaCLAG-M',
  'is',
  'an',
  'active,',
  'well',
  'tolerated',
  'regimen',
  'in',
  'acute',
  'myelogenous',
  'leukemia.',
  'Each',
  'of',
  'the',
  'agents',
  'is',
  'active',
  'in',
  'Acute',
  'Lymphoblastic',
  'Leukemia',
  '(ALL)',
  'as',
  'well.',
  'The',
  'current',
  'trial',
  'will',
  'determine',
  'the',
  'efficacy',
  'of',
  'the',
  'regimen',
  'in',
  'patients',
  'with',
  'relapsed',
  'ALL.',
  '[SEP]'],
 ['[CLS]',
  'gastric',
  'cancerEGFR',
  '[SEP]',
  'Early',
  'Onset',
  'and',
  'Familial',
  'Gastric',
  'Cancer',
  'RegistryThe',
  'purpose',
  'of',
  'this',
  'study',
  'is',
  'to',
  'establish',
  'a',
  'gastric',
  'cancer',
  'registry.',
  'A',
  'registry',
  'is',
  'a',
  'database',
  'of',
  'information.',
  'With',
  't

## Building Topic Splitter

In [69]:
[random.randint(0, 3) for i in range(10)]

[1, 2, 1, 0, 3, 2, 1, 0, 1, 1]

In [73]:
chr(97)

'a'

In [71]:
ord('a')

97

In [75]:
seq_a = [i for i in range(20)]
seq_b = [chr(97+i) for i in range(20)]

print(seq_a)
print(seq_b)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't']


In [100]:
import random

seq_a = [i for i in range(20)]
seq_b = [chr(97+i) for i in range(20)]
labels = [random.randint(0, 1) for i in range(20)]
attrib_seq = [random.randint(0, 4) for i in range(20)]

attribute_split_ratio = 0.50

def _split_data_by_attribute(self, seq_a, seq_b, labels, attribute_seq):
#     random.seed(self.RANDOM_STATE)
    uniq_attrib = set(attribute_seq)
    
    attrib_for_test = random.sample(uniq_attrib, int(len(uniq_attrib)*attribute_split_ratio))
    
    print(f"attrib_for_test: {attrib_for_test}")
    idx = 0
    
    seq_a_test = []
    labels_test = []
    
    seq_a_train = []
    labels_train = []

    if seq_b is None:
        seq_b_test = None
        seq_b_train = None
    else:
        seq_b_test = []
        seq_b_train = []
    
    for attrib in attribute_seq:
        if attrib in attrib_for_test:
            seq_a_test.append(seq_a[idx])
            labels_test.append(labels[idx])
            if seq_b is not None:
                seq_b_test.append(seq_b[idx])
        else:
            seq_a_train.append(seq_a[idx])
            labels_train.append(labels[idx])
            if seq_b is not None:
                seq_b_train.append(seq_b[idx])
        idx += 1
    
    return seq_a_train, seq_b_train, labels_train, seq_a_test, seq_b_test, labels_test

In [102]:
print(attrib_seq)

_split_data_by_attribute(None, 
    seq_a, None,
    labels,
    attrib_seq
)

[3, 1, 0, 4, 2, 0, 3, 4, 3, 2, 4, 4, 3, 0, 2, 0, 1, 1, 1, 4]
attrib_for_test: [2, 4]


([0, 1, 2, 5, 6, 8, 12, 13, 15, 16, 17, 18],
 None,
 [0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1],
 [3, 4, 7, 9, 10, 11, 14, 19],
 None,
 [0, 1, 1, 1, 1, 0, 0, 0])

In [None]:
random.sample(set([]), 2)

In [77]:
x = []
if False:
    x = None

print(x)

[]


## TEST-BERT

In [None]:
# Think long and hard about whether you actually need 
# BertForSequenceClassification as a super class...
# and also other design issues. 

class BertForSeqFinetune(BertForSequenceClassification):
    def __init__(self, model_name, config, num_labels):
        super(BertForSeqFinetune, self).__init__(config)
        
        # self.args_loaded = False

        self.model_name = model_name 
        self.config = config # initialised outside of class
        self.model = None
        self.tokenizer = None
        self.hf_model_class = BertForSequenceClassification
        self.hf_token_class = BertTokenizer

        self.tokenised_ids = None
        self.labels = None
        self.attention_masks = None
        self.training_data_loader = None
        self.test_data_loader = None
        self.validation_accuracy = None
        
        self.NUM_LABELS = num_labels
        self.MAX_TOKEN_LEN = 128
        self.LR = 2e-5
        # self.TEST_SIZE = 0.2
        # self.EPOCHS = 3
        # self.BATCH_SIZE = 8
        self.SAVE_STEPS = 10
        self.WARMUP_STEPS = 100
        self.TOTAL_STEPS = 1000
        self.LOGGING_STEPS = 50
        self.MAX_GRAD_NORM = 1.0
        self.LOSS_OVER_TIME = []
        self.RANDOM_STATE = 2018
        
        self.cache_dir = None
        self.output_dir = "./save_files"
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)   

        self._specify_model(self.model_name, self.config, self.NUM_LABELS)     

    def _specify_model(self, model_name, config, num_labels):
        self.model = self.hf_model_class.from_pretrained(
            model_name, 
            config=self.config
        )
        self.tokenizer = self.hf_token_class.from_pretrained(
            model_name,
            do_lower_case=True
        )

    #######################################################

    #########    AREA OF UNDECIDED CODE DESIGN    #########

    #######################################################

    def _build_tokenised_dataset_two_seq(self, text_a, text_b):
        pass

    def _build_tokenised_dataset_one_seq(self, texts):
        tt_raw = [self.tokenizer.tokenize(s) for s in texts]

        tokenised_text = []
        for tokens in tt_raw:
            tokens = ['[CLS]'] + tokens
            if len(tokens) >= self.MAX_TOKEN_LEN:
                tokens[self.MAX_TOKEN_LEN - 1] = '[SEP]'
            else:
                tokens = tokens + ['[SEP]']
            tokenised_text.append(tokens)

        return tokenised_text

    def _convert_to_ids(self):
        tokenised_ids = [self.tokenizer.convert_tokens_to_ids(sent) for sent in self.tokenised_text]

        return pad_sequences(
                  tokenised_ids,
                  maxlen=self.MAX_TOKEN_LEN, dtype="long",
                  truncating="post", padding="post"
              )

    def _add_attention_masks(self):
        attention_masks = []
        # Create a mask of 1s for each token followed by 0s for padding
        for seq in self.tokenised_ids:
            seq_mask = [float(i > 0) for i in seq]
            attention_masks.append(seq_mask)
        
        return attention_masks

    #######################################################

    def create_train_and_test_dataset(self, 
                                      seq_a, seq_b, 
                                      labels, test_size, 
                                      batch_size, seg_ids=None
                                      ):
        if seq_b is None:
            _build_tokenised_dataset_one_seq(seq_a)
        else:
            _build_tokenised_dataset_two_seq(seq_a, seq_b)
        
        # Continue as usual

    def create_train_and_test_dataset(self, 
                                      raw_text, labels,
                                      test_size, batch_size,
                                      seg_ids=None
                                      ):
      
        # Convert to tokenised_text and save as self.tokenised_ids
        self.tokenised_text = self._build_tokenised_dataset_one_seq(raw_text)
        self.tokenised_ids = self._convert_to_ids()
        self.attention_masks = self._add_attention_masks()
        self.labels = labels
        

        X_train, X_test, y_train, y_test = train_test_split(
            self.tokenised_ids, self.labels,
            random_state=self.RANDOM_STATE,
            test_size=test_size
        )
        X_mask, X_mask_test, _, _ = train_test_split(
            self.attention_masks, self.tokenised_ids,
            random_state=self.RANDOM_STATE,
            test_size=test_size
        )

        self.training_data_loader = DataLoader(
            TensorDataset(
                torch.tensor(X_train),
                torch.tensor(X_mask),
                torch.tensor(y_train)
            ),
            shuffle=True,
            batch_size=batch_size
        )

        self.test_data_loader = DataLoader(
            TensorDataset(
                torch.tensor(X_test),
                torch.tensor(X_mask_test),
                torch.tensor(y_test)
            ),
            shuffle=False,
            batch_size=batch_size
        )

    def train(self, epochs, batch_size):
        if self.model is None:
            raise ValueError("Model has not been specified!")
        
        self.optimizer = AdamW(self.model.parameters(), lr=2e-5, correct_bias=False)
        self.scheduler = get_linear_schedule_with_warmup(self.optimizer,
            num_warmup_steps=self.WARMUP_STEPS,
            num_training_steps=self.TOTAL_STEPS
        )

        global_steps = 0
        tr_loss, tr_loss_prev = 0.0, 0.0
        nb_tr_examples = 0
        self.model.zero_grad()
        
        for _ in trange(epochs, desc="EPOCHS"):
            epoch_iterator = tqdm(self.training_data_loader, desc="Iteration")
            for step, batch in enumerate(epoch_iterator):
                inputs = {
                    'input_ids':      batch[0],
                    'attention_mask': batch[1],
                    'labels':         batch[2]
                }
            # if args.model_type != 'distilbert':
            #     inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 'xlnet'] else None
            
                self.model.zero_grad()
                
                outputs = self.model(**inputs)
                loss = outputs[0]
                print(f"loss: {loss}")
                loss.backward()

                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), 
                    self.MAX_GRAD_NORM
                )
                self.optimizer.step()
                self.scheduler.step()

                tr_loss += loss.item()
                self.LOSS_OVER_TIME.append(tr_loss)
                nb_tr_examples += inputs["input_ids"].size(0)
                global_steps += 1

                # @TODO: Find suitable way to record this information
                if global_steps % self.LOGGING_STEPS == 0:
                    avg_loss = (tr_loss - tr_loss_prev)/self.LOGGING_STEPS
                    tr_loss_prev = tr_loss
                    print(f"Statistics over the last {self.LOGGING_STEPS} steps:")
                    print(f"\t global_steps: {global_steps}")
                    print(f"\t average loss: {avg_loss}")
                    print(f"\t loss.item(): {loss.item()}")
                    print(f"\t tr_loss: {tr_loss}")
                    print(f"\t nb_tr_examples: {nb_tr_examples}")
            
            output_dir = os.path.join(self.output_dir, 'checkpoint-{}'.format(global_steps))
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)

            self.save_model()
            # Take care of distributed/parallel training
            # model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
            # model_to_save.save_pretrained(output_dir)
            
            # @TODO: Do we want to implement a way to save the arguments?
            # torch.save(args, os.path.join(output_dir, 'training_args.bin'))
            
        return global_steps, tr_loss/global_steps
            
    def evaluate(self, compute_metric):
        test_loss = 0.0
        nb_eval_steps = 0
        
        self.model.eval()
        
        for batch in tqdm(self.test_data_loader, desc="EVALUATING"):
            with torch.no_grad():
                inputs = {
                    'input_ids':      batch[0],
                    'attention_mask': batch[1],
                    'labels':         batch[2]
                }
            # if args.model_type != 'distilbert':
            #     inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 'xlnet'] else None  # XLM, DistilBERT and RoBERTa don't use segment_ids

            outputs = self.model(**inputs)
            tmp_test_loss, logits = outputs[:2]
            test_loss += tmp_test_loss.mean().item()
            # nb_eval_steps += 1

            batch_test_loss = compute_metric(logits.detach().numpy(), inputs["labels"].numpy())
            test_loss += batch_test_loss

        num_test_points = len(self.test_data_loader.dataset)  
        
        # previous compute_metric function accuracy percentage for each batch 
        # self.validation_accuracy = test_loss/nb_eval_steps
        print(f"test_loss: {test_loss}")
        print(f"num_test_points: {num_test_points}")
        self.validation_accuracy = test_loss/num_test_points
        print("Validation Accuracy: {}".format(self.validation_accuracy))
        
    def save_model(self):
        model_to_save = self.model.module if hasattr(self.model, 'module') else self.model  # Take care of distributed/parallel training
        model_to_save.save_pretrained(self.output_dir)
        self.tokenizer.save_pretrained(self.output_dir)

        # @TODO: Implement dict of args
#         torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))

        # Load a trained model and vocabulary that you have fine-tuned
        self.model = self.hf_model_class.from_pretrained(self.output_dir)
        self.tokenizer = self.hf_token_class.from_pretrained(self.output_dir)

print("function refreshed")


In [None]:
    def _tokenize_single_seq(self, pair_a):
        pair_a, _ = _truncate_seq_pair(pair_a)
        
        pair_a = ["[CLS]"] + pair_a + ["[SEP]"]
        seg_ids_a = [0] * len(pair_a)
        input_mask = [1] * len(pair_a)
    
        while len(pair_a) < self.MAX_TOKEN_LEN:
            pair_a.append(0)
            seg_ids_a.append(0)
            input_mask.append(0)
        
        return pair_a, seg_ids_a, input_mask
    
    def _tokenize_seq_pair(self, pair_a, pair_b):
        
        pair_a, pair_b = self._truncate_seq_pair(pair_a, pair_b)
    
        pair_a = ["[CLS]"] + pair_a + ["[SEP]"]
        seg_ids_a = [0] * len(pair_a)
        
        pair_b = pair_b + ["[SEP]"]
        seg_ids_b = [1] * len(pair_b)

        pair_ab = pair_a + pair_b
        seg_ids = seg_ids_a + seg_ids_b
        input_mask = [1] * (len(pair_a) + len(pair_b))
        
        # Pad the rest
        while len(pair_ab) < self.MAX_TOKEN_LEN:
            print(len(pair_ab))
            pair_ab.append(0)
            seg_ids.append(0)
            input_mask.append(0)
        
        return pair_ab, seg_ids, input_mask

In [None]:
    def _build_tokenised_dataset_one_seq(self, seq_a):
        seq_a_tokenised = [self.tokenizer.tokenize(s) for s in seq_a]
        
        self.seq_ab_arr = []
        self.seg_ids_arr = []
        self.input_mask_arr = []
        
        for pair_a in seq_a:
            pair_a, seg_ids, input_mask = _tokenize_single_seq(pair_a)
            
            self.seq_ab_arr.append(pair_ab)
            self.seg_ids_arr.append(seg_ids)
            self.input_mask_arr.append(input_mask)
    
    def _build_tokenised_dataset_two_seq(self, seq_a, seq_b):
        seq_a_tokenised = [self.tokenizer.tokenize(s) for s in seq_a]
        seq_b_tokenised = [self.tokenizer.tokenize(s) for s in seq_b]

        self.seq_ab_arr = []
        self.seg_ids_arr = []
        self.input_mask_arr = []
        
        for pair_a, pair_b in zip(seq_a, seq_b):
            pair_ab, seg_ids, input_mask = _tokenize_seq_pair(pair_a, pair_b)
            
            self.seq_ab_arr.append(pair_ab)
            self.seg_ids_arr.append(seg_ids)
            self.input_mask_arr.append(input_mask)