In [23]:
from vocab_mismatch_utils import *
from data_formatter_utils import *
from datasets import DatasetDict
from datasets import Dataset
from datasets import load_dataset
import transformers
import pandas as pd
from collections import OrderedDict 

from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from torch.nn import CrossEntropyLoss

# Load modules, mainly huggingface basic model handlers.
# Make sure you install huggingface and other packages properly.
from collections import Counter
import json

from nltk.tokenize import TweetTokenizer
from sklearn.metrics import classification_report
from sklearn.feature_extraction import DictVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline
from sklearn.metrics import matthews_corrcoef
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import logging
logger = logging.getLogger(__name__)

import os
os.environ["TRANSFORMERS_CACHE"] = "../huggingface_cache/" # Not overload common dir 
                                                           # if run in shared resources.

import random
import sys
from dataclasses import dataclass, field
from typing import Optional
import torch
import argparse
import numpy as np
import pandas as pd
from datasets import load_dataset, load_metric
from datasets import Dataset
from datasets import DatasetDict
from tqdm import tqdm, trange

import transformers
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    EvalPrediction,
    HfArgumentParser,
    PretrainedConfig,
    Trainer,
    TrainingArguments,
    default_data_collator,
    set_seed,
    EarlyStoppingCallback
)
from transformers.trainer_utils import is_main_process, EvaluationStrategy

In [85]:
def get_dataset(inoculation_data_path, eval_data_path=None, test_data_path=None,
                inoculation_step_sample_size=1.0, 
                eval_sample_limit=-1, seed=42):
    """
    eval_data_path is not needed if it is a saved_to_disk 
    huggingface dataset.
    
    return type is already a huggingface dataset.
    """
    pd_format = True
    if inoculation_data_path.split(".")[-1] != "tsv":
        if len(inoculation_data_path.split(".")) > 1:
            logger.info(f"***** Loading pre-loaded datasets from the disk directly! *****")
            pd_format = False
            datasets = DatasetDict.load_from_disk(inoculation_data_path)
            inoculation_step_sample_size = int(len(datasets["train"]) * inoculation_step_sample_size)
            logger.info(f"***** Inoculation Sample Count: %s *****"%(inoculation_step_sample_size))
            # this may not always start for zero inoculation
            datasets["train"] = datasets["train"].shuffle(seed=seed)
            inoculation_train_df = datasets["train"].select(range(inoculation_step_sample_size))
            eval_df = datasets["validation"]
            datasets["validation"] = datasets["validation"].shuffle(seed=seed)
            if eval_sample_limit != -1:
                datasets["validation"] = datasets["validation"].select(range(eval_sample_limit))
        else:
            logger.info(f"***** Loading downloaded huggingface datasets: {inoculation_data_path}! *****")
            pd_format = False
            if inoculation_data_path in ["sst3", "cola", "mnli", "snli", "mrps", "qnli"]:
                pass
            raise NotImplementedError()
    else:
        train_df = pd.read_csv(inoculation_data_path, delimiter="\t")
        eval_df = pd.read_csv(eval_data_path, delimiter="\t")
        test_df = pd.read_csv(test_data_path, delimiter="\t")
        inoculation_step_sample_size = int(len(train_df) * inoculation_step_sample_size)
        logger.info(f"***** Inoculation Sample Count: %s *****"%(inoculation_step_sample_size))
        # this may not always start for zero inoculation
        inoculation_train_df = train_df.sample(n=inoculation_step_sample_size, 
                                               replace=False, 
                                               random_state=seed) # seed here could not a little annoying.
    if pd_format:
        datasets = {}
        datasets["train"] = Dataset.from_pandas(inoculation_train_df)
        datasets["validation"] = Dataset.from_pandas(eval_df)
        datasets["test"] = Dataset.from_pandas(test_df)
    else:
        datasets = {}
        datasets["train"] = inoculation_train_df
        datasets["validation"] = eval_df
    return datasets

