# Using ALBERT for Question Answering - SQuAD

Updates:

Things implemented:
- Data Loading
- Model
- Label Smoothing
- Learner
- Prediction on new data
- Have sortish sampler work with two keys (len(x),len(y))

TODO:
- Consider using the sliding window approach for long articles
- Use more powerful machine to train
- Load in Daily Mail data

In [1]:
from src import *
from transformers import AutoTokenizer, AlbertForQuestionAnswering
from pathlib import Path
import numpy as np
import pandas as pd
import json
import pickle
import re
import logging
from tqdm import tqdm, trange

In [2]:
config = Config(
    data_path = Path("../SQuAD/1.1"), # replace with the directory containing the parsed csv files
    task = "SQuAD",
    testing=False,
    seed = 2020,
    model = 'albert-base-v2',
    max_lr=5e-5,
    epochs=2,
    use_fp16=False,
    bs=8, 
    max_seq_len=384,
    start_tok = "[CLS]",
    end_tok = "[SEP]",
    sep_tok = "[SEP]",
    unk_tok_idx=1,
    pad_idx=0,
    feat_cols = ["paragraph","question"],
    label_cols = "idxs",
    adjustment = 1,
)

config.model_name = re.findall(r"(.+?)-",config.model)[0]

In [3]:
# utility functions
def remove_max_sl(df):
    init_len = len(df)
    df = df[df.seq_len < config.max_seq_len-2]
    new_len = len(df)
    print(f"dropping {init_len - new_len} out of {init_len} questions")
    return df

# Loading the data

In [4]:
train = pd.read_csv(config.data_path/f"train_{config.model_name}.csv")
valid = pd.read_csv(config.data_path/f"val_{config.model_name}.csv")

In [5]:
# reduce df sizes if testing
if config.testing:
    train = train[:1000]
    valid = valid[:500]

In [6]:
train, valid = remove_max_sl(train), remove_max_sl(valid)

dropping 998 out of 87599 questions
dropping 643 out of 34726 questions


In [7]:
# randomizing the order of training data
train = train.sample(frac=1,random_state = config.seed).reset_index(drop=True)
valid = valid.sample(frac=1, random_state = config.seed).reset_index(drop=True)

In [8]:
train.head()

Unnamed: 0,question,paragraph,answer,idxs,seq_len
0,How does short term memory encode information?,While short-term memory encodes information ac...,"['▁acoustic', 'ally']","[8, 10]",136
1,What showed the importance of Utrecht,"By the mid-7th century, English and Irish miss...","['▁utrecht', '▁as', '▁a', '▁centre', '▁of', '▁...","[172, 204]",224
2,"What is the legal name of the ""Stupid Motorist...",The monsoon can begin any time from mid-June t...,"['▁arizona', '▁traffic', '▁code', '▁title', '▁...","[155, 162]",260
3,How many miles did Elizabeth cover on her worl...,"From Elizabeth's birth onwards, the British Em...","['▁', '40,000', '▁miles']","[68, 71]",160
4,Who was the RAF night fighter ace that used ai...,"Nevertheless, it was radar that proved to be c...","['▁john', '▁cunningham']","[141, 143]",178


# Setting up the Tokenizer

In [9]:
class TokenizerProcessor(Processor):
    def __init__(self, tok_func, max_sl, start_tok, end_tok, pre_rules=None,post_rules=None):
        self.tok_func,self.max_sl = tok_func,max_sl
        self.pre_rules,self.post_rules=pre_rules,post_rules
        self.start_tok, self.end_tok = start_tok, end_tok

    def proc1(self, x): return [self.start_tok] + self.tok_func(x)[:self.max_sl-2] + [self.end_tok]
    
    def __call__(self, items): return tqdm([self.proc1(x) for x in items])

