In [1]:
import argparse
import numpy as np
import pandas as pd
import torch
from pathlib import Path
from torch import nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import BertPreTrainedModel, BertModel, BertConfig, BertTokenizer
from typing import List, Text

In [2]:

class BertForQuestRegression(BertPreTrainedModel):
    def __init__(self, config, head_dropout=None):
        super(BertForQuestRegression, self).__init__(config)
        self.config = config
        self.num_labels = config.num_labels
        if head_dropout is None:
            head_dropout = config.hidden_dropout_prob

        self.bert = BertModel(config)
        self.dropout = nn.Dropout(head_dropout)
        self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)

        self.init_weights()

    def forward(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
    ):
        outputs = self.bert(
            input_ids,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
        )
        sequence_output = outputs[0]
        pooled_output = torch.mean(sequence_output, dim=1)

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        return logits

    def load(self, checkpoint, strict=True, **cfg_args):
        self.config.__dict__.update(cfg_args)
        self.__init__(self.config)

        state_dict = torch.load(checkpoint)
        return self.load_state_dict(state_dict, strict=strict)

In [3]:
QUESTION_TARGETS = [
    "question_asker_intent_understanding",
    "question_body_critical",
    "question_conversational",
    "question_expect_short_answer",
    "question_fact_seeking",
    "question_has_commonly_accepted_answer",
    "question_interestingness_others",
    "question_interestingness_self",
    "question_multi_intent",
    "question_not_really_a_question",
    "question_opinion_seeking",
    "question_type_choice",
    "question_type_compare",
    "question_type_consequence",
    "question_type_definition",
    "question_type_entity",
    "question_type_instructions",
    "question_type_procedure",
    "question_type_reason_explanation",
    "question_type_spelling",
    "question_well_written",
]
ANSWER_TARGETS = [
    "answer_helpful",
    "answer_level_of_information",
    "answer_plausible",
    "answer_relevance",
    "answer_satisfaction",
    "answer_type_instructions",
    "answer_type_procedure",
    "answer_type_reason_explanation",
    "answer_well_written",
]
ALL_TARGETS = QUESTION_TARGETS + ANSWER_TARGETS

