In [163]:
from vocab_mismatch_utils import *
from data_formatter_utils import *
from datasets import DatasetDict
from datasets import Dataset
from datasets import load_dataset
import transformers
import pandas as pd

from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from torch.nn import CrossEntropyLoss

# Load modules, mainly huggingface basic model handlers.
# Make sure you install huggingface and other packages properly.
from collections import Counter
import json

from nltk.tokenize import TweetTokenizer
from sklearn.metrics import classification_report
from sklearn.feature_extraction import DictVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline
from sklearn.metrics import matthews_corrcoef
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import logging
logger = logging.getLogger(__name__)

import os
os.environ["TRANSFORMERS_CACHE"] = "../huggingface_cache/" # Not overload common dir 
                                                           # if run in shared resources.

import random
import sys
from dataclasses import dataclass, field
from typing import Optional
import torch
import argparse
import numpy as np
import pandas as pd
from datasets import load_dataset, load_metric
from datasets import Dataset
from datasets import DatasetDict
from tqdm import tqdm, trange

import transformers
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    EvalPrediction,
    HfArgumentParser,
    PretrainedConfig,
    Trainer,
    TrainingArguments,
    default_data_collator,
    set_seed,
    EarlyStoppingCallback
)
from transformers.trainer_utils import is_main_process, EvaluationStrategy

In [14]:
def get_dataset(inoculation_data_path, eval_data_path=None, 
                inoculation_step_sample_size=1.0, 
                eval_sample_limit=-1, seed=42):
    """
    eval_data_path is not needed if it is a saved_to_disk 
    huggingface dataset.
    
    return type is already a huggingface dataset.
    """
    pd_format = True
    if inoculation_data_path.split(".")[-1] != "tsv":
        if len(inoculation_data_path.split(".")) > 1:
            logger.info(f"***** Loading pre-loaded datasets from the disk directly! *****")
            pd_format = False
            datasets = DatasetDict.load_from_disk(inoculation_data_path)
            inoculation_step_sample_size = int(len(datasets["train"]) * inoculation_step_sample_size)
            logger.info(f"***** Inoculation Sample Count: %s *****"%(inoculation_step_sample_size))
            # this may not always start for zero inoculation
            datasets["train"] = datasets["train"].shuffle(seed=seed)
            inoculation_train_df = datasets["train"].select(range(inoculation_step_sample_size))
            eval_df = datasets["validation"]
            datasets["validation"] = datasets["validation"].shuffle(seed=seed)
            if eval_sample_limit != -1:
                datasets["validation"] = datasets["validation"].select(range(eval_sample_limit))
        else:
            logger.info(f"***** Loading downloaded huggingface datasets: {inoculation_data_path}! *****")
            pd_format = False
            if inoculation_data_path in ["sst3", "cola", "mnli", "snli", "mrps", "qnli"]:
                pass
            raise NotImplementedError()
    else:
        train_df = pd.read_csv(inoculation_data_path, delimiter="\t")
        eval_df = pd.read_csv(eval_data_path, delimiter="\t")
        inoculation_step_sample_size = int(len(train_df) * inoculation_step_sample_size)
        logger.info(f"***** Inoculation Sample Count: %s *****"%(inoculation_step_sample_size))
        # this may not always start for zero inoculation
        inoculation_train_df = train_df.sample(n=inoculation_step_sample_size, 
                                               replace=False, 
                                               random_state=seed) # seed here could not a little annoying.
    if pd_format:
        datasets = {}
        datasets["train"] = Dataset.from_pandas(inoculation_train_df)
        datasets["validation"] = Dataset.from_pandas(eval_df)
    else:
        datasets = {}
        datasets["train"] = inoculation_train_df
        datasets["validation"] = eval_df
    return datasets

In [213]:
TASK_CONFIG = {
    "sst3": ("text", None),
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "snli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence")
}
original_vocab = load_bert_vocab("../data-files/bert_vocab.txt")
original_tokenizer = transformers.BertTokenizer(
    vocab_file="../data-files/bert_vocab.txt")
max_length = 128
per_device_train_batch_size = 128
per_device_eval_batch_size = 128
no_cuda = True
device = torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
n_gpu = torch.cuda.device_count() if not no_cuda else 1 # 1 means just on cpu
seed = 42
lr = 1e-3
num_train_epochs = 1000
task_name = "qnli"
sentence1_key, sentence2_key = TASK_CONFIG[task_name]

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if n_gpu > 0 and not no_cuda:
    torch.cuda.manual_seed_all(args.seed)