import collections
class NumericalizeProcessor(Processor):
    """
    only works with an existing vocab at the moment and min_freq is not accounted for
    """
    def __init__(self, vocab:dict, unk_tok_idx:int, min_freq=2): 
        self.vocab, self.unk_tok_idx, self.min_freq = vocab, unk_tok_idx, min_freq
    
    def proc1(self, x): return [self.vocab[i] if i in self.vocab else self.unk_tok_idx for i in x]
    
    def __call__(self, items): 
        if getattr(self, 'otoi', None) is None:
            self.otoi = collections.defaultdict(int,{v:k for k,v in enumerate(self.vocab)})
        return tqdm([self.proc1(x) for x in items])

In [10]:
tok = AutoTokenizer.from_pretrained(config.model)
proc_tok = TokenizerProcessor(tok.tokenize, config.max_seq_len, config.start_tok, config.end_tok)

In [11]:
vocab = {tok.convert_ids_to_tokens(i):i for i in range(tok.vocab_size)}
proc_num = NumericalizeProcessor(vocab, unk_tok_idx=config.unk_tok_idx)

In [12]:
def str2tensor(s):
    indices = re.findall("\d+",s)
    return torch.tensor([int(indices[0]), int(indices[1])], dtype=torch.long)

class QALabelProcessor(Processor):
    def __init__(self, parse_func = noop, adjustment = 1):
        self.parse_func = parse_func
        self.adjustment = adjustment
    def proc1(self, item): return self.parse_func(item) + self.adjustment
    def __call__(self, items): return [self.proc1(item) for item in items]
    

In [13]:
proc_qa = QALabelProcessor(str2tensor)

In [14]:
class TextList(ItemList):      
    @classmethod  
    def from_df(cls, df, feat_cols, label_col, sep_tok, test=False):
        feat_cols = listify(feat_cols)
        x = df[feat_cols[0]]
        for i in range(1,len(feat_cols)):
            x += f" {sep_tok} " + df[feat_cols[i]]
        labels = cls(df[label_col]) if not test else cls([0 for _ in len(df)])
        return cls(x,labels=labels)

In [15]:
il_train = TextList.from_df(train,config.feat_cols,config.label_cols,config.sep_tok)
il_valid = TextList.from_df(valid,config.feat_cols,config.label_cols,config.sep_tok)

In [None]:
ll_train = LabeledData(il_train,il_train.labels, proc_x = [proc_tok,proc_num], proc_y=proc_qa)
ll_valid = LabeledData(il_valid,il_valid.labels, proc_x = [proc_tok,proc_num], proc_y=proc_qa)

In [None]:
from torch.utils.data import Sampler

class SortSampler(Sampler):
    def __init__(self, data_source, key): self.data_source,self.key = data_source,key
    def __len__(self): return len(self.data_source)
    def __iter__(self):
        return iter(sorted(list(range(len(self.data_source))), key=self.key, reverse=True))

In [None]:
class SortishSampler(Sampler):
    def __init__(self, data_source, key, bs):
        self.data_source,self.key,self.bs = data_source,key,bs

    def __len__(self) -> int: return len(self.data_source)

    def __iter__(self):
        idxs = torch.randperm(len(self.data_source))
        megabatches = [idxs[i:i+self.bs*50] for i in range(0, len(idxs), self.bs*50)]
        sorted_idx = torch.cat([tensor(sorted(s, key=self.key, reverse=True)) for s in megabatches])
        batches = [sorted_idx[i:i+self.bs] for i in range(0, len(sorted_idx), self.bs)]
        max_idx = torch.argmax(tensor([self.key(ck[0]) for ck in batches]))  # find the chunk with the largest key,
        batches[0],batches[max_idx] = batches[max_idx],batches[0]            # then make sure it goes first.
        batch_idxs = torch.randperm(len(batches)-2)
        sorted_idx = torch.cat([batches[i+1] for i in batch_idxs]) if len(batches) > 1 else LongTensor([])
        sorted_idx = torch.cat([batches[0], sorted_idx, batches[-1]])
        return iter(sorted_idx)

