In [102]:
"""
Run BERT models.
"""
from nlp_pipeline.textpreprocessing.text_to_tensor import *
from nlp_pipeline.splits.split import *
from nlp_pipeline.metrics.ranks import *
from nlp_pipeline.textpreprocessing.summary import *
from utils import *
from nlp_pipeline.models.bert import *
# basic
import logging
import random
import string
import argparse


In [103]:
console_handler = logging.StreamHandler()

dl_logger = logging.getLogger(__name__)
dl_logger.addHandler(console_handler)
dl_logger.setLevel(logging.INFO)

#####################################
### Global Parameters That Need To Be Set
#####################################

SEED = random.randint(9, 999)
dl_logger.info("The seed chosen is " + str(SEED))

seed_all(SEED)

dl_logger.info("Creating pretrained model")

# PRETRAINED_MODEL_NAME = '../notebooks/pretrained_models/'
PRETRAINED_MODEL_NAME = "bert"

DATASET_DIR = "googlequestchallenge/torch_datasets"
create_dir(Path(DATASET_DIR))

# Directory for train, test and sample
input_dir = Path("googlequestchallenge/")

# Read in the datasets
train = pd.read_csv(input_dir / "train.csv")
test = pd.read_csv(input_dir / "test.csv")
sample_submissions = pd.read_csv(input_dir / "sample_submission.csv")

ALL_LABELS = list(sample_submissions.columns)[1:]
ALL_COLS = ["question_title", "question_body", "answer"]

QUESTION_LABELS = [x for x in ALL_LABELS if x.startswith("question")]
QUESTION_COLS = ["question_title", "question_body"]

ANSWER_COLS = ["answer"]
ANSWER_LABELS = [x for x in ALL_LABELS if x.startswith("answer")]

The seed chosen is 830
The seed chosen is 830
The seed chosen is 830
The seed chosen is 830
Creating pretrained model
Creating pretrained model
Creating pretrained model
Creating pretrained model


In [3]:
#####################################
### Defining the tokenizer class and config class
#####################################
pretrain_model = Text2Tensor()
pretrain_model.choose_model("bert")

In [5]:
#####################################
### Feature Engineering the words
#####################################
dl_logger.info("Adding features that need to be be added to the pooler.")

QUESTION_ALL = QUESTION_COLS
ANSWER_ALL = ANSWER_COLS

# CAT_FEATURES = [col for col in train.columns if col.startswith("subcat") or col.startswith("category")]

Adding features that need to be be added to the pooler.


In [6]:
#####################################
### Change text data into iterable 
### In this case, we change this into pandas series
#####################################
train_questions = train["question_title"] + ["SEP"] + train["question_body"]
train_answers = train["answer"]

test_questions = test["question_title"] + ["SEP"] + test["question_body"]
test_answers = test["answer"]


In [9]:
dl_logger.info("Tokenising Question data...")

encode_method = None
head_len = None
train_q_input_list, train_q_att_list = pretrain_model.convert_text_to_tensor(
    train_questions, head_len=head_len 
    encode_method=encode_method)

dl_logger.info("Tokenising answer data...")
train_ans_input_list, train_ans_att_list = pretrain_model.convert_text_to_tensor(
    train_answers, encode_method=encode_method, head_len=head_len)

test_q_input_list, test_q_att_list = pretrain_model.convert_text_to_tensor(
    test_questions, encode_method=encode_method, head_len=head_len)

test_ans_input_list, test_ans_att_list = pretrain_model.convert_text_to_tensor(
    test_answers, encode_method=encode_method, head_len=head_len)