In [4]:
class QuestDataset(Dataset):
    def __init__(
        self,
        data_df,
        tokenizer,
        max_seg_length=256,
        target_cols="all_targets",
        answer_ratio=0.5,
        title_ratio=0.5,
        use_title=True,
        use_body=True,
        use_answer=True,
        title_col="question_title",
        body_col="question_body",
        answer_col="answer",
        title_transform=None,
        body_transform=None,
        answer_transform=None,
    ):
        self.tokenizer: PreTrainedTokenizer = tokenizer
        self.max_seg_length = max_seg_length
        self.target_cols = (
            QUESTION_TARGETS + ANSWER_TARGETS
            if target_cols is "all_targets"
            else target_cols
        )
        self.answer_ratio = answer_ratio
        self.title_ratio = title_ratio

        if target_cols is not None:
            if target_cols is "all_targets":
                target_cols = ALL_TARGETS
            self.targets = data_df[target_cols].values

        self.question_title = data_df[title_col].values if use_title else None
        self.question_body = data_df[body_col].values if use_body else None
        self.answer = data_df[answer_col].values if use_answer else None

        self.title_transform = title_transform
        self.body_transform = body_transform
        self.answer_transform = answer_transform

    def _encode_segments(self, *text_segments: List[Text]) -> List[List[int]]:
        # if self.transform is not None:
        #     text_segments = [self.transform(txt) for txt in text_segments]
        return [
            self.tokenizer.encode(
                txt, max_length=self.max_seg_length, add_special_tokens=False
            )
            if txt is not None
            else []
            for txt in text_segments
        ]

    def _process(self, title=None, body=None, answer=None):
        input_ids, attention_mask, token_type_ids = self._prepare_features(
            title, body, answer
        )

        input_ids = self._pad_and_truncate(
            input_ids, pad_value=self.tokenizer.pad_token_id
        )
        token_type_ids = self._pad_and_truncate(
            token_type_ids, pad_value=token_type_ids[-1]
        )
        attention_mask = self._pad_and_truncate(attention_mask, pad_value=0)
        return input_ids, attention_mask, token_type_ids

    def _pad_and_truncate(self, features, pad_value=0):
        features = list(features[: self.max_seg_length])
        features = features + [pad_value,] * (self.max_seg_length - len(features))
        features = np.array(features)
        return features

    @staticmethod
    def _balance_segments(
        first_segment_length, second_segment_length, second_ratio, max_length
    ):
        first_segment_length = min(
            first_segment_length,
            (1 - second_ratio) * max_length
            + max(second_ratio * max_length - second_segment_length, 0),
        )

        second_segment_length = min(
            second_segment_length,
            second_ratio * max_length
            + max((1 - second_ratio) * max_length - first_segment_length, 0),
        )

        return int(first_segment_length), int(second_segment_length)

    def _prepare_features(self, title, body, answer):
        title_input_ids, body_input_ids, answer_input_ids = self._encode_segments(
            title, body, answer
        )

        title_length = len(title_input_ids)
        body_length = len(body_input_ids)
        answer_length = len(answer_input_ids)

        question_length, answer_length = self._balance_segments(
            title_length + body_length,
            answer_length,
            self.answer_ratio,
            self.max_seg_length,
        )

        title_length, body_length = self._balance_segments(
            title_length, body_length, self.title_ratio, question_length
        )

        # TODO: generalize this
        question_input_ids = body_input_ids[:body_length]
        if title_length > 0:
            question_input_ids = (
                title_input_ids[:title_length]
                + [self.tokenizer.sep_token_id]
                + question_input_ids
            )
        answer_input_ids = answer_input_ids[:answer_length]

        input_ids = self.tokenizer.build_inputs_with_special_tokens(
            question_input_ids, answer_input_ids if answer_length > 0 else None
        )
        token_type_ids = self.tokenizer.create_token_type_ids_from_sequences(
            question_input_ids, answer_input_ids
        )
        attention_mask = [1.0] * len(input_ids)

        return input_ids, attention_mask, token_type_ids

    def _get_text(self, index):
        title = self.question_title[index] if self.question_title is not None else None
        body = self.question_body[index] if self.question_body is not None else None
        answer = self.answer[index] if self.answer is not None else None

        def apply_transform(txt, transform):
            if transform is not None:
                return transform(txt, idx=index)
            else:
                return txt

        title, body, answer = [
            apply_transform(txt, transform)
            for txt, transform in zip(
                [title, body, answer],
                [self.title_transform, self.body_transform, self.answer_transform],
            )
        ]

        return title, body, answer

    def __getitem__(self, index):
        title, body, answer = self._get_text(index)

        input_ids, attention_mask, token_type_ids = self._process(title, body, answer)
        targets = self.targets[index]

        input_ids, attention_mask, token_type_ids = map(
            torch.LongTensor, [input_ids, attention_mask, token_type_ids]
        )
        targets = torch.FloatTensor(targets)

        return (input_ids, attention_mask, token_type_ids), targets

    def __len__(self):
        if self.answer is not None:
            return len(self.answer)
        elif self.question_title is not None:
            return len(self.question_title)
        else:
            return len(self.question_body)

In [5]:
class TestQuestDataset(QuestDataset):
    def __init__(
        self,
        data_df,
        tokenizer,
        max_seg_length=512,
        answer_ratio=0.5,
        title_ratio=0.5,
        # Michael's Notes: Changed it so it only considers the question body
        use_title=False,
        use_body=True,
        use_answer=False,
        title_col="question_title",
        body_col="question_body",
        answer_col="answer",
    ):
        super(TestQuestDataset, self).__init__(
            data_df=data_df,
            tokenizer=tokenizer,
            max_seg_length=max_seg_length,
            target_cols=None,
            answer_ratio=answer_ratio,
            title_ratio=title_ratio,
            use_title=use_title,
            use_body=use_body,
            use_answer=use_answer,
            title_col=title_col,
            body_col=body_col,
            answer_col=answer_col,
        )

    def __getitem__(self, index):
        title, body, answer = self._get_text(index)

        input_ids, attention_mask, token_type_ids = self._process(title, body, answer)
        input_ids, attention_mask, token_type_ids = map(
            torch.LongTensor, [input_ids, attention_mask, token_type_ids]
        )
        return (input_ids, attention_mask, token_type_ids)