In [None]:
def pad_collate_qa(samples, pad_idx=1, pad_first=False):
    max_len = max([len(s[0]) for s in samples])
    res = torch.zeros(len(samples), max_len).long() + pad_idx
    for i,s in enumerate(samples):
        if pad_first: res[i, -len(s[0]):] = torch.LongTensor(s[0])
        else:         res[i, :len(s[0]) ] = torch.LongTensor(s[0])
    return res, torch.cat([s[1].unsqueeze(0) for s in samples])

In [None]:
train_sampler = SortishSampler(ll_train.x, key=lambda t: len(ll_train[int(t)][0]), bs=config.bs)
train_dl = DataLoader(ll_train, batch_size=config.bs, sampler=train_sampler, collate_fn=pad_collate_qa)

valid_sampler = SortSampler(ll_valid.x, key=lambda t: len(ll_valid[int(t)][0]))
valid_dl = DataLoader(ll_valid, batch_size=config.bs, sampler=valid_sampler, collate_fn=pad_collate_qa)

In [None]:
iter_dl = iter(train_dl)
x,y = next(iter_dl); x.shape

# Setting up the Databunch

In [None]:
data = DataBunch(train_dl,valid_dl)

# Training Model

In [None]:
class CustomAlbertModel(nn.Module):
    def __init__(self):
        super(CustomAlbertModel,self).__init__()
        self.bert = AlbertForQuestionAnswering.from_pretrained(config.model)
        self.bert.train()
        
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
        outputs = self.bert(input_ids)
        return outputs

In [None]:
# defining the loss function
def cross_entropy_qa(input, target):
    """
    Summing the cross entropy loss from the starting and ending indices. 
    """
    loss = torch.add(F.cross_entropy(input[0], target[:,0]) , F.cross_entropy(input[1], target[:,1]))
    return loss

In [None]:
# defining the evaluation metric
def acc_qa(input,target):
    """
    Taking the average between the accuracies of predicting the start and ending indices
    """
    return (accuracy(input[0], target[:,0]) + accuracy(input[1], target[:,1]))/2

In [None]:
cbfs = [partial(AvgStatsCallback,acc_qa),
        CudaCallback,
       ProgressCallback,
       Recorder]

In [None]:
model = CustomAlbertModel()

In [None]:
def albert_splitter(m, g1=[],g2=[]):
    
    if "ffn" in list(dict(m.named_children()).keys()) :
        g2+= m.ffn_output.parameters()
        g2+= m.ffn.parameters()
    if isinstance(m,torch.nn.modules.normalization.LayerNorm):
        g2+= m.parameters()
    elif hasattr(m, 'weight'): g1+= m.parameters()
    for ll in m.children(): albert_splitter(ll, g1, g2)
    return g1,g2

[len(i) for i in albert_splitter(learn.model)]

In [None]:
# https://github.com/fastai/course-v3/blob/master/nbs/dl2/11_train_imagenette.ipynb
def create_phases(phases):
    phases = listify(phases)
    return phases + [1-sum(phases)]

# https://github.com/fastai/course-v3/blob/master/nbs/dl2/11a_transfer_learning.ipynb
def sched_1cycle(lrs, pct_start=0.3, mom_start=0.95, mom_mid=0.85, mom_end=0.95):
    phases = create_phases(pct_start)
    sched_lr  = [combine_scheds(phases, cos_1cycle_anneal(lr/10., lr, lr/1e5))
                 for lr in lrs]
    sched_mom = combine_scheds(phases, cos_1cycle_anneal(mom_start, mom_mid, mom_end))
    return [ParamScheduler('lr', sched_lr),
            ParamScheduler('mom', sched_mom)]

disc_lr_sched = sched_1cycle([config.max_lr,1e-4], 0.3) # 3e-2 best with adam, 1e-3 for lamb

In [None]:
# the learning rate we apply here does not matter since we are scheduling 
learn = Learner(model, data, cross_entropy_qa,lr=config.max_lr,cb_funcs=cbfs,splitter=albert_splitter,opt_func=lamb_opt())


In [None]:
learn.fit(1,cbs=disc_lr_sched)