In [214]:
# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
datasets = get_dataset(f"../data-files/{task_name}/{task_name}-train.tsv", 
                       f"../data-files/{task_name}/{task_name}-dev.tsv")
logger.info(f"***** Train Sample Count (Verify): %s *****"%(len(datasets["train"])))
logger.info(f"***** Valid Sample Count (Verify): %s *****"%(len(datasets["validation"])))

02/22/2021 13:27:05 - INFO - __main__ - ***** Inoculation Sample Count: 104743 *****
02/22/2021 13:27:05 - INFO - __main__ - ***** Train Sample Count (Verify): 104743 *****
02/22/2021 13:27:05 - INFO - __main__ - ***** Valid Sample Count (Verify): 5463 *****


In [215]:
train_input_features = []
train_label_ids = []
for (ex_index, example) in enumerate(tqdm(datasets["train"])):
    bow_feature = torch.zeros(len(original_vocab))
    if sentence2_key is None:
        sentence_combined = example[sentence1_key]
    else:
        sentence_combined = example[sentence1_key] + " [SEP] " + example[sentence2_key]
    sentence_combined = sentence_combined[:max_length]
    sentence_tokens = original_tokenizer.tokenize(sentence_combined)
    for t in sentence_tokens:
        bow_feature[original_vocab[t]] += 1
    train_input_features.append(bow_feature)
    train_label_ids.append(example["label"])

100%|██████████| 104743/104743 [01:40<00:00, 1046.33it/s]


In [216]:
train_input_features = torch.stack(train_input_features, dim=0)
train_input_features = torch.tensor(train_input_features, dtype=torch.float)
train_label_ids = torch.tensor(train_label_ids, dtype=torch.long)
train_data = TensorDataset(train_input_features, train_label_ids)

  


In [217]:
validation_input_features = []
validation_label_ids = []
for (ex_index, example) in enumerate(tqdm(datasets["validation"])):
    bow_feature = torch.zeros(len(original_vocab))
    if sentence2_key is None:
        sentence_combined = example[sentence1_key]
    else:
        sentence_combined = example[sentence1_key] + " [SEP] " + example[sentence2_key]
    sentence_combined = sentence_combined[:max_length]
    sentence_tokens = original_tokenizer.tokenize(sentence_combined)
    for t in sentence_tokens:
        bow_feature[original_vocab[t]] += 1
    validation_input_features.append(bow_feature)
    validation_label_ids.append(example["label"])

100%|██████████| 5463/5463 [00:04<00:00, 1130.13it/s]


In [218]:
validation_input_features = torch.stack(validation_input_features, dim=0)
validation_input_features = torch.tensor(validation_input_features, dtype=torch.float)
validation_label_ids = torch.tensor(validation_label_ids, dtype=torch.long)
validation_data = TensorDataset(validation_input_features, validation_label_ids)

  


In [219]:
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=per_device_train_batch_size*n_gpu)
validation_dataloader = DataLoader(validation_data, batch_size=per_device_eval_batch_size*n_gpu, shuffle=False)

In [220]:
class BOWClassifier(nn.Module):
    def __init__(self, num_labels, vocab_size):
        super(BOWClassifier, self).__init__()
        self.classifier = nn.Linear(vocab_size, num_labels)
    def forward(self, x, labels=None):
        logits = self.classifier(x)

        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits, labels)
            return loss, logits
        else:
            return logits

In [221]:
model = BOWClassifier(len(validation_label_ids.unique()), len(original_vocab))
optimizer = optim.Adam(model.parameters(), lr=lr)
if n_gpu > 0 and not no_cuda:
    model = torch.nn.DataParallel(model)