In [86]:
TASK_CONFIG = {
    "sst3": ("text", None),
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "snli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence")
}
# WARNING: you dont need BERT tokenizer
# original_vocab = load_bert_vocab("../data-files/bert_vocab.txt")
# original_tokenizer = transformers.BertTokenizer(
#     vocab_file="../data-files/bert_vocab.txt")
# Just use some basic white space tokenizor here!
modified_basic_tokenizer = ModifiedBasicTokenizer()
max_length = 128
per_device_train_batch_size = 128
per_device_eval_batch_size = 128
no_cuda = True
device = torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
n_gpu = torch.cuda.device_count() if not no_cuda else 1 # 1 means just on cpu
seed = 42
lr = 1e-3
num_train_epochs = 10
task_name = "sst3"
sentence1_key, sentence2_key = TASK_CONFIG[task_name]

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if n_gpu > 0 and not no_cuda:
    torch.cuda.manual_seed_all(args.seed)

In [87]:
# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
data_file_name = task_name if task_name != "sst3" else "sst-tenary"
datasets = get_dataset(f"../data-files/{data_file_name}/{data_file_name}-train.tsv", 
                       f"../data-files/{data_file_name}/{data_file_name}-dev.tsv", 
                       f"../data-files/{data_file_name}/{data_file_name}-test.tsv")
logger.info(f"***** Train Sample Count (Verify): %s *****"%(len(datasets["train"])))
logger.info(f"***** Valid Sample Count (Verify): %s *****"%(len(datasets["validation"])))
logger.info(f"***** Test Sample Count (Verify): %s *****"%(len(datasets["test"])))

03/14/2021 23:57:02 - INFO - __main__ - ***** Inoculation Sample Count: 159274 *****
03/14/2021 23:57:02 - INFO - __main__ - ***** Train Sample Count (Verify): 159274 *****
03/14/2021 23:57:02 - INFO - __main__ - ***** Valid Sample Count (Verify): 1100 *****
03/14/2021 23:57:02 - INFO - __main__ - ***** Test Sample Count (Verify): 2210 *****


In [97]:
# create the vocab file
vocab_index = 0
original_vocab = OrderedDict()
if "train" in datasets:
    for (ex_index, example) in enumerate(tqdm(datasets["train"])):
        if sentence2_key is None:
            sentence_combined = example[sentence1_key]
        else:
            sentence_combined = example[sentence1_key] + " [SEP] " + example[sentence2_key]
        sentence_tokens = modified_basic_tokenizer.tokenize(sentence_combined)
        for token in sentence_tokens:
            if token not in original_vocab.keys():
                original_vocab[token] = vocab_index
                vocab_index += 1
train_data_only = False
if not train_data_only:
    if "validation" in datasets:
        for (ex_index, example) in enumerate(tqdm(datasets["validation"])):
            if sentence2_key is None:
                sentence_combined = example[sentence1_key]
            else:
                sentence_combined = example[sentence1_key] + " [SEP] " + example[sentence2_key]
            sentence_tokens = modified_basic_tokenizer.tokenize(sentence_combined)
            for token in sentence_tokens:
                if token not in original_vocab.keys():
                    original_vocab[token] = vocab_index
                    vocab_index += 1

    if "test" in datasets:
        for (ex_index, example) in enumerate(tqdm(datasets["test"])):
            if sentence2_key is None:
                sentence_combined = example[sentence1_key]
            else:
                sentence_combined = example[sentence1_key] + " [SEP] " + example[sentence2_key]
            sentence_tokens = modified_basic_tokenizer.tokenize(sentence_combined)
            for token in sentence_tokens:
                if token not in original_vocab.keys():
                    original_vocab[token] = vocab_index
                    vocab_index += 1

