In [1]:
"""
Run BERT models.
"""
from nlp_pipeline.textpreprocessing.text_to_tensor import *
from nlp_pipeline.splits.split import *
from nlp_pipeline.metrics.ranks import *
from nlp_pipeline.textpreprocessing.summary import *
from utils import *
from nlp_pipeline.models.bert import *
# basic
import logging
import random
import string
import argparse


Using TensorFlow backend.


In [2]:
console_handler = logging.StreamHandler()

dl_logger = logging.getLogger(__name__)
dl_logger.addHandler(console_handler)
dl_logger.setLevel(logging.INFO)

#####################################
### Global Parameters That Need To Be Set
#####################################

SEED = random.randint(9, 999)
dl_logger.info("The seed chosen is " + str(SEED))

seed_all(SEED)

dl_logger.info("Creating pretrained model")

# PRETRAINED_MODEL_NAME = '../notebooks/pretrained_models/'
PRETRAINED_MODEL_NAME = "bert"

DATASET_DIR = "googlequestchallenge/torch_datasets"
create_dir(Path(DATASET_DIR))

# Directory for train, test and sample
input_dir = Path("googlequestchallenge/")

# Read in the datasets
train = pd.read_csv(input_dir / "train.csv")
test = pd.read_csv(input_dir / "test.csv")
sample_submissions = pd.read_csv(input_dir / "sample_submission.csv")

ALL_LABELS = list(sample_submissions.columns)[1:]
ALL_COLS = ["question_title", "question_body", "answer"]

QUESTION_LABELS = [x for x in ALL_LABELS if x.startswith("question")]
QUESTION_COLS = ["question_title", "question_body"]

ANSWER_COLS = ["answer"]
ANSWER_LABELS = [x for x in ALL_LABELS if x.startswith("answer")]

The seed chosen is 324
Creating pretrained model


In [3]:
#####################################
### Defining the tokenizer class and config class
#####################################
pretrain_model = Text2Tensor()
pretrain_model.choose_model("bert")

In [4]:
#####################################
### Feature Engineering the words
#####################################
dl_logger.info("Adding features that need to be be added to the pooler.")

QUESTION_ALL = QUESTION_COLS
ANSWER_ALL = ANSWER_COLS

# CAT_FEATURES = [col for col in train.columns if col.startswith("subcat") or col.startswith("category")]

Adding features that need to be be added to the pooler.


In [5]:
#####################################
### Change text data into iterable 
### In this case, we change this into pandas series
#####################################
train_questions = train["question_title"] + ["SEP"] + train["question_body"]
train_answers = train["answer"]

test_questions = test["question_title"] + ["SEP"] + test["question_body"]
test_answers = test["answer"]


In [7]:
dl_logger.info("Tokenising Question data...")

encode_method = None
head_len = None
train_q_input_list, train_q_att_list = pretrain_model.convert_text_to_tensor(
    train_questions, head_len=head_len,
    encode_method=encode_method)

dl_logger.info("Tokenising answer data...")
train_ans_input_list, train_ans_att_list = pretrain_model.convert_text_to_tensor(
    train_answers, encode_method=encode_method, head_len=head_len)

test_q_input_list, test_q_att_list = pretrain_model.convert_text_to_tensor(
    test_questions, encode_method=encode_method, head_len=head_len)

test_ans_input_list, test_ans_att_list = pretrain_model.convert_text_to_tensor(
    test_answers, encode_method=encode_method, head_len=head_len)