In [222]:
global_step = 0
for _ in range(int(num_train_epochs)):
    
    model.train()
    # pbar = tqdm(train_dataloader, desc="Iteration")
    for step, batch in enumerate(train_dataloader):
        if torch.cuda.is_available() and not no_cuda:
            torch.cuda.empty_cache()

        input_features, label_ids = batch

        if torch.cuda.is_available() and not no_cuda:
            input_features = input_features.to(device)
            label_ids = label_ids.to(device)

        loss, _ = model(input_features, labels=label_ids)

        if n_gpu > 1:
            loss = loss.mean() # mean() to average on multi-gpu.
        loss.backward()

        optimizer.step()
        model.zero_grad()
        # pbar.set_postfix({'train_loss': loss.tolist()})

        if global_step % 500 == 0:
            logger.info("***** Evaluation Interval Hit *****")
            model.eval()
            all_logits = []
            all_label_ids = []
            with torch.no_grad():
                # pbar = tqdm(validation_dataloader, desc="Iteration")
                for step, batch in enumerate(validation_dataloader):
                    if torch.cuda.is_available() and not no_cuda:
                        torch.cuda.empty_cache()
                        
                    input_features, label_ids = batch
                    
                    if torch.cuda.is_available() and not no_cuda:
                        input_features = input_features.to(device)
                        label_ids = label_ids.to(device)
                    
                    loss, logits = model(input_features, labels=label_ids)
                    logits = F.softmax(logits, dim=-1)
                    logits = logits.detach().cpu().numpy()
                    label_ids = label_ids.to('cpu').numpy()
                    outputs = np.argmax(logits, axis=1)
                    all_logits.append(outputs)
                    all_label_ids.append(label_ids)
                    
            all_logits = np.concatenate(all_logits, axis=0)
            all_label_ids = np.concatenate(all_label_ids, axis=0)
            result_to_save = classification_report(all_label_ids, all_logits, digits=5, output_dict=True)
            print(classification_report(all_label_ids, all_logits, digits=5))
            print("Macro-F1: ", result_to_save["macro avg"]["f1-score"])
                    
        global_step += 1

02/22/2021 13:31:14 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.49442   0.27868   0.35645      2702
           1    0.50533   0.72112   0.59424      2761

    accuracy                        0.50229      5463
   macro avg    0.49987   0.49990   0.47534      5463
weighted avg    0.49993   0.50229   0.47663      5463

Macro-F1:  0.47534468493177295


02/22/2021 13:31:17 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.56423   0.52665   0.54479      2702
           1    0.56511   0.60196   0.58295      2761

    accuracy                        0.56471      5463
   macro avg    0.56467   0.56430   0.56387      5463
weighted avg    0.56468   0.56471   0.56408      5463

Macro-F1:  0.5638733057850795


02/22/2021 13:31:20 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.55943   0.52258   0.54038      2702
           1    0.56108   0.59725   0.57860      2761

    accuracy                        0.56031      5463
   macro avg    0.56025   0.55991   0.55949      5463
weighted avg    0.56026   0.56031   0.55969      5463

Macro-F1:  0.5594857695329023


02/22/2021 13:31:23 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.54843   0.57624   0.56199      2702
           1    0.56364   0.53568   0.54930      2761

    accuracy                        0.55574      5463
   macro avg    0.55604   0.55596   0.55565      5463
weighted avg    0.55612   0.55574   0.55558      5463

Macro-F1:  0.5556480206553427


02/22/2021 13:31:26 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.54890   0.57328   0.56083      2702
           1    0.56342   0.53894   0.55091      2761

    accuracy                        0.55592      5463
   macro avg    0.55616   0.55611   0.55587      5463
weighted avg    0.55624   0.55592   0.55581      5463

Macro-F1:  0.5558662801156329


02/22/2021 13:31:29 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.54359   0.59771   0.56936      2702
           1    0.56380   0.50887   0.53493      2761

    accuracy                        0.55281      5463
   macro avg    0.55370   0.55329   0.55215      5463
weighted avg    0.55381   0.55281   0.55196      5463

Macro-F1:  0.5521480359791383


02/22/2021 13:31:34 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.55082   0.50740   0.52822      2702
           1    0.55245   0.59507   0.57297      2761

    accuracy                        0.55171      5463
   macro avg    0.55164   0.55124   0.55060      5463
weighted avg    0.55165   0.55171   0.55084      5463

Macro-F1:  0.5505974477656235


02/22/2021 13:31:38 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.54702   0.55329   0.55014      2702
           1    0.55788   0.55161   0.55473      2761

    accuracy                        0.55244      5463
   macro avg    0.55245   0.55245   0.55243      5463