100%|██████████| 159274/159274 [00:18<00:00, 8413.97it/s]
100%|██████████| 1100/1100 [00:00<00:00, 4097.65it/s]
100%|██████████| 2210/2210 [00:00<00:00, 4186.32it/s]


In [98]:
train_input_features = []
train_label_ids = []
for (ex_index, example) in enumerate(tqdm(datasets["train"])):
    bow_feature = torch.zeros(len(original_vocab))
    if sentence2_key is None:
        sentence_combined = example[sentence1_key]
    else:
        sentence_combined = example[sentence1_key] + " [SEP] " + example[sentence2_key]
    sentence_tokens = modified_basic_tokenizer.tokenize(sentence_combined)
    sentence_tokens = sentence_tokens[:max_length]
    if ex_index % 50000 == 0:
        print("Example sentence: " + sentence_combined)
    for t in sentence_tokens:
        bow_feature[original_vocab[t]] += 1
    train_input_features.append(bow_feature)
    train_label_ids.append(example["label"])

  0%|          | 317/159274 [00:00<00:59, 2652.06it/s]

Example sentence: Surprisingly, considering that Baird is a former film editor, the movie is rather choppy.


 32%|███▏      | 50471/159274 [00:13<00:29, 3718.51it/s]

Example sentence: achronological


 63%|██████▎   | 100714/159274 [00:26<00:16, 3649.76it/s]

Example sentence: Show


 95%|█████████▍| 150740/159274 [00:40<00:02, 3691.99it/s]

Example sentence: picked me up ,


100%|██████████| 159274/159274 [00:43<00:00, 3692.34it/s]


In [99]:
train_input_features = torch.stack(train_input_features, dim=0)
train_input_features = torch.tensor(train_input_features, dtype=torch.float)
train_label_ids = torch.tensor(train_label_ids, dtype=torch.long)
train_data = TensorDataset(train_input_features, train_label_ids)

  


In [100]:
validation_input_features = []
validation_label_ids = []
for (ex_index, example) in enumerate(tqdm(datasets["validation"])):
    bow_feature = torch.zeros(len(original_vocab))
    if sentence2_key is None:
        sentence_combined = example[sentence1_key]
    else:
        sentence_combined = example[sentence1_key] + " [SEP] " + example[sentence2_key]
    sentence_tokens = modified_basic_tokenizer.tokenize(sentence_combined)
    sentence_tokens = sentence_tokens[:max_length]
    for t in sentence_tokens:
        if t in original_vocab.keys():
            bow_feature[original_vocab[t]] += 1
    validation_input_features.append(bow_feature)
    validation_label_ids.append(example["label"])

100%|██████████| 1100/1100 [00:00<00:00, 1962.63it/s]


In [101]:
validation_input_features = torch.stack(validation_input_features, dim=0)
validation_input_features = torch.tensor(validation_input_features, dtype=torch.float)
validation_label_ids = torch.tensor(validation_label_ids, dtype=torch.long)
validation_data = TensorDataset(validation_input_features, validation_label_ids)

  


In [102]:
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=per_device_train_batch_size*n_gpu)
validation_dataloader = DataLoader(validation_data, batch_size=per_device_eval_batch_size*n_gpu, shuffle=False)

In [103]:
class BOWClassifier(nn.Module):
    def __init__(self, num_labels, vocab_size):
        super(BOWClassifier, self).__init__()
        self.classifier = nn.Linear(vocab_size, num_labels)
    def forward(self, x, labels=None):
        logits = self.classifier(x)

        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits, labels)
            return loss, logits
        else:
            return logits

In [104]:
model = BOWClassifier(len(validation_label_ids.unique()), len(original_vocab))
optimizer = optim.Adam(model.parameters(), lr=lr)
if n_gpu > 0 and not no_cuda:
    model = torch.nn.DataParallel(model)