In [6]:
def predict(model, test_loader, columns, device="cuda"):
    model.eval()
    model.to(device)
    preds = []

    with torch.no_grad():
        for batch in tqdm(test_loader):
            pred = torch_to_numpy(model(*torch_to(batch, device)))
            preds.append(pred)

    preds = np.vstack(preds)

    preds = torch.sigmoid(torch.from_numpy(preds)).numpy()
    preds = np.clip(preds, 0, 1 - 1e-8)
    preds = pd.DataFrame(preds, columns=columns)
    return preds

In [7]:
def torch_to_numpy(obj, copy=False):
    if copy:
        func = lambda t: t.cpu().detach().numpy().copy()
    else:
        func = lambda t: t.cpu().detach().numpy()
    return torch_apply(obj, func)


In [8]:
def torch_to(obj, *args, **kargs):
    return torch_apply(obj, lambda t: t.to(*args, **kargs))

In [9]:
def torch_apply(obj, func):
    fn = lambda t: func(t) if torch.is_tensor(t) else t
    return _apply(obj, fn)

In [10]:
def _apply(obj, func):
    if isinstance(obj, (list, tuple)):
        return type(obj)(_apply(el, func) for el in obj)
    if isinstance(obj, dict):
        return {k: _apply(el, func) for k, el in obj.items()}
    return func(obj)

In [11]:
class Arg:
    pass
    
args = Arg()   
args.data_path = Path('/home/ubuntu/existing-projects/modified-qa-labeler/custom-input-output/')
args.model_dir = Path('/home/ubuntu/existing-projects/modified-qa-labeler/model1_ckpt/')
args.sub_file = Path('/home/ubuntu/existing-projects/modified-qa-labeler/output/model1_submission.csv')
args.batch_size = 8
args.num_workers = 2


In [12]:
 def get_model(targets=ALL_TARGETS):
        config = BertConfig.from_json_file(
            args.model_dir / "stackx-base-cased-config.json"
        )
        config.__dict__["num_labels"] = len(targets)

        model = BertForQuestRegression(config)
        return model

In [13]:
def predict_test_checkpoints(checkpoints, test_loader, targets, device="cuda"):
        model = get_model(targets)
        pred = []

        for path in checkpoints:
            model.load(path, map_location="cpu")
            pred.append(predict(model, test_loader, targets, device=device))

        pred = np.mean([p[targets].values for p in pred], axis=0)
        pred = np.clip(pred, 0, 1 - 1e-8)
        del model
        torch.cuda.empty_cache()
        return pd.DataFrame(pred, columns=targets)

In [14]:
test = pd.read_csv(args.data_path / "test.csv")


In [15]:
def predict_frame(df, tokenizer_path):
    tokenizer = BertTokenizer(
        args.model_dir / "stackx-base-cased-vocab.txt", do_lower_case=False
    )
    checkpoints = list(args.model_dir.glob("*.pth"))
    dataset = TestQuestDataset(df, tokenizer, max_seg_length=512)
    dataset_loader = DataLoader(
        dataset,
        shuffle=False,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
    )
    return predict_test_checkpoints(checkpoints, dataset_loader, ALL_TARGETS)
    

In [16]:
# test_df = test[0:1]
# test_df = pd.DataFrame(test_df, columns=['question_title', 'question_body', 'answer'])
# test_df['question_title'] = 'Foo'
# predict_frame(test_df)


In [19]:
import json
import os

cwd = os.getcwd()

with open('custom-input-output/khan_questions.json') as jsonfile:
    contents = json.load(jsonfile)
