In [1]:
import dataclasses
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Dict, Optional

import numpy as np

In [2]:
from transformers import (
    AutoConfig,
    AutoModelForMultipleChoice,
    AutoTokenizer,
    EvalPrediction,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    set_seed )

In [3]:
from utils_superglue_mcqa import (
    MultipleChoiceDataset, 
    Split, 
    processors, 
    superglue_mcqa_output_modes, 
    superglue_mcqa_tasks_num_labels, 
    superglue_mcqa_compute_metrics,
)

In [4]:
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
    )

@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(processors.keys())})
    data_dir: str = field(metadata={"help": "Should contain the data files for the task."})
    max_seq_length: int = field(
        default=128,
        metadata={
            "help": "The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )
    
    def __post_init__(self):
        self.task_name = self.task_name.lower()

parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
logger = logging.getLogger(__name__)

MODEL_NAME = "bert-base-cased"
DATESTAMP = "20200805"
SUPER_GLUE_DIR = "/home/keyur/medhas/superglue_data/"
TASK_NAME = "MultiRC"
PER_DEVICE_BATCH_SIZE = 16
EXPERIMENT_DIR="/mnt/data/medhas/glue_experiments/%s/%s"%(MODEL_NAME, DATESTAMP)

custom_sysargv = [
"--model_name_or_path=%s"%MODEL_NAME,
"--task_name=%s"%TASK_NAME,
"--do_train",
"--do_eval",
"--data_dir=%s"%os.path.join(SUPER_GLUE_DIR, TASK_NAME),
"--max_seq_length=512",
"--per_device_train_batch_size=%s"%PER_DEVICE_BATCH_SIZE,
"--learning_rate=10e-5",
"--num_train_epochs=10",
"--output_dir=%s"%os.path.join(EXPERIMENT_DIR, TASK_NAME),
"--logging_dir=%s/logs"%os.path.join(EXPERIMENT_DIR, TASK_NAME),
"--logging_steps=212",
"--evaluate_during_training",
"--eval_step=212",
"--save_total_limit=2",
"--save_steps=1000",
"--gradient_accumulation_steps=1",
"--overwrite_output_dir"
]

model_args, data_args, training_args = parser.parse_args_into_dataclasses(args=custom_sysargv)

# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.WARN if training_args.local_rank in [-1, 0] else logging.WARN,
)
logger.warning(
    "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
    training_args.local_rank,
    training_args.device,
    training_args.n_gpu,
    bool(training_args.local_rank != -1),
    training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)

set_seed(training_args.seed)
training_args.seed

num_labels = superglue_mcqa_tasks_num_labels[data_args.task_name]
output_mode = superglue_mcqa_output_modes[data_args.task_name]
print ("Task:", data_args.task_name, "Labels:", num_labels, ', Output', output_mode)




Task: multirc Labels: 2 , Output classification


In [5]:
config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else     model_args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
    )

In [6]:
tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )

In [7]:
model = AutoModelForMultipleChoice.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
)
#model = BertForNLI.from_pretrained(model_args.model_name_or_path, config=config, cache_dir=model_args.cache_dir)

- This IS expected if you are initializing BertForMultipleChoice from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForMultipleChoice from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
train_dataset = MultipleChoiceDataset(data_args.data_dir, tokenizer=tokenizer, 
                    task=data_args.task_name, max_seq_length=data_args.max_seq_length, 
                    overwrite_cache=data_args.overwrite_cache, mode=Split.train,) if training_args.do_train else None

eval_dataset = MultipleChoiceDataset(data_args.data_dir, tokenizer=tokenizer, 
                    task=data_args.task_name, max_seq_length=data_args.max_seq_length, 
                    overwrite_cache=data_args.overwrite_cache, mode=Split.dev,) if training_args.do_eval else None

test_dataset = MultipleChoiceDataset(data_args.data_dir, tokenizer=tokenizer, 
                    task=data_args.task_name, max_seq_length=data_args.max_seq_length, 
                    overwrite_cache=data_args.overwrite_cache, mode=Split.test,) if training_args.do_predict else None