In [105]:
global_step = 0
for _ in range(int(num_train_epochs)):
    
    model.train()
    # pbar = tqdm(train_dataloader, desc="Iteration")
    for step, batch in enumerate(train_dataloader):
        if torch.cuda.is_available() and not no_cuda:
            torch.cuda.empty_cache()

        input_features, label_ids = batch

        if torch.cuda.is_available() and not no_cuda:
            input_features = input_features.to(device)
            label_ids = label_ids.to(device)

        loss, _ = model(input_features, labels=label_ids)

        if n_gpu > 1:
            loss = loss.mean() # mean() to average on multi-gpu.
        loss.backward()

        optimizer.step()
        model.zero_grad()
        # pbar.set_postfix({'train_loss': loss.tolist()})

        if global_step % 500 == 0:
            logger.info("***** Evaluation Interval Hit *****")
            model.eval()
            all_logits = []
            all_label_ids = []
            with torch.no_grad():
                # pbar = tqdm(validation_dataloader, desc="Iteration")
                for step, batch in enumerate(validation_dataloader):
                    if torch.cuda.is_available() and not no_cuda:
                        torch.cuda.empty_cache()
                        
                    input_features, label_ids = batch
                    
                    if torch.cuda.is_available() and not no_cuda:
                        input_features = input_features.to(device)
                        label_ids = label_ids.to(device)
                    
                    loss, logits = model(input_features, labels=label_ids)
                    logits = F.softmax(logits, dim=-1)
                    logits = logits.detach().cpu().numpy()
                    label_ids = label_ids.to('cpu').numpy()
                    outputs = np.argmax(logits, axis=1)
                    all_logits.append(outputs)
                    all_label_ids.append(label_ids)
                    
            all_logits = np.concatenate(all_logits, axis=0)
            all_label_ids = np.concatenate(all_label_ids, axis=0)
            result_to_save = classification_report(all_label_ids, all_logits, digits=5, output_dict=True)
            print(classification_report(all_label_ids, all_logits, digits=5))
            print("Macro-F1: ", result_to_save["macro avg"]["f1-score"])
                    
        global_step += 1

03/15/2021 00:04:34 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.41014   0.62383   0.49490       428
           1    0.39572   0.16667   0.23455       444
           2    0.23664   0.27193   0.25306       228

    accuracy                        0.36636      1100
   macro avg    0.34750   0.35414   0.32750      1100
weighted avg    0.36836   0.36636   0.33969      1100

Macro-F1:  0.3275040827127372


03/15/2021 00:04:36 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.69416   0.47196   0.56189       428
           1    0.62500   0.77703   0.69277       444
           2    0.31518   0.35526   0.33402       228

    accuracy                        0.57091      1100
   macro avg    0.54478   0.53475   0.52956      1100
weighted avg    0.58769   0.57091   0.56749      1100

Macro-F1:  0.5295610729628291


03/15/2021 00:04:39 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.66754   0.59579   0.62963       428
           1    0.66414   0.78829   0.72091       444
           2    0.35602   0.29825   0.32458       228

    accuracy                        0.61182      1100
   macro avg    0.56257   0.56078   0.55837      1100
weighted avg    0.60160   0.61182   0.60324      1100

Macro-F1:  0.5583727502383646


03/15/2021 00:04:41 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.66919   0.61916   0.64320       428
           1    0.67692   0.79279   0.73029       444
           2    0.35326   0.28509   0.31553       228

    accuracy                        0.62000      1100
   macro avg    0.56646   0.56568   0.56301      1100
weighted avg    0.60683   0.62000   0.61044      1100

Macro-F1:  0.5630094401697351


03/15/2021 00:04:44 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.66430   0.65654   0.66040       428
           1    0.69261   0.78153   0.73439       444
           2    0.35227   0.27193   0.30693       228

    accuracy                        0.62727      1100
   macro avg    0.56973   0.57000   0.56724      1100
weighted avg    0.61106   0.62727   0.61700      1100

Macro-F1:  0.5672405858085295