question_arr = np.array(contents)
frame = pd.DataFrame(question_arr, columns=['question_body'])
frame['question_title'] = ''
frame['answer'] = ''
tokenizer_path = 
labels = predict_frame(frame)

100%|██████████| 6/6 [00:00<00:00,  8.93it/s]
100%|██████████| 6/6 [00:00<00:00,  9.03it/s]
100%|██████████| 6/6 [00:00<00:00,  9.02it/s]
100%|██████████| 6/6 [00:00<00:00,  8.91it/s]
100%|██████████| 6/6 [00:00<00:00,  8.98it/s]


In [165]:
labels = pd.DataFrame(labels, columns=QUESTION_TARGETS)
frames = [frame['question_body'], labels]
pd.concat(frames, axis=1)

Unnamed: 0,question_body,question_asker_intent_understanding,question_body_critical,question_conversational,question_expect_short_answer,question_fact_seeking,question_has_commonly_accepted_answer,question_interestingness_others,question_interestingness_self,question_multi_intent,...,question_type_choice,question_type_compare,question_type_consequence,question_type_definition,question_type_entity,question_type_instructions,question_type_procedure,question_type_reason_explanation,question_type_spelling,question_well_written
0,"why didn't we call the war of 1812 ,world war 1?",0.945643,0.868897,0.179526,0.83707,0.875727,0.783926,0.723753,0.669387,0.039323,...,0.02643,0.023828,0.011115,0.081532,0.026037,0.017439,0.013348,0.94005,0.004743,0.881338
1,Why the Federalist Party was dissoluted after ...,0.927207,0.862812,0.230581,0.769753,0.816946,0.767172,0.717266,0.648196,0.035061,...,0.028899,0.020764,0.018978,0.029043,0.018781,0.02525,0.019494,0.958029,0.005142,0.868621
2,what was discussed at the Hartford convention,0.930462,0.714814,0.173825,0.881098,0.845838,0.797951,0.655597,0.597833,0.032063,...,0.06597,0.032571,0.015016,0.15996,0.154743,0.018464,0.052286,0.231092,0.008725,0.788448
3,"After the War of 1812, did the British stop tr...",0.968957,0.723829,0.067224,0.857686,0.893337,0.8984,0.703929,0.720916,0.218027,...,0.860559,0.026775,0.020826,0.031357,0.021412,0.020929,0.03421,0.343612,0.002573,0.924115
4,Was the War of 1812 really when US citizens st...,0.961232,0.836647,0.17303,0.719276,0.843397,0.765005,0.710112,0.72215,0.150674,...,0.71114,0.026368,0.019887,0.042975,0.013739,0.010068,0.014282,0.646467,0.002899,0.917618
5,Most of the events leading up to America growi...,0.946165,0.73966,0.442933,0.658478,0.399217,0.298733,0.74504,0.760461,0.049653,...,0.043941,0.019796,0.019624,0.015395,0.03263,0.026723,0.037554,0.72436,0.002363,0.882818
6,Couldn't someone argue that the French and Ind...,0.931253,0.624015,0.150641,0.533234,0.849955,0.586269,0.708665,0.716302,0.368662,...,0.184449,0.023608,0.022959,0.230627,0.033706,0.019552,0.034865,0.697787,0.003578,0.847382
7,who won the war of 1812,0.938526,0.815289,0.121631,0.939695,0.862525,0.790204,0.62519,0.550433,0.023453,...,0.063888,0.029847,0.018001,0.050493,0.620614,0.043319,0.048633,0.125697,0.008345,0.791754
8,why did it all happen?,0.91744,0.826733,0.237,0.775816,0.82302,0.738261,0.715536,0.61157,0.045972,...,0.017735,0.023214,0.025306,0.037049,0.02286,0.038195,0.027199,0.965079,0.007387,0.870301
9,did any other European countries fight in the ...,0.947529,0.713103,0.053253,0.866881,0.90955,0.912919,0.685022,0.627767,0.178557,...,0.857553,0.015797,0.019522,0.017419,0.148611,0.021195,0.027002,0.302099,0.00438,0.862474


In [None]:
submission.to_csv(args.sub_file, index=False)