weighted avg    0.55251   0.55244   0.55246      5463

Macro-F1:  0.5524319548070381


02/22/2021 13:31:42 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.54587   0.57476   0.55994      2702
           1    0.56112   0.53205   0.54620      2761

    accuracy                        0.55318      5463
   macro avg    0.55349   0.55341   0.55307      5463
weighted avg    0.55357   0.55318   0.55300      5463

Macro-F1:  0.5530702446296019


02/22/2021 13:31:46 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.54817   0.51591   0.53155      2702
           1    0.55205   0.58385   0.56751      2761

    accuracy                        0.55025      5463
   macro avg    0.55011   0.54988   0.54953      5463
weighted avg    0.55013   0.55025   0.54972      5463

Macro-F1:  0.5495297908218135


02/22/2021 13:31:49 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.54419   0.56514   0.55447      2702
           1    0.55777   0.53676   0.54707      2761

    accuracy                        0.55080      5463
   macro avg    0.55098   0.55095   0.55077      5463
weighted avg    0.55105   0.55080   0.55073      5463

Macro-F1:  0.5507657843499157


02/22/2021 13:31:53 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.54161   0.56847   0.55471      2702
           1    0.55615   0.52916   0.54232      2761

    accuracy                        0.54860      5463
   macro avg    0.54888   0.54881   0.54851      5463
weighted avg    0.54896   0.54860   0.54845      5463

Macro-F1:  0.5485145755464773


02/22/2021 13:31:56 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.54024   0.58882   0.56349      2702
           1    0.55878   0.50960   0.53306      2761

    accuracy                        0.54878      5463
   macro avg    0.54951   0.54921   0.54827      5463
weighted avg    0.54961   0.54878   0.54811      5463

Macro-F1:  0.5482702696193118


02/22/2021 13:32:00 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.54233   0.55477   0.54848      2702
           1    0.55428   0.54183   0.54799      2761

    accuracy                        0.54823      5463
   macro avg    0.54830   0.54830   0.54823      5463
weighted avg    0.54837   0.54823   0.54823      5463

Macro-F1:  0.5482334350610969


02/22/2021 13:32:03 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.54219   0.57069   0.55608      2702
           1    0.55708   0.52843   0.54238      2761

    accuracy                        0.54933      5463
   macro avg    0.54964   0.54956   0.54923      5463
weighted avg    0.54972   0.54933   0.54915      5463

Macro-F1:  0.549227816826354


02/22/2021 13:32:07 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.53940   0.54219   0.54079      2702
           1    0.54969   0.54690   0.54829      2761

    accuracy                        0.54457      5463
   macro avg    0.54454   0.54455   0.54454      5463
weighted avg    0.54460   0.54457   0.54458      5463

Macro-F1:  0.5445416754126284


02/22/2021 13:32:10 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.53738   0.56403   0.55038      2702
           1    0.55158   0.52481   0.53786      2761

    accuracy                        0.54421      5463
   macro avg    0.54448   0.54442   0.54412      5463
weighted avg    0.54455   0.54421   0.54405      5463

Macro-F1:  0.5441205568170027


02/22/2021 13:32:13 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.53738   0.55329   0.54522      2702
           1    0.54979   0.53386   0.54171      2761

    accuracy                        0.54347      5463
   macro avg    0.54359   0.54358   0.54347      5463
weighted avg    0.54366   0.54347   0.54345      5463

Macro-F1:  0.5434675355067201


02/22/2021 13:32:17 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.53206   0.54663   0.53925      2702
           1    0.54410   0.52952   0.53671      2761

    accuracy                        0.53798      5463
   macro avg    0.53808   0.53808   0.53798      5463
weighted avg    0.53815   0.53798   0.53797      5463

Macro-F1:  0.5379793101118933


02/22/2021 13:32:20 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.53368   0.56588   0.54931      2702
           1    0.54850   0.51612   0.53182      2761

    accuracy                        0.54073      5463
   macro avg    0.54109   0.54100   0.54056      5463
weighted avg    0.54117   0.54073   0.54047      5463

Macro-F1:  0.5405620309454928


02/22/2021 13:32:23 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.53368   0.57476   0.55346      2702
           1    0.54994   0.50851   0.52842      2761

    accuracy                        0.54128      5463
   macro avg    0.54181   0.54164   0.54094      5463