03/15/2021 00:04:48 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.67990   0.64019   0.65945       428
           1    0.68582   0.80631   0.74120       444
           2    0.35429   0.27193   0.30769       228

    accuracy                        0.63091      1100
   macro avg    0.57334   0.57281   0.56945      1100
weighted avg    0.61480   0.63091   0.61954      1100

Macro-F1:  0.5694465286366087


03/15/2021 00:04:52 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.68550   0.65187   0.66826       428
           1    0.69246   0.80631   0.74506       444
           2    0.38068   0.29386   0.33168       228

    accuracy                        0.64000      1100
   macro avg    0.58621   0.58401   0.58167      1100
weighted avg    0.62513   0.64000   0.62950      1100

Macro-F1:  0.5816679578068906


03/15/2021 00:04:57 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.68599   0.66355   0.67458       428
           1    0.69941   0.80180   0.74711       444
           2    0.37853   0.29386   0.33086       228

    accuracy                        0.64273      1100
   macro avg    0.58798   0.58640   0.58419      1100
weighted avg    0.62768   0.64273   0.63262      1100

Macro-F1:  0.5841876320756892


03/15/2021 00:05:04 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.70398   0.66121   0.68193       428
           1    0.69364   0.81081   0.74766       444
           2    0.39106   0.30702   0.34398       228

    accuracy                        0.64818      1100
   macro avg    0.59623   0.59301   0.59119      1100
weighted avg    0.63495   0.64818   0.63841      1100

Macro-F1:  0.5911905354085288


03/15/2021 00:05:10 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.69927   0.66822   0.68339       428
           1    0.69591   0.80405   0.74608       444
           2    0.39888   0.31140   0.34975       228

    accuracy                        0.64909      1100
   macro avg    0.59802   0.59456   0.59308      1100
weighted avg    0.63565   0.64909   0.63954      1100

Macro-F1:  0.5930760899244399


03/15/2021 00:05:15 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.70218   0.67757   0.68966       428
           1    0.70472   0.80631   0.75210       444
           2    0.39665   0.31140   0.34889       228

    accuracy                        0.65364      1100
   macro avg    0.60118   0.59843   0.59688      1100
weighted avg    0.63988   0.65364   0.64423      1100

Macro-F1:  0.5968834538814255


03/15/2021 00:05:17 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.71139   0.65654   0.68287       428
           1    0.70541   0.79279   0.74655       444
           2    0.38350   0.34649   0.36406       228

    accuracy                        0.64727      1100
   macro avg    0.60010   0.59861   0.59783      1100
weighted avg    0.64101   0.64727   0.64249      1100

Macro-F1:  0.5978254699156306


03/15/2021 00:05:19 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.70707   0.65421   0.67961       428
           1    0.69472   0.79955   0.74346       444
           2    0.38342   0.32456   0.35154       228

    accuracy                        0.64455      1100
   macro avg    0.59507   0.59277   0.59154      1100
weighted avg    0.63500   0.64455   0.63738      1100

Macro-F1:  0.5915370302868367


03/15/2021 00:05:22 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.70443   0.66822   0.68585       428
           1    0.70766   0.79054   0.74681       444
           2    0.39899   0.34649   0.37089       228

    accuracy                        0.65091      1100
   macro avg    0.60369   0.60175   0.60118      1100
weighted avg    0.64243   0.65091   0.64517      1100

Macro-F1:  0.6011839494541615


03/15/2021 00:05:24 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.70976   0.67991   0.69451       428
           1    0.69574   0.80856   0.74792       444
           2    0.40805   0.31140   0.35323       228

    accuracy                        0.65545      1100
   macro avg    0.60451   0.59996   0.59855      1100
weighted avg    0.64156   0.65545   0.64533      1100

Macro-F1:  0.5985537457897466


03/15/2021 00:05:26 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.71608   0.66589   0.69007       428
           1    0.69709   0.80856   0.74870       444
           2    0.41176   0.33772   0.37108       228

    accuracy                        0.65545      1100
   macro avg    0.60831   0.60406   0.60328      1100