Tokenising Question data...
  0%|          | 12/6079 [00:00<00:54, 110.98it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (2501 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (517 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2425 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1061 > 512). Running this sequence through the model will result in indexing errors
  1%|          | 39/6079 [00:00<00:44, 134.46it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (518 > 512). Running this sequence through 

Token indices sequence length is longer than the specified maximum sequence length for this model (699 > 512). Running this sequence through the model will result in indexing errors
  7%|▋         | 438/6079 [00:01<00:16, 339.58it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (3338 > 512). Running this sequence through the model will result in indexing errors
  8%|▊         | 474/6079 [00:01<00:16, 341.73it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1627 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (713 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (519 > 512). Running this sequence through the model will result in in

Token indices sequence length is longer than the specified maximum sequence length for this model (630 > 512). Running this sequence through the model will result in indexing errors
 15%|█▍        | 906/6079 [00:02<00:12, 410.13it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (573 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (744 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (579 > 512). Running this sequence through the model will result in indexing errors
 16%|█▌        | 949/6079 [00:02<00:12, 415.32it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (538 > 512). Running this sequence through the model will result in inde

Token indices sequence length is longer than the specified maximum sequence length for this model (753 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (520 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1122 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (3240 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1097 > 512). Running this sequence through the model will result in indexing errors
 24%|██▍       | 1461/6079 [00:03<00:12, 372.98it/s]Token indices sequence length is lo

 31%|███▏      | 1911/6079 [00:04<00:10, 413.44it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1024 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (567 > 512). Running this sequence through the model will result in indexing errors
 32%|███▏      | 1954/6079 [00:04<00:09, 415.27it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (2449 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (701 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2684 > 512). Running this sequence through the model will result in

Token indices sequence length is longer than the specified maximum sequence length for this model (839 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1740 > 512). Running this sequence through the model will result in indexing errors
 38%|███▊      | 2317/6079 [00:05<00:09, 377.73it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (614 > 512). Running this sequence through the model will result in indexing errors
 39%|███▉      | 2370/6079 [00:06<00:08, 413.26it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (790 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1523 > 512). Running this sequence through the model will result in 

Token indices sequence length is longer than the specified maximum sequence length for this model (1654 > 512). Running this sequence through the model will result in indexing errors
 49%|████▉     | 3007/6079 [00:07<00:07, 437.54it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1375 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1239 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (982 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1013 > 512). Running this sequence through the model will result in indexing errors
 50%|█████     | 3052/6079 [00:07<

Token indices sequence length is longer than the specified maximum sequence length for this model (1833 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (3082 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (519 > 512). Running this sequence through the model will result in indexing errors
 58%|█████▊    | 3556/6079 [00:08<00:07, 335.13it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (871 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (526 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is lon

Token indices sequence length is longer than the specified maximum sequence length for this model (989 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2332 > 512). Running this sequence through the model will result in indexing errors
 68%|██████▊   | 4114/6079 [00:10<00:04, 431.84it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1483 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (519 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (3164 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is lo

Token indices sequence length is longer than the specified maximum sequence length for this model (593 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (656 > 512). Running this sequence through the model will result in indexing errors
 75%|███████▍  | 4550/6079 [00:11<00:03, 412.46it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (575 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (598 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (841 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longe

Token indices sequence length is longer than the specified maximum sequence length for this model (528 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (691 > 512). Running this sequence through the model will result in indexing errors
 81%|████████▏ | 4948/6079 [00:12<00:02, 419.88it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (566 > 512). Running this sequence through the model will result in indexing errors
 82%|████████▏ | 4997/6079 [00:12<00:02, 436.84it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (4230 > 512). Running this sequence through the model will result in indexing errors
 83%|████████▎ | 5042/6079 [00:12<00:02, 431.48it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1726 > 512). Run

Token indices sequence length is longer than the specified maximum sequence length for this model (539 > 512). Running this sequence through the model will result in indexing errors
 89%|████████▉ | 5438/6079 [00:13<00:01, 389.56it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (660 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1205 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (759 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1205 > 512). Running this sequence through the model will result in indexing errors
 90%|█████████ | 5478/6079 [00:13<00

Token indices sequence length is longer than the specified maximum sequence length for this model (693 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (903 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (526 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (575 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1107 > 512). Running this sequence through the model will result in indexing errors
 98%|█████████▊| 5930/6079 [00:14<00:00, 368.70it/s]Token indices sequence length is long

Token indices sequence length is longer than the specified maximum sequence length for this model (1125 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1658 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (805 > 512). Running this sequence through the model will result in indexing errors
  6%|▌         | 335/6079 [00:00<00:13, 416.28it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1460 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1242 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is lo

Token indices sequence length is longer than the specified maximum sequence length for this model (521 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (767 > 512). Running this sequence through the model will result in indexing errors
 16%|█▌        | 951/6079 [00:02<00:11, 443.97it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (569 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1127 > 512). Running this sequence through the model will result in indexing errors
 17%|█▋        | 1010/6079 [00:02<00:10, 478.46it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (2926 > 512). Running this sequence through the model will result in i

 25%|██▍       | 1497/6079 [00:03<00:10, 448.66it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (707 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (512 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (3016 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (847 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1028 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is lon

Token indices sequence length is longer than the specified maximum sequence length for this model (793 > 512). Running this sequence through the model will result in indexing errors
 33%|███▎      | 2033/6079 [00:04<00:08, 457.94it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1109 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (954 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1050 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (666 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is lon

Token indices sequence length is longer than the specified maximum sequence length for this model (1151 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (995 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (534 > 512). Running this sequence through the model will result in indexing errors
 42%|████▏     | 2535/6079 [00:05<00:08, 436.55it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (513 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (856 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is long

Token indices sequence length is longer than the specified maximum sequence length for this model (593 > 512). Running this sequence through the model will result in indexing errors
 51%|█████     | 3085/6079 [00:07<00:07, 415.19it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (726 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (547 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (799 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (3200 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is long

Token indices sequence length is longer than the specified maximum sequence length for this model (915 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (780 > 512). Running this sequence through the model will result in indexing errors
 59%|█████▉    | 3590/6079 [00:08<00:05, 417.66it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1742 > 512). Running this sequence through the model will result in indexing errors
 60%|█████▉    | 3635/6079 [00:08<00:05, 425.57it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (606 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1072 > 512). Running this sequence through the model will result in 

 68%|██████▊   | 4113/6079 [00:09<00:04, 461.84it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (733 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (799 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (535 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1374 > 512). Running this sequence through the model will result in indexing errors
 68%|██████▊   | 4162/6079 [00:09<00:04, 454.20it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (554 > 512). Running this sequence through the model will result in i

 75%|███████▍  | 4534/6079 [00:10<00:04, 331.35it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (597 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (647 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (682 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1179 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (732 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is long

 85%|████████▍ | 5158/6079 [00:12<00:01, 460.88it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (873 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (647 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (581 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (532 > 512). Running this sequence through the model will result in indexing errors
 86%|████████▌ | 5205/6079 [00:12<00:01, 448.62it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1027 > 512). Running this sequence through the model will result in i

Token indices sequence length is longer than the specified maximum sequence length for this model (2130 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (621 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (753 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1029 > 512). Running this sequence through the model will result in indexing errors
 95%|█████████▌| 5778/6079 [00:13<00:00, 400.51it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (529 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is lon

Token indices sequence length is longer than the specified maximum sequence length for this model (530 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (594 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (915 > 512). Running this sequence through the model will result in indexing errors
 34%|███▍      | 161/476 [00:00<00:00, 346.76it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1288 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (864 > 512). Running this sequence through the model will result in indexing errors
 44%|████▍     | 210/476 [00:00<00:00, 

 28%|██▊       | 133/476 [00:00<00:00, 400.12it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (961 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1470 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (602 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (715 > 512). Running this sequence through the model will result in indexing errors
 36%|███▌      | 170/476 [00:00<00:00, 389.07it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1067 > 512). Running this sequence through the model will result in inde

In [8]:
class QuestDataset(Dataset):
    def __init__(
        self,
        input_ids: List[torch.Tensor]=None,
        attention_mask: List[torch.Tensor]=None,
        label_tensors: torch.FloatTensor = None,
        **kwargs
    ):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.label_tensors = label_tensors
        # Create an obvious x
        self.x = list(zip(self.input_ids, self.attention_mask))
        self.y = label_tensors
        
    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return self.x[idx], self.y[idx]


In [9]:
train_q_label_tensors = torch.FloatTensor(train[QUESTION_LABELS].values)
test_q_label_tensors = torch.zeros(test.size, 30)

train_ans_label_tensors = torch.FloatTensor(train[ANSWER_LABELS].values)
test_ans_label_tensors = torch.zeros(test.size, 30)

train_q_dataset = QuestDataset(
    input_ids=train_q_input_list,
    attention_mask=train_q_att_list,
    label_tensors=train_q_label_tensors,
)

test_q_dataset = QuestDataset(
    input_ids=test_q_input_list,
    attention_mask=test_q_att_list,
    label_tensors=test_q_label_tensors,
)

train_ans_dataset = QuestDataset(
    input_ids=train_ans_input_list,
    attention_mask=train_ans_att_list,
    label_tensors=train_ans_label_tensors,
)

test_ans_dataset = QuestDataset(
    input_ids=test_ans_input_list,
    attention_mask=test_ans_att_list,
    label_tensors=test_ans_label_tensors,
)


In [10]:
torch_save_dict = {
    train_q_dataset: "train_q",
    test_q_dataset: "test_q",
    train_ans_dataset: "train_ans",
    test_ans_dataset: "test_ans"
}

for dataset, dataset_name in torch_save_dict.items():
    # dataset dir, description with underscores and fold number
    output_path = Path(DATASET_DIR, dataset_name)
    # We do not save the test fold because there is no fold
    torch.save(dataset, output_path)
    print("Saving to: "+str(output_path))

Saving to: googlequestchallenge/torch_datasets/train_q
Saving to: googlequestchallenge/torch_datasets/test_q
Saving to: googlequestchallenge/torch_datasets/train_ans
Saving to: googlequestchallenge/torch_datasets/test_ans


In [11]:
# Load in the custom tokenizer still
pretrain_model.config.num_labels = len(QUESTION_LABELS)
pretrain_model.config.output_hidden_states = True

## Training 

In [12]:
class CustomTransformerModel(nn.Module):
    """Model for our custom transformer"""
    def __init__(self, transformer_model: PreTrainedModel):
        """Simplest Transformer Model"""
        super(CustomTransformerModel, self).__init__()
        self.transformer = transformer_model

    def forward(self, input_ids, attention_mask, engineered_features=None):
        results = self.transformer(
            input_ids, attention_mask=attention_mask, engineered_features=None,
        )
        logits = results[0]
        return logits

In [14]:
# q_custom_transformer_model = CustomTransformerModel(q_transformer_model)
reduce_lr_callback = partial(
    ReduceLROnPlateauCallback,
    monitor="spearman_rho",
    mode="max",
    patience=3,
    min_delta=0.001,
    min_lr=1e-6,
)

In [15]:
def flattenAnneal(learn: Learner, lr: float, n_epochs: int, start_pct: float):
    """Learning method"""
    n = len(learn.data.train_dl)
    anneal_start = int(n * n_epochs * start_pct)
    anneal_end = int(n * n_epochs) - anneal_start
    # 2.6 was chosen from ULMFIT paper as the ideal number
    lr_array = np.array(
        [
            lr / (2.6 ** 8),
            lr / (2.6 ** 7),
            lr / (2.6 ** 6),
            lr / (2.6 ** 5),
            lr / (2.6 ** 4),
            lr / (2.6 ** 3),
            lr / (2.6 ** 2),
            lr / (2.6 ** 1),
            lr,
        ]
    )
    phases = [
        TrainingPhase(anneal_start).schedule_hp("lr", lr_array),
        TrainingPhase(anneal_end).schedule_hp("lr", lr_array, anneal=annealing_cos),
    ]
    sched = GeneralScheduler(learn, phases)
    learn.callbacks.append(sched)
    learn.fit(n_epochs)

In [None]:
OUTPUT_MODEL_DIR = "."
# batch size
bs = 8

# Create GroupKFold index - 3 was chosen because of time limit 
from sklearn.model_selection import GroupKFold
gkf = GroupKFold(n_splits=3)
gkf_splits = gkf.split(train_questions, groups=train['question_body'])

for i, (train_split, valid_split) in enumerate(gkf_splits):
    print("Fold :" + str(i))
    all_text_data = [
        train_questions, train_answers,  
        test_questions, test_answers
    ]
    #####################################
    ### Converting Text to tensor
    #####################################
    train_sampler = iter(train_split)
    valid_sampler = iter(valid_split)
    ()
    dl_kwargs = {
        "batch_size": bs,
        "shuffle": False,
        "batch_sampler": None,
        "num_workers": 0,
        # "pin_memory": True, - this has been defined in the collate wrapper implementation
    }
    train_sampler = RandomSampler(train_split,replacement=False, num_samples=None)
    valid_sampler = RandomSampler(valid_split,replacement=False, num_samples=None)
    
    train_q_dl = DataLoader(train_q_dataset, sampler=train_sampler, **dl_kwargs)
    valid_q_dl = DataLoader(train_q_dataset, sampler=valid_sampler, **dl_kwargs)
    
    
    q_databunch = TextDataBunch(
        train_dl=train_q_dl, valid_dl=valid_q_dl, device="cuda:0",
    )
    
    opt_func = partial(
        Ranger, betas=(0.9, 0.99), eps=0.05
    )
    
    # Defining the learnin
    model_class = BertSequenceClassification
    q_transformer_model = model_class.from_pretrained(
        'bert-base-uncased',
        config=pretrain_model.config,
        dropout_rate=0.05,
        hidden_layer_output=2,
    )
    q_custom_transformer_model = CustomTransformerModel(q_transformer_model)
    q_learner = BertLearner(
        q_databunch,
        q_custom_transformer_model,
        opt_func=opt_func,
        bn_wd=False,
        true_wd=True,
        metrics=[SpearmanRho()],
        callback_fns=[reduce_lr_callback],
    ).to_fp16()
    q_learner.freeze()
    # Find a learning rate
    q_learner.lr_find()
    lr = q_learner.recorder.lrs[np.array(q_learner.recorder.losses).argmin()] / 10
    freeze_to_counter = 1
    # 6 has been chosen from experience
    while freeze_to_counter < 6:
        freeze_to = freeze_to_counter
        print("Freezing up to "+str(freeze_to))
        q_learner.freeze_to(-freeze_to)
        flattenAnneal(q_learner, lr, 5, 0.55)
        
        # Save model 
        model_save_name = "bert_q_" + str(freeze_to) + "_fold_" + str(i)
        q_learner.save(model_save_name, with_opt=False)
        print("Saved to " + str(model_save_name))
        
        # Reset model
        del q_learner
        gc.collect()
        torch.cuda.empty_cache()
        q_databunch = TextDataBunch(
            train_dl=train_q_dl,
            valid_dl=valid_q_dl,
            device="cuda:0",
        )
        q_learner = BertLearner(
            q_databunch,
            q_custom_transformer_model,
            opt_func=opt_func,
            bn_wd=False,
            true_wd=True,
            metrics=[SpearmanRho()],
            callback_fns=[reduce_lr_callback],
        ).to_fp16()
        
        q_learner.load(model_save_name);
        freeze_to_counter += 1

Fold :0


epoch,train_loss,valid_loss,spearman_rho,time


LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
Freezing up to 1


epoch,train_loss,valid_loss,spearman_rho,time
0,0.447322,0.452833,0.303379,02:54
1,0.437046,0.409267,0.311145,03:09
2,0.44856,0.444406,0.313755,03:09
3,0.42016,0.403408,0.331118,03:01
4,0.401592,0.397222,0.34254,03:07


Saved to bert_q_1_fold_0
Freezing up to 2


epoch,train_loss,valid_loss,spearman_rho,time
0,0.444831,0.446868,0.259386,03:13
1,0.437794,0.412911,0.29641,03:14
2,0.426324,0.425208,0.289432,03:14
3,0.403179,0.40092,0.325346,03:07
4,0.397226,0.391346,0.355181,03:05


Saved to bert_q_2_fold_0
Freezing up to 3


epoch,train_loss,valid_loss,spearman_rho,time
0,0.386149,0.392072,0.368999,03:32
1,0.387444,0.374491,0.39989,03:21
2,0.382796,0.381866,0.404757,03:29
3,0.368509,0.36119,0.436965,03:25
4,0.359168,0.356058,0.450239,03:18


Saved to bert_q_3_fold_0
Freezing up to 4


epoch,train_loss,valid_loss,spearman_rho,time
0,0.384853,0.376841,0.407955,04:13
1,0.37574,0.365552,0.421693,04:13
2,0.374227,0.368073,0.429258,04:13
3,0.362473,0.355173,0.45627,04:15
4,0.357086,0.347906,0.471147,04:13


Saved to bert_q_4_fold_0
Freezing up to 5


epoch,train_loss,valid_loss,spearman_rho,time
0,0.375562,0.365659,0.433783,04:57
1,0.374159,0.357558,0.442303,04:57
2,0.368358,0.362115,0.444033,04:57
3,0.357757,0.349735,0.467405,05:02
4,0.346552,0.342165,0.484977,04:55


Saved to bert_q_5_fold_0
Fold :1


epoch,train_loss,valid_loss,spearman_rho,time


LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
Freezing up to 1


epoch,train_loss,valid_loss,spearman_rho,time
0,0.428966,0.443475,0.30518,03:09
1,0.427516,0.405457,0.318886,03:00
2,0.433183,0.430827,0.321009,03:00
3,0.412629,0.402739,0.334962,02:56
4,0.399331,0.397583,0.342553,02:51


Saved to bert_q_1_fold_1
Freezing up to 2


epoch,train_loss,valid_loss,spearman_rho,time
0,0.429293,0.449885,0.288779,02:51
1,0.426606,0.411429,0.295324,02:50
2,0.426774,0.453097,0.301671,02:50
3,0.401036,0.396207,0.341683,02:50
4,0.392526,0.390588,0.358609,02:50


Saved to bert_q_2_fold_1
Freezing up to 3


epoch,train_loss,valid_loss,spearman_rho,time