weighted avg    0.54190   0.54128   0.54080      5463

Macro-F1:  0.5409361921641647


KeyboardInterrupt: 

In [223]:
# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
corrupt_datasets = get_dataset(f"../data-files/{task_name}-corrupted")
logger.info(f"***** Train Sample Count (Verify): %s *****"%(len(datasets["train"])))
logger.info(f"***** Valid Sample Count (Verify): %s *****"%(len(datasets["validation"])))

02/22/2021 13:32:28 - INFO - __main__ - ***** Loading pre-loaded datasets from the disk directly! *****
02/22/2021 13:32:29 - INFO - __main__ - ***** Inoculation Sample Count: 104743 *****
Loading cached shuffled indices for dataset at ../data-files/qnli-corrupted/train/cache-1554dfc1dcdbbd75.arrow
Loading cached shuffled indices for dataset at ../data-files/qnli-corrupted/validation/cache-04fc1ef0f3bc5901.arrow
02/22/2021 13:32:29 - INFO - __main__ - ***** Train Sample Count (Verify): 104743 *****
02/22/2021 13:32:29 - INFO - __main__ - ***** Valid Sample Count (Verify): 5463 *****


In [224]:
corrupt_validation_input_features = []
corrupt_validation_label_ids = []
for (ex_index, example) in enumerate(tqdm(corrupt_datasets["validation"])):
    bow_feature = torch.zeros(len(original_vocab))
    if sentence2_key is None:
        sentence_combined = example[sentence1_key]
    else:
        sentence_combined = example[sentence1_key] + " [SEP] " + example[sentence2_key]
    sentence_combined = sentence_combined[:max_length]
    sentence_tokens = original_tokenizer.tokenize(sentence_combined)
    for t in sentence_tokens:
        bow_feature[original_vocab[t]] += 1
    corrupt_validation_input_features.append(bow_feature)
    corrupt_validation_label_ids.append(example["label"])

100%|██████████| 5463/5463 [00:06<00:00, 821.28it/s]


In [225]:
corrupt_validation_input_features = torch.stack(corrupt_validation_input_features, dim=0)
corrupt_validation_input_features = torch.tensor(corrupt_validation_input_features, dtype=torch.float)
corrupt_validation_label_ids = torch.tensor(corrupt_validation_label_ids, dtype=torch.long)
corrupt_validation_data = TensorDataset(corrupt_validation_input_features, corrupt_validation_label_ids)

  


In [226]:
corrupt_validation_dataloader = DataLoader(corrupt_validation_data, batch_size=per_device_eval_batch_size*n_gpu, shuffle=False)

In [227]:
logger.info("***** Evaluation With Corrupt Data *****")
model.eval()
all_logits = []
all_label_ids = []
with torch.no_grad():
    # pbar = tqdm(validation_dataloader, desc="Iteration")
    for step, batch in enumerate(corrupt_validation_dataloader):
        if torch.cuda.is_available() and not no_cuda:
            torch.cuda.empty_cache()

        input_features, label_ids = batch

        if torch.cuda.is_available() and not no_cuda:
            input_features = input_features.to(device)
            label_ids = label_ids.to(device)

        loss, logits = model(input_features, labels=label_ids)
        logits = F.softmax(logits, dim=-1)
        logits = logits.detach().cpu().numpy()
        label_ids = label_ids.to('cpu').numpy()
        outputs = np.argmax(logits, axis=1)
        all_logits.append(outputs)
        all_label_ids.append(label_ids)

all_logits = np.concatenate(all_logits, axis=0)
all_label_ids = np.concatenate(all_label_ids, axis=0)
result_to_save = classification_report(all_label_ids, all_logits, digits=5, output_dict=True)
print(classification_report(all_label_ids, all_logits, digits=5))
print("Macro-F1: ", result_to_save["macro avg"]["f1-score"])

02/22/2021 13:32:47 - INFO - __main__ - ***** Evaluation With Corrupt Data *****


              precision    recall  f1-score   support

           0    0.48950   0.92339   0.63983      2702
           1    0.43443   0.05759   0.10169      2761

    accuracy                        0.48581      5463
   macro avg    0.46196   0.49049   0.37076      5463
weighted avg    0.46167   0.48581   0.36785      5463

Macro-F1:  0.37076026696164865