weighted avg    0.64534   0.65545   0.64762      1100

Macro-F1:  0.6032845118300387


03/15/2021 00:05:29 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.71867   0.65654   0.68620       428
           1    0.69608   0.79955   0.74423       444
           2    0.39196   0.34211   0.36534       228

    accuracy                        0.64909      1100
   macro avg    0.60224   0.59940   0.59859      1100
weighted avg    0.64183   0.64909   0.64312      1100

Macro-F1:  0.5985923551651977


03/15/2021 00:05:31 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.70812   0.65187   0.67883       428
           1    0.70935   0.78604   0.74573       444
           2    0.37850   0.35526   0.36652       228

    accuracy                        0.64455      1100
   macro avg    0.59866   0.59772   0.59702      1100
weighted avg    0.64030   0.64455   0.64110      1100

Macro-F1:  0.5970248165396298


03/15/2021 00:05:33 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.70781   0.65654   0.68121       428
           1    0.70876   0.78378   0.74439       444
           2    0.38208   0.35526   0.36818       228

    accuracy                        0.64545      1100
   macro avg    0.59955   0.59853   0.59793      1100
weighted avg    0.64068   0.64545   0.64183      1100

Macro-F1:  0.5979263220439691


03/15/2021 00:05:36 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.70812   0.65187   0.67883       428
           1    0.71341   0.79054   0.75000       444
           2    0.37383   0.35088   0.36199       228

    accuracy                        0.64545      1100
   macro avg    0.59846   0.59776   0.59694      1100
weighted avg    0.64097   0.64545   0.64189      1100

Macro-F1:  0.5969410223381885


03/15/2021 00:05:39 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.72051   0.65654   0.68704       428
           1    0.69608   0.79955   0.74423       444
           2    0.38000   0.33333   0.35514       228

    accuracy                        0.64727      1100
   macro avg    0.59886   0.59647   0.59547      1100
weighted avg    0.64007   0.64727   0.64133      1100

Macro-F1:  0.5954721841822127


03/15/2021 00:05:41 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.72063   0.64486   0.68064       428
           1    0.70020   0.78378   0.73964       444
           2    0.36364   0.35088   0.35714       228

    accuracy                        0.64000      1100
   macro avg    0.59482   0.59317   0.59247      1100
weighted avg    0.63839   0.64000   0.63740      1100

Macro-F1:  0.592474241039859


03/15/2021 00:05:44 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.71899   0.66355   0.69016       428
           1    0.70243   0.78153   0.73987       444
           2    0.37441   0.34649   0.35991       228

    accuracy                        0.64545      1100
   macro avg    0.59861   0.59719   0.59665      1100
weighted avg    0.64088   0.64545   0.64177      1100

Macro-F1:  0.5966463035816281


03/15/2021 00:05:46 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.72379   0.66121   0.69109       428
           1    0.70423   0.78829   0.74389       444
           2    0.38208   0.35526   0.36818       228

    accuracy                        0.64909      1100
   macro avg    0.60336   0.60159   0.60105      1100
weighted avg    0.64506   0.64909   0.64547      1100

Macro-F1:  0.6010526628486246


03/15/2021 00:05:49 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.73387   0.63785   0.68250       428
           1    0.70341   0.79054   0.74443       444
           2    0.36681   0.36842   0.36761       228

    accuracy                        0.64364      1100
   macro avg    0.60136   0.59894   0.59818      1100
weighted avg    0.64549   0.64364   0.64223      1100

Macro-F1:  0.5981825137892708


evaluate with corrupt data