In [9]:
def compute_metrics(p: EvalPrediction) -> Dict:
    if output_mode == "classification":
        preds = np.argmax(p.predictions, axis=1)
    elif output_mode == "regression":
        preds = np.squeeze(p.predictions)
    return superglue_mcqa_compute_metrics(data_args.task_name, preds, p.label_ids)

In [10]:
# Initialize our Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
)

In [11]:
if training_args.do_train:
        trainer.train(
            model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
        )
        trainer.save_model()
        # For convenience, we also re-save the tokenizer to the same directory,
        # so that you can share your model easily on huggingface.co/models =)
        if trainer.is_world_master():
            tokenizer.save_pretrained(training_args.output_dir)

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=10.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=847.0, style=ProgressStyle(description_wi…



{'loss': 1.474056603773585, 'learning_rate': 9.749704840613932e-05, 'epoch': 0.2502951593860685, 'step': 212}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=302.0, style=ProgressStyle(description_w…


{'eval_loss': 0.12400425761777735, 'eval_acc': 0.572079536039768, 'epoch': 0.2502951593860685, 'step': 212}
{'loss': 0.0, 'learning_rate': 9.499409681227863e-05, 'epoch': 0.500590318772137, 'step': 424}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=302.0, style=ProgressStyle(description_w…


{'eval_loss': 0.12400425761777735, 'eval_acc': 0.572079536039768, 'epoch': 0.500590318772137, 'step': 424}
{'loss': 0.0, 'learning_rate': 9.249114521841796e-05, 'epoch': 0.7508854781582054, 'step': 636}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=302.0, style=ProgressStyle(description_w…


{'eval_loss': 0.12400425761777735, 'eval_acc': 0.572079536039768, 'epoch': 0.7508854781582054, 'step': 636}



HBox(children=(FloatProgress(value=0.0, description='Iteration', max=847.0, style=ProgressStyle(description_wi…

{'loss': 25.058962264150942, 'learning_rate': 8.998819362455726e-05, 'epoch': 1.001180637544274, 'step': 848}


HBox(children=(FloatProgress(value=0.0, description='Evaluation', max=302.0, style=ProgressStyle(description_w…


{'eval_loss': 0.12400425761777735, 'eval_acc': 0.572079536039768, 'epoch': 1.001180637544274, 'step': 848}








KeyboardInterrupt: 

In [None]:
import torch
torch.Tensor([1,3,4]).cuda()

In [None]:
eval_results = {}
if training_args.do_eval:
    #logger.info("*** Evaluate ***")

    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_datasets = [eval_dataset]

    for eval_dataset in eval_datasets:
        eval_result = trainer.evaluate(eval_dataset=eval_dataset)

        output_eval_file = os.path.join(
            training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt"
        )
        if trainer.is_world_master():
            with open(output_eval_file, "w") as writer:
                #logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name))
                for key, value in eval_result.items():
                    #logger.info("  %s = %s", key, value)
                    writer.write("%s = %s\n" % (key, value))

        eval_results.update(eval_result)

In [None]:
eval_results

In [None]:
import torch
a = torch.randn(16)

In [None]:
a

In [None]:
a.view(-1, 4)

In [16]:
train_dataset[0]

InputFeatures(example_id='train-1-0-0', input_ids=[[101, 1109, 11158, 1261, 1282, 1113, 1357, 1542, 117, 1103, 4598, 1113, 1428, 1853, 119, 5630, 117, 2530, 1273, 10448, 4884, 1132, 9829, 1112, 5307, 25856, 131, 107, 4673, 1759, 1118, 1260, 1643, 2047, 3970, 1128, 1104, 5618, 1105, 18111, 1240, 1713, 1106, 5475, 1103, 12374, 118, 118, 1114, 2423, 6014, 4133, 119, 1135, 1110, 8431, 1193, 117, 1191, 25731, 26610, 1193, 117, 1694, 119, 107, 1249, 2382, 1807, 117, 1103, 107, 1121, 1139, 2504, 2044, 1493, 107, 1226, 1110, 2566, 4673, 112, 188, 1236, 1106, 8698, 1124, 6186, 119, 2966, 2256, 1133, 4673, 112, 188, 4217, 2458, 1122, 1112, 1625, 1950, 136, 1124, 4664, 1674, 1136, 107, 19795, 1122, 1106, 170, 4055, 1187, 1122, 1108, 1136, 23056, 107, 1105, 117, 1112, 2382, 1807, 117, 1833, 1177, 3059, 1156, 1294, 1185, 2305, 20748, 1191, 4673, 1108, 1103, 3283, 22448, 1260, 23566, 1197, 1115, 1117, 4217, 3548, 1119, 1110, 119, 16752, 14840, 3381, 1103, 13206, 9800, 2315, 3669, 1187, 1124, 6186, 1

In [17]:
train_dataset[1]

InputFeatures(example_id='train-1-0-1', input_ids=[[101, 1109, 11158, 1261, 1282, 1113, 1357, 1542, 117, 1103, 4598, 1113, 1428, 1853, 119, 5630, 117, 2530, 1273, 10448, 4884, 1132, 9829, 1112, 5307, 25856, 131, 107, 4673, 1759, 1118, 1260, 1643, 2047, 3970, 1128, 1104, 5618, 1105, 18111, 1240, 1713, 1106, 5475, 1103, 12374, 118, 118, 1114, 2423, 6014, 4133, 119, 1135, 1110, 8431, 1193, 117, 1191, 25731, 26610, 1193, 117, 1694, 119, 107, 1249, 2382, 1807, 117, 1103, 107, 1121, 1139, 2504, 2044, 1493, 107, 1226, 1110, 2566, 4673, 112, 188, 1236, 1106, 8698, 1124, 6186, 119, 2966, 2256, 1133, 4673, 112, 188, 4217, 2458, 1122, 1112, 1625, 1950, 136, 1124, 4664, 1674, 1136, 107, 19795, 1122, 1106, 170, 4055, 1187, 1122, 1108, 1136, 23056, 107, 1105, 117, 1112, 2382, 1807, 117, 1833, 1177, 3059, 1156, 1294, 1185, 2305, 20748, 1191, 4673, 1108, 1103, 3283, 22448, 1260, 23566, 1197, 1115, 1117, 4217, 3548, 1119, 1110, 119, 16752, 14840, 3381, 1103, 13206, 9800, 2315, 3669, 1187, 1124, 6186, 1

In [18]:
train_dataset[2]

InputFeatures(example_id='train-1-0-2', input_ids=[[101, 1109, 11158, 1261, 1282, 1113, 1357, 1542, 117, 1103, 4598, 1113, 1428, 1853, 119, 5630, 117, 2530, 1273, 10448, 4884, 1132, 9829, 1112, 5307, 25856, 131, 107, 4673, 1759, 1118, 1260, 1643, 2047, 3970, 1128, 1104, 5618, 1105, 18111, 1240, 1713, 1106, 5475, 1103, 12374, 118, 118, 1114, 2423, 6014, 4133, 119, 1135, 1110, 8431, 1193, 117, 1191, 25731, 26610, 1193, 117, 1694, 119, 107, 1249, 2382, 1807, 117, 1103, 107, 1121, 1139, 2504, 2044, 1493, 107, 1226, 1110, 2566, 4673, 112, 188, 1236, 1106, 8698, 1124, 6186, 119, 2966, 2256, 1133, 4673, 112, 188, 4217, 2458, 1122, 1112, 1625, 1950, 136, 1124, 4664, 1674, 1136, 107, 19795, 1122, 1106, 170, 4055, 1187, 1122, 1108, 1136, 23056, 107, 1105, 117, 1112, 2382, 1807, 117, 1833, 1177, 3059, 1156, 1294, 1185, 2305, 20748, 1191, 4673, 1108, 1103, 3283, 22448, 1260, 23566, 1197, 1115, 1117, 4217, 3548, 1119, 1110, 119, 16752, 14840, 3381, 1103, 13206, 9800, 2315, 3669, 1187, 1124, 6186, 1