In [None]:
import argparse
from email.header import decode_header
import json
import os
from collections import OrderedDict
import torch
import csv
import util
from transformers import DistilBertTokenizerFast
from transformers import DistilBertForQuestionAnswering
from transformers import AdamW
from tensorboardX import SummaryWriter


from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from args import get_train_test_args

from tqdm import tqdm

from train import read_and_process, get_dataset

In [None]:
# TODO: use a logger, use tensorboard
class Trainer:
    def __init__(self, args, log):
        self.lr = args.lr
        self.num_epochs = args.num_epochs
        self.device = args.device
        self.eval_every = args.eval_every
        self.path = os.path.join(args.save_dir, "checkpoint")
        self.num_visuals = args.num_visuals
        self.save_dir = args.save_dir
        self.log = log
        self.visualize_predictions = args.visualize_predictions
        if not os.path.exists(self.path):
            os.makedirs(self.path)

    def save(self, model):
        model.save_pretrained(self.path)

    def evaluate(
        self, model, data_loader, data_dict, return_preds=False, split="validation"
    ):
        device = self.device

        model.eval()
        pred_dict = {}
        all_start_logits = []
        all_end_logits = []
        with torch.no_grad(), tqdm(total=len(data_loader.dataset)) as progress_bar:
            for batch in data_loader:
                # Setup for forward
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                batch_size = len(input_ids)
                outputs = model(input_ids, attention_mask=attention_mask)
                print(outputs)
                assert 1 == 0
                # Forward
                start_logits, end_logits = outputs.start_logits, outputs.end_logits
                # TODO: compute loss

                all_start_logits.append(start_logits)
                all_end_logits.append(end_logits)
                progress_bar.update(batch_size)

        # Get F1 and EM scores
        start_logits = torch.cat(all_start_logits).cpu().numpy()
        end_logits = torch.cat(all_end_logits).cpu().numpy()
        preds = util.postprocess_qa_predictions(
            data_dict, data_loader.dataset.encodings, (start_logits, end_logits)
        )
        if split == "validation":
            results = util.eval_dicts(data_dict, preds)
            results_list = [("F1", results["F1"]), ("EM", results["EM"])]
        else:
            results_list = [("F1", -1.0), ("EM", -1.0)]
        results = OrderedDict(results_list)
        if return_preds:
            return preds, results
        return results

    def train(self, model, train_dataloader, eval_dataloader, val_dict, tokenizer):
        device = self.device
        model.to(device)
        optim = AdamW(model.parameters(), lr=self.lr)
        global_idx = 0
        best_scores = {"F1": -1.0, "EM": -1.0}
        tbx = SummaryWriter(self.save_dir)

        for epoch_num in range(self.num_epochs):
            self.log.info(f"Epoch: {epoch_num}")
            with torch.enable_grad(), tqdm(
                total=len(train_dataloader.dataset)
            ) as progress_bar:
                for batch in train_dataloader:
                    optim.zero_grad()
                    model.train()
                    input_ids = batch["input_ids"].to(device)
                    attention_mask = batch["attention_mask"].to(device)
                    start_positions = batch["start_positions"].to(device)
                    end_positions = batch["end_positions"].to(device)
                    outputs = model(
                        input_ids,
                        attention_mask=attention_mask,
                        start_positions=start_positions,
                        end_positions=end_positions,
                    )
                    loss = outputs[0]
                    loss.backward()
                    optim.step()
                    progress_bar.update(len(input_ids))
                    progress_bar.set_postfix(epoch=epoch_num, NLL=loss.item())
                    tbx.add_scalar("train/NLL", loss.item(), global_idx)
                    if (global_idx % self.eval_every) == 0:
                        self.log.info(f"Evaluating at step {global_idx}...")
                        preds, curr_score = self.evaluate(
                            model, eval_dataloader, val_dict, return_preds=True
                        )
                        results_str = ", ".join(
                            f"{k}: {v:05.2f}" for k, v in curr_score.items()
                        )
                        self.log.info("Visualizing in TensorBoard...")
                        for k, v in curr_score.items():
                            tbx.add_scalar(f"val/{k}", v, global_idx)
                        self.log.info(f"Eval {results_str}")
                        if self.visualize_predictions:
                            util.visualize(
                                tbx,
                                pred_dict=preds,
                                gold_dict=val_dict,
                                step=global_idx,
                                split="val",
                                num_visuals=self.num_visuals,
                            )
                        if curr_score["F1"] >= best_scores["F1"]:
                            best_scores = curr_score
                            self.save(model)
                    global_idx += 1
        return best_scores

In [None]:
args = """
--do-train
--eval-every 2000
--run-name elliot-train
--finetune-path /vision/u/naagnes/github/robustqa/save/elliot-checkpoint
--train-datasets=race,relation_extraction,duorc
--train-dir=datasets/oodomain_train
--val-dir=datasets/oodomain_val
--num-epochs 10
--eval-every 10
"""

args = """
--do-train
--eval-every 2000
--run-name elliot-train
--finetune-path /vision/u/naagnes/github/robustqa/save/elliot-checkpoint
--num-epochs 10
--eval-every 10
"""

In [None]:
# define parser and arguments
args = get_train_test_args(args=args.replace("\n", " ").split())

util.set_seed(args.seed)

# Load model for finetuning or train from scratch.
if args.finetune_path is not None and True:
    model = DistilBertForQuestionAnswering.from_pretrained(args.finetune_path, output_hidden_states=True, output_attentions=True)
else:
    model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased", output_hidden_states=True, output_attentions=True)
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

if args.do_train:
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    args.save_dir = util.get_save_dir(args.save_dir, args.run_name)
    log = util.get_logger(args.save_dir, "log_train")
    log.info(f"Args: {json.dumps(vars(args), indent=4, sort_keys=True)}")
    log.info("Preparing Training Data...")
    args.device = (
        torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    )
    trainer = Trainer(args, log)
    train_dataset, _ = get_dataset(
        args, args.train_datasets, args.train_dir, tokenizer, "train"
    )
    log.info("Preparing Validation Data...")
    val_dataset, val_dict = get_dataset(
        args, args.train_datasets, args.val_dir, tokenizer, "val"
    )
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        sampler=RandomSampler(train_dataset),
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        sampler=SequentialSampler(val_dataset),
    )

In [None]:
device = "cuda"
model.to(device)
optim = AdamW(model.parameters(), lr=args.lr)
global_idx = 0
best_scores = {"F1": -1.0, "EM": -1.0}
tbx = SummaryWriter(args.save_dir)

optim.zero_grad()
model.train()

In [None]:
with torch.enable_grad():
    for batch in train_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        start_positions = batch["start_positions"].to(device)
        end_positions = batch["end_positions"].to(device)
        outputs = model(
            input_ids,
            attention_mask=attention_mask,
            start_positions=start_positions,
            end_positions=end_positions,
        )
        break

In [None]:
print(f"input_ids shape {input_ids.shape}")
print(f"start_positions shape: {start_positions.shape}, end_positions shape {end_positions.shape}")
print(f"start_logits shape: {outputs.start_logits.shape}, end_logits shape: {outputs.end_logits.shape}")

In [None]:
for qid in range(16):
    decoded_string = tokenizer.decode(input_ids[qid])
    answer_string = tokenizer.decode(input_ids[qid][start_positions[qid]:end_positions[qid]+1])
    if answer_string == "[CLS]":
        answer_string = "NO ANSWER"
    print(decoded_string)
    print(start_positions[qid], end_positions[qid], answer_string)
    print()

In [None]:
model

In [None]:
outputs

In [None]:
outputs