In [112]:
# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
corrupt_method = "S3"
data_file_name = task_name if task_name != "sst3" else "sst-tenary"
corrupt_datasets = get_dataset(f"../data-files/{data_file_name}-corrupted-{corrupt_method}")
logger.info(f"***** Train Sample Count (Verify): %s *****"%(len(datasets["train"])))
logger.info(f"***** Valid Sample Count (Verify): %s *****"%(len(datasets["validation"])))

03/15/2021 00:18:06 - INFO - __main__ - ***** Loading pre-loaded datasets from the disk directly! *****
03/15/2021 00:18:06 - INFO - __main__ - ***** Inoculation Sample Count: 159274 *****
03/15/2021 00:18:06 - INFO - __main__ - ***** Train Sample Count (Verify): 159274 *****
03/15/2021 00:18:06 - INFO - __main__ - ***** Valid Sample Count (Verify): 1100 *****


In [113]:
corrupt_validation_input_features = []
corrupt_validation_label_ids = []
for (ex_index, example) in enumerate(tqdm(corrupt_datasets["validation"])):
    bow_feature = torch.zeros(len(original_vocab))
    if sentence2_key is None:
        sentence_combined = example[sentence1_key]
    else:
        sentence_combined = example[sentence1_key] + " [SEP] " + example[sentence2_key]
    sentence_tokens = modified_basic_tokenizer.tokenize(sentence_combined)
    sentence_tokens = sentence_tokens[:max_length]
    for t in sentence_tokens:
        bow_feature[original_vocab[t]] += 1
    corrupt_validation_input_features.append(bow_feature)
    corrupt_validation_label_ids.append(example["label"])

100%|██████████| 1100/1100 [00:00<00:00, 1194.33it/s]


In [114]:
corrupt_validation_input_features = torch.stack(corrupt_validation_input_features, dim=0)
corrupt_validation_input_features = torch.tensor(corrupt_validation_input_features, dtype=torch.float)
corrupt_validation_label_ids = torch.tensor(corrupt_validation_label_ids, dtype=torch.long)
corrupt_validation_data = TensorDataset(corrupt_validation_input_features, corrupt_validation_label_ids)

  


In [115]:
corrupt_validation_dataloader = DataLoader(corrupt_validation_data, batch_size=per_device_eval_batch_size*n_gpu, shuffle=False)

In [116]:
logger.info("***** Evaluation With Corrupt Data *****")
model.eval()
all_logits = []
all_label_ids = []
with torch.no_grad():
    # pbar = tqdm(validation_dataloader, desc="Iteration")
    for step, batch in enumerate(corrupt_validation_dataloader):
        if torch.cuda.is_available() and not no_cuda:
            torch.cuda.empty_cache()

        input_features, label_ids = batch

        if torch.cuda.is_available() and not no_cuda:
            input_features = input_features.to(device)
            label_ids = label_ids.to(device)

        loss, logits = model(input_features, labels=label_ids)
        logits = F.softmax(logits, dim=-1)
        logits = logits.detach().cpu().numpy()
        label_ids = label_ids.to('cpu').numpy()
        outputs = np.argmax(logits, axis=1)
        all_logits.append(outputs)
        all_label_ids.append(label_ids)

all_logits = np.concatenate(all_logits, axis=0)
all_label_ids = np.concatenate(all_label_ids, axis=0)
result_to_save = classification_report(all_label_ids, all_logits, digits=5, output_dict=True)
print(classification_report(all_label_ids, all_logits, digits=5))
print("Macro-F1: ", result_to_save["macro avg"]["f1-score"])

03/15/2021 00:18:12 - INFO - __main__ - ***** Evaluation With Corrupt Data *****


              precision    recall  f1-score   support

           0    0.42553   0.04673   0.08421       428
           1    0.38776   0.04279   0.07708       444
           2    0.21016   0.92544   0.34253       228

    accuracy                        0.22727      1100
   macro avg    0.34115   0.33832   0.16794      1100
weighted avg    0.36564   0.22727   0.13488      1100

Macro-F1:  0.16794070045110934


evaluate with other type corrupt data