Tokenising Question data...
  0%|          | 1/6079 [00:00<32:27,  3.12it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (2501 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (517 > 512). Running this sequence through the model will result in indexing errors
  0%|          | 24/6079 [00:00<22:45,  4.43it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (2425 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1061 > 512). Running this sequence through the model will result in indexing errors
  1%|          | 47/6079 [00:00<16:00,  6.28it/s]Token indices sequence length is longer than the specified maximum sequence length for this mod

  7%|▋         | 416/6079 [00:01<00:50, 112.37it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (699 > 512). Running this sequence through the model will result in indexing errors
  8%|▊         | 464/6079 [00:01<00:38, 145.71it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (3338 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1627 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (713 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (519 > 512). Running this sequence through the model will result in in

Token indices sequence length is longer than the specified maximum sequence length for this model (630 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (573 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (744 > 512). Running this sequence through the model will result in indexing errors
 15%|█▌        | 937/6079 [00:02<00:12, 403.30it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (579 > 512). Running this sequence through the model will result in indexing errors
 16%|█▌        | 987/6079 [00:02<00:11, 427.78it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (538 > 512). Running this sequence through the model will result in inde

Token indices sequence length is longer than the specified maximum sequence length for this model (753 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (520 > 512). Running this sequence through the model will result in indexing errors
 24%|██▍       | 1446/6079 [00:03<00:11, 409.96it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1122 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (3240 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1097 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is lo

Token indices sequence length is longer than the specified maximum sequence length for this model (1106 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1024 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (567 > 512). Running this sequence through the model will result in indexing errors
 32%|███▏      | 1928/6079 [00:05<00:10, 408.54it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (2449 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (701 > 512). Running this sequence through the model will result in indexing errors
 32%|███▏      | 1973/6079 [00:05<0

Token indices sequence length is longer than the specified maximum sequence length for this model (2838 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (839 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1740 > 512). Running this sequence through the model will result in indexing errors
 38%|███▊      | 2309/6079 [00:06<00:09, 390.85it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (614 > 512). Running this sequence through the model will result in indexing errors
 39%|███▉      | 2358/6079 [00:06<00:08, 416.03it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (790 > 512). Running this sequence through the model will result in 

 49%|████▊     | 2960/6079 [00:07<00:06, 450.23it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (513 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1654 > 512). Running this sequence through the model will result in indexing errors
 49%|████▉     | 3007/6079 [00:07<00:06, 448.39it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1375 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1239 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (982 > 512). Running this sequence through the model will result in

Token indices sequence length is longer than the specified maximum sequence length for this model (6862 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1833 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (3082 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (519 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (871 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for 

Token indices sequence length is longer than the specified maximum sequence length for this model (989 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2332 > 512). Running this sequence through the model will result in indexing errors
 68%|██████▊   | 4110/6079 [00:10<00:04, 428.31it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1483 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (519 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (3164 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is lo

Token indices sequence length is longer than the specified maximum sequence length for this model (593 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (656 > 512). Running this sequence through the model will result in indexing errors
 75%|███████▍  | 4553/6079 [00:11<00:03, 430.36it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (575 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (598 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (841 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longe

Token indices sequence length is longer than the specified maximum sequence length for this model (528 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (691 > 512). Running this sequence through the model will result in indexing errors
 81%|████████  | 4931/6079 [00:12<00:02, 420.63it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (566 > 512). Running this sequence through the model will result in indexing errors
 82%|████████▏ | 4980/6079 [00:12<00:02, 437.86it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (4230 > 512). Running this sequence through the model will result in indexing errors
 83%|████████▎ | 5039/6079 [00:12<00:02, 450.35it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1726 > 512). Run

Token indices sequence length is longer than the specified maximum sequence length for this model (539 > 512). Running this sequence through the model will result in indexing errors
 90%|████████▉ | 5442/6079 [00:13<00:01, 393.75it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (660 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1205 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (759 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1205 > 512). Running this sequence through the model will result in indexing errors
 90%|█████████ | 5483/6079 [00:13<00

Token indices sequence length is longer than the specified maximum sequence length for this model (693 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (903 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (526 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (575 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1107 > 512). Running this sequence through the model will result in indexing errors
 98%|█████████▊| 5932/6079 [00:14<00:00, 380.82it/s]Token indices sequence length is long

Token indices sequence length is longer than the specified maximum sequence length for this model (1125 > 512). Running this sequence through the model will result in indexing errors
  5%|▌         | 316/6079 [00:00<00:12, 455.05it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1658 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (805 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1460 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1242 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is lo

Token indices sequence length is longer than the specified maximum sequence length for this model (521 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (767 > 512). Running this sequence through the model will result in indexing errors
 16%|█▌        | 953/6079 [00:02<00:10, 470.64it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (569 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1127 > 512). Running this sequence through the model will result in indexing errors
 17%|█▋        | 1010/6079 [00:02<00:10, 496.11it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (2926 > 512). Running this sequence through the model will result in i

Token indices sequence length is longer than the specified maximum sequence length for this model (707 > 512). Running this sequence through the model will result in indexing errors
 25%|██▍       | 1512/6079 [00:03<00:10, 435.55it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (512 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (3016 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (847 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1028 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is lon

 33%|███▎      | 2007/6079 [00:04<00:09, 447.65it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1109 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (954 > 512). Running this sequence through the model will result in indexing errors
 34%|███▍      | 2057/6079 [00:04<00:08, 460.73it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1050 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (666 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (731 > 512). Running this sequence through the model will result in 

Token indices sequence length is longer than the specified maximum sequence length for this model (995 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (534 > 512). Running this sequence through the model will result in indexing errors
 42%|████▏     | 2537/6079 [00:05<00:07, 467.69it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (513 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (856 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (541 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longe

Token indices sequence length is longer than the specified maximum sequence length for this model (726 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (547 > 512). Running this sequence through the model will result in indexing errors
 51%|█████     | 3095/6079 [00:06<00:07, 424.32it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (799 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (3200 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (585 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is long

Token indices sequence length is longer than the specified maximum sequence length for this model (1742 > 512). Running this sequence through the model will result in indexing errors
 59%|█████▉    | 3595/6079 [00:08<00:05, 443.02it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (606 > 512). Running this sequence through the model will result in indexing errors
 60%|█████▉    | 3647/6079 [00:08<00:05, 460.27it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1072 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (535 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1055 > 512). Running this sequence through the model will result in

Token indices sequence length is longer than the specified maximum sequence length for this model (535 > 512). Running this sequence through the model will result in indexing errors
 68%|██████▊   | 4157/6079 [00:09<00:03, 495.14it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1374 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (554 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (760 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (769 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is long

 75%|███████▌  | 4561/6079 [00:10<00:04, 368.85it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1179 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (732 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2063 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (707 > 512). Running this sequence through the model will result in indexing errors
 76%|███████▌  | 4602/6079 [00:10<00:03, 379.92it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (691 > 512). Running this sequence through the model will result in 

Token indices sequence length is longer than the specified maximum sequence length for this model (532 > 512). Running this sequence through the model will result in indexing errors
 85%|████████▌ | 5192/6079 [00:11<00:01, 459.74it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1027 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (732 > 512). Running this sequence through the model will result in indexing errors
 86%|████████▋ | 5246/6079 [00:11<00:01, 480.34it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (2630 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (609 > 512). Running this sequence through the model will result in 

Token indices sequence length is longer than the specified maximum sequence length for this model (1029 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (529 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (592 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (531 > 512). Running this sequence through the model will result in indexing errors
 95%|█████████▌| 5797/6079 [00:13<00:00, 400.48it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (720 > 512). Running this sequence through the model will result in indexing errors
 96%|█████████▌| 5845/6079 [00:13<00:

Token indices sequence length is longer than the specified maximum sequence length for this model (1288 > 512). Running this sequence through the model will result in indexing errors
 38%|███▊      | 182/476 [00:00<00:00, 319.15it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (864 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1282 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2037 > 512). Running this sequence through the model will result in indexing errors
 47%|████▋     | 225/476 [00:00<00:00, 345.42it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1237 > 512). Running this sequence through the model will result in in

Token indices sequence length is longer than the specified maximum sequence length for this model (715 > 512). Running this sequence through the model will result in indexing errors
 36%|███▌      | 169/476 [00:00<00:00, 384.68it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1067 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1192 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (641 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1588 > 512). Running this sequence through the model will result in indexing errors
 44%|████▍     | 210/476 [00:00<00:00

In [53]:
class QuestDataset(Dataset):
    def __init__(
        self,
        input_ids: List[torch.Tensor]=None,
        attention_mask: List[torch.Tensor]=None,
        label_tensors: torch.FloatTensor = None,
        **kwargs
    ):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.label_tensors = label_tensors
        # Create an obvious x
        self.x = list(zip(self.input_ids, self.attention_mask))
        self.y = label_tensors
        
    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return self.x[idx], self.y[idx]


In [56]:
train_q_label_tensors = torch.FloatTensor(train[QUESTION_LABELS].values)
test_q_label_tensors = torch.zeros(test.size, 30)

train_ans_label_tensors = torch.FloatTensor(train[ANSWER_LABELS].values)
test_ans_label_tensors = torch.zeros(test.size, 30)

train_q_dataset = QuestDataset(
    input_ids=train_q_input_list,
    attention_mask=train_q_att_list,
    label_tensors=train_q_label_tensors,
)

test_q_dataset = QuestDataset(
    input_ids=test_q_input_list,
    attention_mask=test_q_att_list,
    label_tensors=test_q_label_tensors,
)

train_ans_dataset = QuestDataset(
    input_ids=train_ans_input_list,
    attention_mask=train_ans_att_list,
    label_tensors=train_ans_label_tensors,
)

test_ans_dataset = QuestDataset(
    input_ids=test_ans_input_list,
    attention_mask=test_ans_att_list,
    label_tensors=test_ans_label_tensors,
)


In [63]:
torch_save_dict = {
    train_q_dataset: "train_q",
    test_q_dataset: "test_q",
    train_ans_dataset: "train_ans",
    test_ans_dataset: "test_ans"
}

for dataset, dataset_name in torch_save_dict.items():
    # dataset dir, description with underscores and fold number
    output_path = Path(DATASET_DIR, dataset_name)
    # We do not save the test fold because there is no fold
    torch.save(dataset, output_path)
    print("Saving to: "+str(output_path))

Saving to: googlequestchallenge/torch_datasets/train_q
Saving to: googlequestchallenge/torch_datasets/test_q
Saving to: googlequestchallenge/torch_datasets/train_ans
Saving to: googlequestchallenge/torch_datasets/test_ans


In [64]:
# Load in the custom tokenizer still
pretrain_model.config.num_labels = len(QUESTION_LABELS)
pretrain_model.config.output_hidden_states = True

## Training 

In [98]:
class CustomTransformerModel(nn.Module):
    """Model for our custom transformer"""
    def __init__(self, transformer_model: PreTrainedModel):
        """Simplest Transformer Model"""
        super(CustomTransformerModel, self).__init__()
        self.transformer = transformer_model

    def forward(self, input_ids, attention_mask, engineered_features=None):
        results = self.transformer(
            input_ids, attention_mask=attention_mask, engineered_features=None,
        )
        logits = results[0]
        return logits

In [99]:
q_custom_transformer_model = CustomTransformerModel(q_transformer_model)
reduce_lr_callback = partial(
    ReduceLROnPlateauCallback,
    monitor="spearman_rho",
    mode="max",
    patience=3,
    min_delta=0.001,
    min_lr=1e-6,
)

In [None]:
def flattenAnneal(learn: Learner, lr: float, n_epochs: int, start_pct: float):
    """Learning method"""
    n = len(learn.data.train_dl)
    anneal_start = int(n * n_epochs * start_pct)
    anneal_end = int(n * n_epochs) - anneal_start
    lr_array = np.array(
        [
            lr / (2.6 ** 8),
            lr / (2.6 ** 7),
            lr / (2.6 ** 6),
            lr / (2.6 ** 5),
            lr / (2.6 ** 4),
            lr / (2.6 ** 3),
            lr / (2.6 ** 2),
            lr / (2.6 ** 1),
            lr,
        ]
    )
    phases = [
        TrainingPhase(anneal_start).schedule_hp("lr", lr_array),
        TrainingPhase(anneal_end).schedule_hp("lr", lr_array, anneal=annealing_cos),
    ]
    sched = GeneralScheduler(learn, phases)
    learn.callbacks.append(sched)
    learn.fit(n_epochs)

In [None]:
# batch size
bs = 8

# Create GroupKFold index 
from sklearn.model_selection import GroupKFold
gkf = GroupKFold(n_splits=3)
gkf_splits = gkf.split(train_questions, groups=train['question_body'])

for i, (train_split, valid_split) in enumerate(gkf_splits):
    print("Fold :" + str(i))
    all_text_data = [
        train_questions, train_answers,  
        test_questions, test_answers
    ]
    #####################################
    ### Converting Text to tensor
    #####################################
    train_sampler = iter(train_split)
    valid_sampler = iter(valid_split)
    ()
    dl_kwargs = {
        "batch_size": bs,
        "shuffle": False,
        "batch_sampler": None,
        "num_workers": 0,
        # "pin_memory": True, - this has been defined in the collate wrapper implementation
    }
    train_sampler = RandomSampler(train_split,replacement=False, num_samples=None)
    valid_sampler = RandomSampler(valid_split,replacement=False, num_samples=None)
    
    train_q_dl = DataLoader(train_q_dataset, sampler=train_sampler, **dl_kwargs)
    valid_q_dl = DataLoader(train_q_dataset, sampler=valid_sampler, **dl_kwargs)
    
    
    q_databunch = TextDataBunch(
        train_dl=train_q_dl, valid_dl=valid_q_dl, device="cuda:0",
    )
    
    opt_func = partial(
        Ranger, betas=(0.9, 0.99), eps=0.05
    )
    
    q_learner = BertLearner(
        q_databunch,
        q_custom_transformer_model,
        opt_func=opt_func,
        bn_wd=False,
        true_wd=True,
        metrics=[SpearmanRho()],
        callback_fns=[reduce_lr_callback],
    ).to_fp16()
    
    q_learner.freeze()
    
    # Find a learning rate
    q_learner.lr_find()
    lr = q_learner.recorder.lrs[np.array(q_learner.recorder.losses).argmin()] / 10
    
    # Defining the learnin
    model_class = BertSequenceClassification
    q_transformer_model = model_class.from_pretrained(
        'bert-base-uncased',
        config=pretrain_model.config,
        dropout_rate=0.05,
        hidden_layer_output=2,
    )
    name = "Bert"
    freeze_to_counter = 1
    while freeze_to_counter < 6:
        freeze_to = freeze_to_counter
        print("Freezing up to "+str(freeze_to))
        model_save_name = "bert_" + str(freeze_to) + "_" + str(i)
        
        q_learner.freeze_to(-freeze_to)
        flattenAnneal(q_learner, lr, 5, 0.55)
        
        # Save the models 
        q_learner.save(model_save_name, with_opt=False)
        print("Saved to " + str(model_save_name))
        del q_learner
        gc.collect()
        torch.cuda.empty_cache()
        
        # Reset the index after every gradual unfreeze (strange workaround)
        q_databunch = TextDataBunch(
            train_dl=train_q_dl,
            valid_dl=valid_q_dl,
            test_dl=test_q_dl,
            device="cuda:0",
        )
        q_learner = model_class(
            q_databunch,
            q_custom_transformer_model,
            opt_func=opt_func,
            metrics=[SpearmanRho()],
            callback_fns=[reduce_lr_callback],
        ).to_fp16()
        # Kaggle is better than a legal drug
        q_learner.model_dir = OUTPUT_MODEL_DIR
        q_learner.load(model_save_name);
        freeze_to_counter += 1

Fold :0


epoch,train_loss,valid_loss,spearman_rho,time


LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
Freezing up to 1


epoch,train_loss,valid_loss,spearman_rho,time
0,0.443122,0.467513,0.304283,03:15
1,0.427925,0.402734,0.324784,03:14
2,0.430748,0.428227,0.322503,03:14
