In [1]:
from vocab_mismatch_utils import *
from data_formatter_utils import *
from datasets import DatasetDict
from datasets import Dataset
from datasets import load_dataset
import transformers
import pandas as pd
from collections import OrderedDict
import operator

from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from torch.nn import CrossEntropyLoss

# Load modules, mainly huggingface basic model handlers.
# Make sure you install huggingface and other packages properly.
from collections import Counter
import json

from nltk.tokenize import TweetTokenizer
from sklearn.metrics import classification_report
from sklearn.feature_extraction import DictVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline
from sklearn.metrics import matthews_corrcoef
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import statistics

import logging
logger = logging.getLogger(__name__)

import os
os.environ["TRANSFORMERS_CACHE"] = "../huggingface_cache/" # Not overload common dir 
                                                           # if run in shared resources.

import random
import sys
from dataclasses import dataclass, field
from typing import Optional
import torch
import argparse
import numpy as np
import pandas as pd
from datasets import load_dataset, load_metric
from datasets import Dataset
from datasets import DatasetDict
from tqdm import tqdm, trange

import transformers
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    EvalPrediction,
    HfArgumentParser,
    PretrainedConfig,
    Trainer,
    TrainingArguments,
    default_data_collator,
    set_seed,
    EarlyStoppingCallback
)
from transformers.trainer_utils import is_main_process, EvaluationStrategy

import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "Times New Roman"
font = {'family' : 'Times New Roman',
        'size'   : 30}
plt.rc('font', **font)

#### Setups

In [63]:
task_name = "cola"

In [64]:
def get_dataset(inoculation_data_path, eval_data_path=None, test_data_path=None,
                inoculation_step_sample_size=1.0, 
                eval_sample_limit=-1, seed=42):
    """
    eval_data_path is not needed if it is a saved_to_disk 
    huggingface dataset.
    
    return type is already a huggingface dataset.
    """
    pd_format = True
    if inoculation_data_path.split(".")[-1] != "tsv":
        if len(inoculation_data_path.split(".")) > 1:
            logger.info(f"***** Loading pre-loaded datasets from the disk directly! *****")
            pd_format = False
            datasets = DatasetDict.load_from_disk(inoculation_data_path)
            inoculation_step_sample_size = int(len(datasets["train"]) * inoculation_step_sample_size)
            logger.info(f"***** Inoculation Sample Count: %s *****"%(inoculation_step_sample_size))
            # this may not always start for zero inoculation
            datasets["train"] = datasets["train"].shuffle(seed=seed)
            inoculation_train_df = datasets["train"].select(range(inoculation_step_sample_size))
            eval_df = datasets["validation"]
            datasets["validation"] = datasets["validation"].shuffle(seed=seed)
            if eval_sample_limit != -1:
                datasets["validation"] = datasets["validation"].select(range(eval_sample_limit))
        else:
            logger.info(f"***** Loading downloaded huggingface datasets: {inoculation_data_path}! *****")
            pd_format = False
            if inoculation_data_path in ["sst3", "cola", "mnli", "snli", "mrps", "qnli"]:
                pass
            raise NotImplementedError()
    else:
        train_df = pd.read_csv(inoculation_data_path, delimiter="\t")
        eval_df = pd.read_csv(eval_data_path, delimiter="\t")
        test_df = pd.read_csv(test_data_path, delimiter="\t")
        inoculation_step_sample_size = int(len(train_df) * inoculation_step_sample_size)
        logger.info(f"***** Inoculation Sample Count: %s *****"%(inoculation_step_sample_size))
        # this may not always start for zero inoculation
        inoculation_train_df = train_df.sample(n=inoculation_step_sample_size, 
                                               replace=False, 
                                               random_state=seed) # seed here could not a little annoying.
    if pd_format:
        datasets = {}
        datasets["train"] = Dataset.from_pandas(inoculation_train_df)
        datasets["validation"] = Dataset.from_pandas(eval_df)
        datasets["test"] = Dataset.from_pandas(test_df)
    else:
        datasets = {}
        datasets["train"] = inoculation_train_df
        datasets["validation"] = eval_df
    return datasets

In [65]:
TASK_CONFIG = {
    "sst3": ("text", None),
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "snli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence")
}
# WARNING: you dont need BERT tokenizer
# original_vocab = load_bert_vocab("../data-files/bert_vocab.txt")
# original_tokenizer = transformers.BertTokenizer(
#     vocab_file="../data-files/bert_vocab.txt")
# Just use some basic white space tokenizor here!
modified_basic_tokenizer = ModifiedBasicTokenizer()
max_length = 128
per_device_train_batch_size = 128
per_device_eval_batch_size = 128
no_cuda = True
device = torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
n_gpu = torch.cuda.device_count() if not no_cuda else 1 # 1 means just on cpu
seed = 42
lr = 1e-3
num_train_epochs = 10
sentence1_key, sentence2_key = TASK_CONFIG[task_name]

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if n_gpu > 0 and not no_cuda:
    torch.cuda.manual_seed_all(args.seed)

In [66]:
# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
data_file_name = task_name if task_name != "sst3" else "sst-tenary"
datasets = get_dataset(f"../data-files/{data_file_name}/{data_file_name}-train.tsv", 
                       f"../data-files/{data_file_name}/{data_file_name}-dev.tsv", 
                       f"../data-files/{data_file_name}/{data_file_name}-test.tsv")
logger.info(f"***** Train Sample Count (Verify): %s *****"%(len(datasets["train"])))
logger.info(f"***** Valid Sample Count (Verify): %s *****"%(len(datasets["validation"])))
logger.info(f"***** Test Sample Count (Verify): %s *****"%(len(datasets["test"])))

03/29/2021 14:17:15 - INFO - __main__ - ***** Inoculation Sample Count: 8551 *****
03/29/2021 14:17:15 - INFO - __main__ - ***** Train Sample Count (Verify): 8551 *****
03/29/2021 14:17:15 - INFO - __main__ - ***** Valid Sample Count (Verify): 1043 *****
03/29/2021 14:17:15 - INFO - __main__ - ***** Test Sample Count (Verify): 1063 *****


#### BoW preprocessor

In [67]:
def sanity_check_non_empty(sentece):
    if sentece != None and sentece.strip() != "" and sentece.strip() != "None":
        return True
    return False

# create the vocab file
vocab_index = 0
original_vocab = OrderedDict()
if "train" in datasets:
    for (ex_index, example) in enumerate(tqdm(datasets["train"])):
        if sentence2_key is None:
            if sanity_check_non_empty(example[sentence1_key]):
                sentence_combined = example[sentence1_key]
        else:
            s1 = ""
            s2 = ""
            if sanity_check_non_empty(example[sentence1_key]):
                s1 = example[sentence1_key]
            if sanity_check_non_empty(example[sentence2_key]):
                s2 = example[sentence2_key]
            sentence_combined = s1 + " [SEP] " + s2
        sentence_tokens = modified_basic_tokenizer.tokenize(sentence_combined)
        for token in sentence_tokens:
            if token not in original_vocab.keys():
                original_vocab[token] = vocab_index
                vocab_index += 1
train_data_only = False
if not train_data_only:
    if "validation" in datasets:
        for (ex_index, example) in enumerate(tqdm(datasets["validation"])):
            if sentence2_key is None:
                if sanity_check_non_empty(example[sentence1_key]):
                    sentence_combined = example[sentence1_key]
            else:
                s1 = ""
                s2 = ""
                if sanity_check_non_empty(example[sentence1_key]):
                    s1 = example[sentence1_key]
                if sanity_check_non_empty(example[sentence2_key]):
                    s2 = example[sentence2_key]
                sentence_combined = s1 + " [SEP] " + s2
            sentence_tokens = modified_basic_tokenizer.tokenize(sentence_combined)
            for token in sentence_tokens:
                if token not in original_vocab.keys():
                    original_vocab[token] = vocab_index
                    vocab_index += 1

    if "test" in datasets:
        for (ex_index, example) in enumerate(tqdm(datasets["test"])):
            if sentence2_key is None:
                if sanity_check_non_empty(example[sentence1_key]):
                    sentence_combined = example[sentence1_key]
            else:
                s1 = ""
                s2 = ""
                if sanity_check_non_empty(example[sentence1_key]):
                    s1 = example[sentence1_key]
                if sanity_check_non_empty(example[sentence2_key]):
                    s2 = example[sentence2_key]
                sentence_combined = s1 + " [SEP] " + s2
            sentence_tokens = modified_basic_tokenizer.tokenize(sentence_combined)
            for token in sentence_tokens:
                if token not in original_vocab.keys():
                    original_vocab[token] = vocab_index
                    vocab_index += 1

100%|██████████| 8551/8551 [00:01<00:00, 8031.65it/s]
100%|██████████| 1043/1043 [00:00<00:00, 8366.12it/s]
100%|██████████| 1063/1063 [00:00<00:00, 8201.12it/s]


In [68]:
# BoW feature vectors for train split
train_input_features = []
train_label_ids = []
for (ex_index, example) in enumerate(tqdm(datasets["train"])):
    if sentence2_key is None:
        bow_feature = torch.zeros(len(original_vocab))
        if sanity_check_non_empty(example[sentence1_key]):
            sentence_combined = example[sentence1_key]
        sentence_tokens = modified_basic_tokenizer.tokenize(sentence_combined)
        if ex_index % 50000 == 0:
            print("Example sentence: " + sentence_combined)
        for t in sentence_tokens:
            bow_feature[original_vocab[t]] += 1
        train_input_features.append(bow_feature)
        train_label_ids.append(example["label"])
    else:
        bow_feature_1 = torch.zeros(len(original_vocab))
        bow_feature_2 = torch.zeros(len(original_vocab))
        s1 = ""
        s2 = ""
        if sanity_check_non_empty(example[sentence1_key]):
            s1 = example[sentence1_key]
        if sanity_check_non_empty(example[sentence2_key]):
            s2 = example[sentence2_key]
        s1_tokens = modified_basic_tokenizer.tokenize(s1)
        s2_tokens = modified_basic_tokenizer.tokenize(s2)
        if ex_index % 50000 == 0:
            print("Example sentence 1: " + s1)
            print("Example sentence 2: " + s2)
        for t in s1_tokens:
            bow_feature_1[original_vocab[t]] += 1
        for t in s2_tokens:
            bow_feature_2[original_vocab[t]] += 1
        bow_feature = torch.cat([bow_feature_1, bow_feature_2], dim=-1)
        train_input_features.append(bow_feature)
        train_label_ids.append(example["label"])
    
train_input_features = torch.stack(train_input_features, dim=0)
train_input_features = torch.tensor(train_input_features, dtype=torch.float)
train_label_ids = torch.tensor(train_label_ids, dtype=torch.long)
train_data = TensorDataset(train_input_features, train_label_ids)

  9%|▉         | 759/8551 [00:00<00:02, 3775.67it/s]

Example sentence: Where all did they go for their holidays?


100%|██████████| 8551/8551 [00:02<00:00, 3760.38it/s]


In [69]:
# BoW feature vectors for validation split
validation_input_features = []
validation_label_ids = []
for (ex_index, example) in enumerate(tqdm(datasets["validation"])):
    if sentence2_key is None:
        bow_feature = torch.zeros(len(original_vocab))
        if sanity_check_non_empty(example[sentence1_key]):
            sentence_combined = example[sentence1_key]
        sentence_tokens = modified_basic_tokenizer.tokenize(sentence_combined)
        if ex_index % 50000 == 0:
            print("Example sentence: " + sentence_combined)
        for t in sentence_tokens:
            bow_feature[original_vocab[t]] += 1
    else:
        bow_feature_1 = torch.zeros(len(original_vocab))
        bow_feature_2 = torch.zeros(len(original_vocab))
        s1 = ""
        s2 = ""
        if sanity_check_non_empty(example[sentence1_key]):
            s1 = example[sentence1_key]
        if sanity_check_non_empty(example[sentence2_key]):
            s2 = example[sentence2_key]
        s1_tokens = modified_basic_tokenizer.tokenize(s1)
        s2_tokens = modified_basic_tokenizer.tokenize(s2)
        if ex_index % 50000 == 0:
            print("Example sentence 1: " + s1)
            print("Example sentence 2: " + s2)
        for t in s1_tokens:
            bow_feature_1[original_vocab[t]] += 1
        for t in s2_tokens:
            bow_feature_2[original_vocab[t]] += 1
        bow_feature = torch.cat([bow_feature_1, bow_feature_2], dim=-1)
    validation_input_features.append(bow_feature)
    validation_label_ids.append(example["label"])

    
    
validation_input_features = torch.stack(validation_input_features, dim=0)
validation_input_features = torch.tensor(validation_input_features, dtype=torch.float)
validation_label_ids = torch.tensor(validation_label_ids, dtype=torch.long)
validation_data = TensorDataset(validation_input_features, validation_label_ids)

 62%|██████▏   | 651/1043 [00:00<00:00, 3389.36it/s]

Example sentence: All who lost money in the scam are eligible for the program.


100%|██████████| 1043/1043 [00:00<00:00, 3073.06it/s]


In [70]:
# data loader
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=per_device_train_batch_size*n_gpu)
validation_dataloader = DataLoader(validation_data, batch_size=per_device_eval_batch_size*n_gpu, shuffle=False)

#### BoW Classifer

In [71]:
class BOWClassifier(nn.Module):
    def __init__(self, num_labels, vocab_size):
        super(BOWClassifier, self).__init__()
        self.classifier = nn.Linear(vocab_size, num_labels, bias=True)
    def forward(self, x, labels=None):
        logits = self.classifier(x)

        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits, labels)
            return loss, logits
        else:
            return logits

In [72]:
class MockBERTBOWClassifier(nn.Module):
    def __init__(self, num_labels, vocab_size):
        super(MockBERTBOWClassifier, self).__init__()
        hidden_dim = 32
        self.mock_bert = nn.Linear(vocab_size, hidden_dim, bias=False)
        self.mock_activation = nn.Tanh()
        self.classifier = nn.Linear(hidden_dim, num_labels, bias=False)
    def forward(self, x, labels=None):
        cls = self.mock_activation(self.mock_bert(x))
        logits = self.classifier(cls)

        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits, labels)
            return loss, logits
        else:
            return logits

In [73]:
# some overriding fun stuffs!
lr = 1e-3
num_train_epochs = 50
if sentence2_key is None:
    in_dim = len(original_vocab)
else:
    in_dim = len(original_vocab) * 2
model = BOWClassifier(len(validation_label_ids.unique()), in_dim)
optimizer = optim.Adam(model.parameters(), lr=lr)
if n_gpu > 0 and not no_cuda:
    model = torch.nn.DataParallel(model)

#### Main training loop

In [74]:
global_step = 0
best_f1 = -1
best_mcc = -1
for _ in range(int(num_train_epochs)):
    
    model.train()
    # pbar = tqdm(train_dataloader, desc="Iteration")
    for step, batch in enumerate(train_dataloader):
        if torch.cuda.is_available() and not no_cuda:
            torch.cuda.empty_cache()

        input_features, label_ids = batch

        if torch.cuda.is_available() and not no_cuda:
            input_features = input_features.to(device)
            label_ids = label_ids.to(device)

        loss, _ = model(input_features, labels=label_ids)

        if n_gpu > 1:
            loss = loss.mean() # mean() to average on multi-gpu.
        loss.backward()

        optimizer.step()
        model.zero_grad()
        # pbar.set_postfix({'train_loss': loss.tolist()})

        if global_step % 500 == 0:
            logger.info("***** Evaluation Interval Hit *****")
            model.eval()
            all_logits = []
            all_label_ids = []
            with torch.no_grad():
                # pbar = tqdm(validation_dataloader, desc="Iteration")
                for step, batch in enumerate(validation_dataloader):
                    if torch.cuda.is_available() and not no_cuda:
                        torch.cuda.empty_cache()
                        
                    input_features, label_ids = batch
                    
                    if torch.cuda.is_available() and not no_cuda:
                        input_features = input_features.to(device)
                        label_ids = label_ids.to(device)
                    
                    loss, logits = model(input_features, labels=label_ids)
                    logits = F.softmax(logits, dim=-1)
                    logits = logits.detach().cpu().numpy()
                    label_ids = label_ids.to('cpu').numpy()
                    outputs = np.argmax(logits, axis=1)
                    all_logits.append(outputs)
                    all_label_ids.append(label_ids)
                    
            all_logits = np.concatenate(all_logits, axis=0)
            all_label_ids = np.concatenate(all_label_ids, axis=0)
            result_to_save = classification_report(all_label_ids, all_logits, digits=5, output_dict=True)
            print(classification_report(all_label_ids, all_logits, digits=5))
            f1 = result_to_save["macro avg"]["f1-score"]
            print("Macro-F1: ", f1)
            best_f1 = f1 if f1 > best_f1 else best_f1
            mcc = matthews_corrcoef(all_label_ids, all_logits)
            best_mcc = mcc if mcc > best_mcc else best_mcc
            print("MCC: ", mcc)
                    
        global_step += 1
print("Best Macro-F1: ", best_f1)
print("Best MCC: ", best_mcc)

03/29/2021 14:17:28 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.32227   0.72360   0.44593       322
           1    0.72188   0.32039   0.44380       721

    accuracy                        0.44487      1043
   macro avg    0.52207   0.52200   0.44487      1043
weighted avg    0.59851   0.44487   0.44446      1043

Macro-F1:  0.44486852446809977
MCC:  0.044067014240337134


03/29/2021 14:17:30 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.54545   0.01863   0.03604       322
           1    0.69380   0.99307   0.81689       721

    accuracy                        0.69223      1043
   macro avg    0.61963   0.50585   0.42646      1043
weighted avg    0.64800   0.69223   0.57582      1043

Macro-F1:  0.42646068772708823
MCC:  0.0529051568355168


03/29/2021 14:17:31 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.39024   0.04969   0.08815       322
           1    0.69461   0.96533   0.80789       721

    accuracy                        0.68265      1043
   macro avg    0.54243   0.50751   0.44802      1043
weighted avg    0.60065   0.68265   0.58569      1043

Macro-F1:  0.4480237397453669
MCC:  0.03569488815055102


03/29/2021 14:17:33 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.44444   0.09938   0.16244       322
           1    0.70134   0.94452   0.80496       721

    accuracy                        0.68360      1043
   macro avg    0.57289   0.52195   0.48370      1043
weighted avg    0.62203   0.68360   0.60660      1043

Macro-F1:  0.4837005436152212
MCC:  0.07999963096499767


03/29/2021 14:17:34 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.44706   0.11801   0.18673       322
           1    0.70355   0.93481   0.80286       721

    accuracy                        0.68265      1043
   macro avg    0.57530   0.52641   0.49480      1043
weighted avg    0.62436   0.68265   0.61265      1043

Macro-F1:  0.4947955156412572
MCC:  0.08919578997554263


03/29/2021 14:17:36 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.42574   0.13354   0.20331       322
           1    0.70382   0.91956   0.79735       721

    accuracy                        0.67689      1043
   macro avg    0.56478   0.52655   0.50033      1043
weighted avg    0.61797   0.67689   0.61396      1043

Macro-F1:  0.5003319359328111
MCC:  0.08294222652033967


03/29/2021 14:17:37 - INFO - __main__ - ***** Evaluation Interval Hit *****


              precision    recall  f1-score   support

           0    0.42342   0.14596   0.21709       322
           1    0.70494   0.91123   0.79492       721

    accuracy                        0.67498      1043
   macro avg    0.56418   0.52860   0.50600      1043
weighted avg    0.61803   0.67498   0.61653      1043

Macro-F1:  0.5060041997962973
MCC:  0.08568412322809724
Best Macro-F1:  0.5060041997962973
Best MCC:  0.08919578997554263


#### Evaluations with frequency-matched scrambling

In [78]:
# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
corrupt_method = "matched"
data_file_name = task_name if task_name != "sst3" else "sst-tenary"
corrupt_datasets = get_dataset(f"../data-files/{data_file_name}-corrupted-{corrupt_method}")
logger.info(f"***** Train Sample Count (Verify): %s *****"%(len(datasets["train"])))
logger.info(f"***** Valid Sample Count (Verify): %s *****"%(len(datasets["validation"])))

corrupt_validation_input_features = []
corrupt_validation_label_ids = []
for (ex_index, example) in enumerate(tqdm(corrupt_datasets["validation"])):
    if sentence2_key is None:
        bow_feature = torch.zeros(len(original_vocab))
        if sanity_check_non_empty(example[sentence1_key]):
            sentence_combined = example[sentence1_key]
        sentence_tokens = modified_basic_tokenizer.tokenize(sentence_combined)
        if ex_index % 50000 == 0:
            print("Example sentence: " + sentence_combined)
        for t in sentence_tokens:
            bow_feature[original_vocab[t]] += 1
    else:
        bow_feature_1 = torch.zeros(len(original_vocab))
        bow_feature_2 = torch.zeros(len(original_vocab))
        s1 = ""
        s2 = ""
        if sanity_check_non_empty(example[sentence1_key]):
            s1 = example[sentence1_key]
        if sanity_check_non_empty(example[sentence2_key]):
            s2 = example[sentence2_key]
        s1_tokens = modified_basic_tokenizer.tokenize(s1)
        s2_tokens = modified_basic_tokenizer.tokenize(s2)
        if ex_index % 50000 == 0:
            print("Example sentence 1: " + s1)
            print("Example sentence 2: " + s2)
        for t in s1_tokens:
            bow_feature_1[original_vocab[t]] += 1
        for t in s2_tokens:
            bow_feature_2[original_vocab[t]] += 1
        bow_feature = torch.cat([bow_feature_1, bow_feature_2], dim=-1)
    corrupt_validation_input_features.append(bow_feature)
    corrupt_validation_label_ids.append(example["label"])
    
corrupt_validation_input_features = torch.stack(corrupt_validation_input_features, dim=0)
corrupt_validation_input_features = torch.tensor(corrupt_validation_input_features, dtype=torch.float)
corrupt_validation_label_ids = torch.tensor(corrupt_validation_label_ids, dtype=torch.long)
corrupt_validation_data = TensorDataset(corrupt_validation_input_features, corrupt_validation_label_ids)
corrupt_validation_dataloader = DataLoader(corrupt_validation_data, batch_size=per_device_eval_batch_size*n_gpu, shuffle=False)

logger.info("***** Evaluation With Corrupt Data *****")
model.eval()
all_logits = []
all_label_ids = []
with torch.no_grad():
    # pbar = tqdm(validation_dataloader, desc="Iteration")
    for step, batch in enumerate(corrupt_validation_dataloader):
        if torch.cuda.is_available() and not no_cuda:
            torch.cuda.empty_cache()

        input_features, label_ids = batch

        if torch.cuda.is_available() and not no_cuda:
            input_features = input_features.to(device)
            label_ids = label_ids.to(device)

        loss, logits = model(input_features, labels=label_ids)
        logits = F.softmax(logits, dim=-1)
        logits = logits.detach().cpu().numpy()
        label_ids = label_ids.to('cpu').numpy()
        outputs = np.argmax(logits, axis=1)
        all_logits.append(outputs)
        all_label_ids.append(label_ids)

all_logits = np.concatenate(all_logits, axis=0)
all_label_ids = np.concatenate(all_label_ids, axis=0)
result_to_save = classification_report(all_label_ids, all_logits, digits=5, output_dict=True)
print(classification_report(all_label_ids, all_logits, digits=5))
print("Macro-F1: ", result_to_save["macro avg"]["f1-score"])
mcc = matthews_corrcoef(all_label_ids, all_logits)
print("MCC: ", mcc)

03/29/2021 14:18:31 - INFO - __main__ - ***** Loading pre-loaded datasets from the disk directly! *****
03/29/2021 14:18:31 - INFO - __main__ - ***** Inoculation Sample Count: 8551 *****
Loading cached shuffled indices for dataset at ../data-files/cola-corrupted-matched/train/cache-f4202775805e0baf.arrow
Loading cached shuffled indices for dataset at ../data-files/cola-corrupted-matched/validation/cache-411cd289c16c3d03.arrow
03/29/2021 14:18:31 - INFO - __main__ - ***** Train Sample Count (Verify): 8551 *****
03/29/2021 14:18:31 - INFO - __main__ - ***** Valid Sample Count (Verify): 1043 *****
 26%|██▌       | 267/1043 [00:00<00:00, 2663.66it/s]

Example sentence: can bill lifted whether and . sushi but includes on . critics the


100%|██████████| 1043/1043 [00:00<00:00, 3237.66it/s]
03/29/2021 14:18:31 - INFO - __main__ - ***** Evaluation With Corrupt Data *****


              precision    recall  f1-score   support

           0    0.29327   0.18944   0.23019       322
           1    0.68743   0.79612   0.73779       721

    accuracy                        0.60882      1043
   macro avg    0.49035   0.49278   0.48399      1043
weighted avg    0.56574   0.60882   0.58108      1043

Macro-F1:  0.483988941165058
MCC:  -0.016697947067186646


#### Evaluations with frequency-unmatched scrambling

In [79]:
# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
corrupt_method = "mismatched"
data_file_name = task_name if task_name != "sst3" else "sst-tenary"
corrupt_datasets = get_dataset(f"../data-files/{data_file_name}-corrupted-{corrupt_method}")
logger.info(f"***** Train Sample Count (Verify): %s *****"%(len(datasets["train"])))
logger.info(f"***** Valid Sample Count (Verify): %s *****"%(len(datasets["validation"])))

corrupt_validation_input_features = []
corrupt_validation_label_ids = []
for (ex_index, example) in enumerate(tqdm(corrupt_datasets["validation"])):
    if sentence2_key is None:
        bow_feature = torch.zeros(len(original_vocab))
        if sanity_check_non_empty(example[sentence1_key]):
            sentence_combined = example[sentence1_key]
        sentence_tokens = modified_basic_tokenizer.tokenize(sentence_combined)
        if ex_index % 50000 == 0:
            print("Example sentence: " + sentence_combined)
        for t in sentence_tokens:
            bow_feature[original_vocab[t]] += 1
    else:
        bow_feature_1 = torch.zeros(len(original_vocab))
        bow_feature_2 = torch.zeros(len(original_vocab))
        s1 = ""
        s2 = ""
        if sanity_check_non_empty(example[sentence1_key]):
            s1 = example[sentence1_key]
        if sanity_check_non_empty(example[sentence2_key]):
            s2 = example[sentence2_key]
        s1_tokens = modified_basic_tokenizer.tokenize(s1)
        s2_tokens = modified_basic_tokenizer.tokenize(s2)
        if ex_index % 50000 == 0:
            print("Example sentence 1: " + s1)
            print("Example sentence 2: " + s2)
        for t in s1_tokens:
            bow_feature_1[original_vocab[t]] += 1
        for t in s2_tokens:
            bow_feature_2[original_vocab[t]] += 1
        bow_feature = torch.cat([bow_feature_1, bow_feature_2], dim=-1)
    corrupt_validation_input_features.append(bow_feature)
    corrupt_validation_label_ids.append(example["label"])
    
corrupt_validation_input_features = torch.stack(corrupt_validation_input_features, dim=0)
corrupt_validation_input_features = torch.tensor(corrupt_validation_input_features, dtype=torch.float)
corrupt_validation_label_ids = torch.tensor(corrupt_validation_label_ids, dtype=torch.long)
corrupt_validation_data = TensorDataset(corrupt_validation_input_features, corrupt_validation_label_ids)
corrupt_validation_dataloader = DataLoader(corrupt_validation_data, batch_size=per_device_eval_batch_size*n_gpu, shuffle=False)

logger.info("***** Evaluation With Corrupt Data *****")
model.eval()
all_logits = []
all_label_ids = []
with torch.no_grad():
    # pbar = tqdm(validation_dataloader, desc="Iteration")
    for step, batch in enumerate(corrupt_validation_dataloader):
        if torch.cuda.is_available() and not no_cuda:
            torch.cuda.empty_cache()

        input_features, label_ids = batch

        if torch.cuda.is_available() and not no_cuda:
            input_features = input_features.to(device)
            label_ids = label_ids.to(device)

        loss, logits = model(input_features, labels=label_ids)
        logits = F.softmax(logits, dim=-1)
        logits = logits.detach().cpu().numpy()
        label_ids = label_ids.to('cpu').numpy()
        outputs = np.argmax(logits, axis=1)
        all_logits.append(outputs)
        all_label_ids.append(label_ids)

all_logits = np.concatenate(all_logits, axis=0)
all_label_ids = np.concatenate(all_label_ids, axis=0)
result_to_save = classification_report(all_label_ids, all_logits, digits=5, output_dict=True)
print(classification_report(all_label_ids, all_logits, digits=5))
print("Macro-F1: ", result_to_save["macro avg"]["f1-score"])
mcc = matthews_corrcoef(all_label_ids, all_logits)
print("MCC: ", mcc)

03/29/2021 14:18:43 - INFO - __main__ - ***** Loading pre-loaded datasets from the disk directly! *****
03/29/2021 14:18:43 - INFO - __main__ - ***** Inoculation Sample Count: 8551 *****
Loading cached shuffled indices for dataset at ../data-files/cola-corrupted-mismatched/train/cache-452a380612a1acc4.arrow
Loading cached shuffled indices for dataset at ../data-files/cola-corrupted-mismatched/validation/cache-83a4d14ad2bf06f0.arrow
03/29/2021 14:18:43 - INFO - __main__ - ***** Train Sample Count (Verify): 8551 *****
03/29/2021 14:18:43 - INFO - __main__ - ***** Valid Sample Count (Verify): 1043 *****
 59%|█████▊    | 611/1043 [00:00<00:00, 3043.29it/s]

Example sentence: bartlett infinite ? delivered respect outdone 1492 penny nina majestic outdone unpopular ugliest


100%|██████████| 1043/1043 [00:00<00:00, 3019.72it/s]
03/29/2021 14:18:43 - INFO - __main__ - ***** Evaluation With Corrupt Data *****


              precision    recall  f1-score   support

           0    0.38462   0.04658   0.08310       322
           1    0.69422   0.96671   0.80812       721

    accuracy                        0.68265      1043
   macro avg    0.53942   0.50665   0.44561      1043
weighted avg    0.59864   0.68265   0.58429      1043

Macro-F1:  0.4456092175518889
MCC:  0.03237739483038874


#### Random guessing baseline
If we randomly guess the lables, what is the performance now?

In [77]:
# getting avg mF1 on the dataset with a dummy classifier
import numpy as np
from sklearn.dummy import DummyClassifier

mf1s = []
mccs = []
runs = 100
for i in range(runs):
    dummy_clf = DummyClassifier(strategy="stratified")
    dummy_clf.fit(validation_input_features, validation_label_ids)
    dummy_labels = dummy_clf.predict(validation_input_features)

    # dummy performance
    # print(classification_report(validation_label_ids, dummy_labels, digits=5))
    result_to_save = classification_report(validation_label_ids, dummy_labels, digits=5, output_dict=True)
    mf1s += [result_to_save["macro avg"]["f1-score"]]
    mcc = matthews_corrcoef(validation_label_ids, dummy_labels)
    mccs += [mcc]

print(classification_report(validation_label_ids, dummy_labels, digits=5))
print(f"AVG over {runs} runs mF1: {round(sum(mf1s)/len(mf1s), 6)}.")
print("Standard Deviation of sample is % s " % (statistics.stdev(mf1s)))
print(f"AVG over {runs} runs MCC: {round(sum(mccs)/len(mccs), 6)}.")


              precision    recall  f1-score   support

           0    0.32298   0.32298   0.32298       322
           1    0.69764   0.69764   0.69764       721

    accuracy                        0.58198      1043
   macro avg    0.51031   0.51031   0.51031      1043
weighted avg    0.58198   0.58198   0.58198      1043

AVG over 100 runs mF1: 0.500884.
Standard Deviation of sample is 0.015327163760148876 
AVG over 100 runs MCC: 0.002038.


#### FrequencyBoW classifiers

In [None]:
# task setups
task_name = "sst3"
num_labels = 3
FILENAME_CONFIG = {
    "sst3" : "sst-tenary"
}

# let us corrupt SST3 in the same way as before
train_df = pd.read_csv(os.path.join(external_output_dirname, FILENAME_CONFIG[task_name], 
                                    f"{FILENAME_CONFIG[task_name]}-train.tsv"), 
                       delimiter="\t")
eval_df = pd.read_csv(os.path.join(external_output_dirname, FILENAME_CONFIG[task_name], 
                                   f"{FILENAME_CONFIG[task_name]}-dev.tsv"), 
                      delimiter="\t")
test_df = pd.read_csv(os.path.join(external_output_dirname, FILENAME_CONFIG[task_name], 
                                   f"{FILENAME_CONFIG[task_name]}-test.tsv"), 
                      delimiter="\t")

train_df = Dataset.from_pandas(train_df)
eval_df = Dataset.from_pandas(eval_df)
test_df = Dataset.from_pandas(test_df)

In [None]:
modified_basic_tokenizer = ModifiedBasicTokenizer()
label_vocab_map = {}
token_frequency_map = {} # overwrite this everytime for a new dataset
for i, example in enumerate(train_df):
    if i % 10000 == 0 and i != 0:
        print(f"processing #{i} example...")
    original_sentence = example['text']
    label = example['label']
    if len(original_sentence.strip()) != 0:
        tokens = modified_basic_tokenizer.tokenize(original_sentence)
        if label not in label_vocab_map.keys():
            label_vocab_map[label] = tokens
        else:
            for t in tokens:
                label_vocab_map[label].append(t)
        for t in tokens:
            if t in token_frequency_map.keys():
                token_frequency_map[t] = token_frequency_map[t] + 1
            else:
                token_frequency_map[t] = 1
for i, example in enumerate(eval_df):
    if i % 10000 == 0 and i != 0:
        print(f"processing #{i} example...")
    original_sentence = example['text']
    label = example['label']
    if len(original_sentence.strip()) != 0:
        tokens = modified_basic_tokenizer.tokenize(original_sentence)
        if label not in label_vocab_map.keys():
            label_vocab_map[label] = tokens
        else:
            for t in tokens:
                label_vocab_map[label].append(t)
        for t in tokens:
            if t in token_frequency_map.keys():
                token_frequency_map[t] = token_frequency_map[t] + 1
            else:
                token_frequency_map[t] = 1
for i, example in enumerate(test_df):
    if i % 10000 == 0 and i != 0:
        print(f"processing #{i} example...")
    original_sentence = example['text']
    label = example['label']
    if len(original_sentence.strip()) != 0:
        tokens = modified_basic_tokenizer.tokenize(original_sentence)
        if label not in label_vocab_map.keys():
            label_vocab_map[label] = tokens
        else:
            for t in tokens:
                label_vocab_map[label].append(t)
        for t in tokens:
            if t in token_frequency_map.keys():
                token_frequency_map[t] = token_frequency_map[t] + 1
            else:
                token_frequency_map[t] = 1
task_token_frequency_map = sorted(token_frequency_map.items(), key=operator.itemgetter(1), reverse=True)
task_token_frequency_map = OrderedDict(task_token_frequency_map)

training BoW with 1st order frequency bins

In [None]:
# freq and bucket mappings
freq_set = set([])
for k, v in task_token_frequency_map.items():
    freq_set.add(v)
freq_set = list(freq_set)
freq_set.sort()
bucket_count = 256
freq_bucket = np.logspace(math.log(freq_set[0], 10), math.log(freq_set[-1], 10), bucket_count, endpoint=True)
freq_bucket = freq_bucket[:-1]
freq_bucket = [math.ceil(n) for n in freq_bucket]
# finally the bucket is a map between freq and bucket number
def find_bucket_number(freq, freq_bucket):
    for i in range(len(freq_bucket)):
        if freq > freq_bucket[i]:
            continue
        else:
            return i+1
    return len(freq_bucket)

new_bucket_idx = 0
freq_bucket_map = {}
for freq in freq_set:
    # bucket_num = find_bucket_number(freq, freq_bucket)
    freq_bucket_map[freq] = new_bucket_idx
    new_bucket_idx += 1

bucket_length = new_bucket_idx # len(freq_bucket)

In [None]:
# these lines of code make random buckets and assign words to them.
freq_count = {}
vocab = []
for k, v in task_token_frequency_map.items():
    vocab.append(k)
    if v in freq_count.keys():
        freq_count[v] += 1
    else:
        freq_count[v] = 1
random.shuffle(vocab)
bucket_length = 600
def split(a, n):
    k, m = divmod(len(a), n)
    return (a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n))
bucket_vocab_random = split(vocab, bucket_length)
random_bucket_vocab_map = {}
bucket_id = 0
for bucket in bucket_vocab_random:
    for word in bucket:
        random_bucket_vocab_map[word] = bucket_id
    bucket_id += 1

In [None]:
# FBoW feature vectors for train split
train_input_features = []
train_label_ids = []
for (ex_index, example) in enumerate(tqdm(train_df)):
    bow_feature = torch.zeros(bucket_length)
    if sentence2_key is None:
        sentence_combined = example[sentence1_key]
    else:
        sentence_combined = example[sentence1_key] + " [SEP] " + example[sentence2_key]
    sentence_tokens = modified_basic_tokenizer.tokenize(sentence_combined)
    sentence_tokens = sentence_tokens[:max_length]
    for t in sentence_tokens:
        # bow_feature[freq_bucket_map[token_frequency_map[t]]] = 1 # not bucket count, aggregated info contains word identity!
        bow_feature[random_bucket_vocab_map[t]] = +1
    if ex_index % 50000 == 0:
        print("Example sentence: " + sentence_combined)
        print(bow_feature)
    train_input_features.append(bow_feature)
    train_label_ids.append(example["label"])
    
train_input_features = torch.stack(train_input_features, dim=0)
train_input_features = torch.tensor(train_input_features, dtype=torch.float)
train_label_ids = torch.tensor(train_label_ids, dtype=torch.long)
train_data = TensorDataset(train_input_features, train_label_ids)

In [None]:
# FBoW feature vectors for validation split
validation_input_features = []
validation_label_ids = []
for (ex_index, example) in enumerate(tqdm(eval_df)):
    bow_feature = torch.zeros(bucket_length)
    if sentence2_key is None:
        sentence_combined = example[sentence1_key]
    else:
        sentence_combined = example[sentence1_key] + " [SEP] " + example[sentence2_key]
    sentence_tokens = modified_basic_tokenizer.tokenize(sentence_combined)
    sentence_tokens = sentence_tokens[:max_length]
    for t in sentence_tokens:
        # bow_feature[freq_bucket_map[token_frequency_map[t]]] = 1 # bucket count
        bow_feature[random_bucket_vocab_map[t]] = +1
    validation_input_features.append(bow_feature)
    validation_label_ids.append(example["label"])

validation_input_features = torch.stack(validation_input_features, dim=0)
validation_input_features = torch.tensor(validation_input_features, dtype=torch.float)
validation_label_ids = torch.tensor(validation_label_ids, dtype=torch.long)
validation_data = TensorDataset(validation_input_features, validation_label_ids)

In [None]:
# data loader
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=per_device_train_batch_size*n_gpu)
validation_dataloader = DataLoader(validation_data, batch_size=per_device_eval_batch_size*n_gpu, shuffle=False)

In [None]:
# some overriding fun stuffs!
lr = 1e-3
num_train_epochs = 20
model = BOWClassifier(len(validation_label_ids.unique()), bucket_length)
optimizer = optim.Adam(model.parameters(), lr=lr)
if n_gpu > 0 and not no_cuda:
    model = torch.nn.DataParallel(model)

In [None]:
global_step = 0
max_score = -1
for _ in range(int(num_train_epochs)):
    
    model.train()
    # pbar = tqdm(train_dataloader, desc="Iteration")
    for step, batch in enumerate(train_dataloader):
        if torch.cuda.is_available() and not no_cuda:
            torch.cuda.empty_cache()

        input_features, label_ids = batch

        if torch.cuda.is_available() and not no_cuda:
            input_features = input_features.to(device)
            label_ids = label_ids.to(device)

        loss, _ = model(input_features, labels=label_ids)

        if n_gpu > 1:
            loss = loss.mean() # mean() to average on multi-gpu.
        loss.backward()

        optimizer.step()
        model.zero_grad()
        # pbar.set_postfix({'train_loss': loss.tolist()})

        if global_step % 500 == 0:
            # logger.info("***** Evaluation Interval Hit *****")
            model.eval()
            all_logits = []
            all_label_ids = []
            with torch.no_grad():
                # pbar = tqdm(validation_dataloader, desc="Iteration")
                for step, batch in enumerate(validation_dataloader):
                    if torch.cuda.is_available() and not no_cuda:
                        torch.cuda.empty_cache()
                        
                    input_features, label_ids = batch
                    
                    if torch.cuda.is_available() and not no_cuda:
                        input_features = input_features.to(device)
                        label_ids = label_ids.to(device)
                    
                    loss, logits = model(input_features, labels=label_ids)
                    logits = F.softmax(logits, dim=-1)
                    logits = logits.detach().cpu().numpy()
                    label_ids = label_ids.to('cpu').numpy()
                    outputs = np.argmax(logits, axis=1)
                    all_logits.append(outputs)
                    all_label_ids.append(label_ids)
                    
            all_logits = np.concatenate(all_logits, axis=0)
            all_label_ids = np.concatenate(all_label_ids, axis=0)
            result_to_save = classification_report(all_label_ids, all_logits, digits=5, output_dict=True)
            # print(classification_report(all_label_ids, all_logits, digits=5))
            print("Macro-F1: ", result_to_save["macro avg"]["f1-score"])
            if result_to_save["macro avg"]["f1-score"] > max_score:
                max_score = result_to_save["macro avg"]["f1-score"]
                    
        global_step += 1
print("Best Macro-F1: ", max_score)

training BoW with 1st and 2nd order frequency bins

In [None]:
# repartition the first order information
second_order_freq_set = set([])
for k, v in task_token_frequency_map.items():
    second_order_freq_set.add(v)
second_order_freq_set = list(second_order_freq_set)
second_order_freq_set.sort()
temp_bucket_count = 24
second_order_freq_bucket = np.logspace(math.log(second_order_freq_set[0], 10), 
                          math.log(second_order_freq_set[-1], 10), temp_bucket_count+1, 
                          endpoint=True)
second_order_freq_bucket = second_order_freq_bucket[:-1]
second_order_freq_bucket = [math.ceil(n) for n in second_order_freq_bucket]
# finally the bucket is a map between freq and bucket number
def find_bucket_number(freq, freq_bucket):
    for i in range(len(freq_bucket)):
        if freq > freq_bucket[i]:
            continue
        else:
            return i+1
    return len(freq_bucket)

second_order_freq_bucket_map = {}
for freq in second_order_freq_set:
    bucket_num = find_bucket_number(freq, second_order_freq_bucket)
    second_order_freq_bucket_map[freq] = bucket_num

In [None]:
modified_basic_tokenizer = ModifiedBasicTokenizer()
token_freq_freq_map = {} # overwrite this everytime for a new dataset
for i, example in enumerate(train_df):
    if i % 10000 == 0 and i != 0:
        print(f"processing #{i} example...")
    original_sentence = example['text']
    label = example['label']
    if len(original_sentence.strip()) != 0:
        tokens = modified_basic_tokenizer.tokenize(original_sentence)
        for i in range(len(tokens)-1):
            for j in range(i+1, len(tokens)):
                t1 = tokens[i]
                t2 = tokens[j]
                index_tuple = [second_order_freq_bucket_map[token_frequency_map[t1]], 
                               second_order_freq_bucket_map[token_frequency_map[t2]]]
                index_tuple.sort()
                index_tuple = tuple(index_tuple)
                if index_tuple in token_freq_freq_map.keys():
                    token_freq_freq_map[index_tuple] += 1
                else:
                    token_freq_freq_map[index_tuple] = 1
                    
for i, example in enumerate(eval_df):
    if i % 10000 == 0 and i != 0:
        print(f"processing #{i} example...")
    original_sentence = example['text']
    label = example['label']
    if len(original_sentence.strip()) != 0:
        tokens = modified_basic_tokenizer.tokenize(original_sentence)
        for i in range(len(tokens)-1):
            for j in range(i+1, len(tokens)):
                t1 = tokens[i]
                t2 = tokens[j]
                index_tuple = [second_order_freq_bucket_map[token_frequency_map[t1]], 
                               second_order_freq_bucket_map[token_frequency_map[t2]]]
                index_tuple.sort()
                index_tuple = tuple(index_tuple)
                if index_tuple in token_freq_freq_map.keys():
                    token_freq_freq_map[index_tuple] += 1
                else:
                    token_freq_freq_map[index_tuple] = 1
                    
for i, example in enumerate(test_df):
    if i % 10000 == 0 and i != 0:
        print(f"processing #{i} example...")
    original_sentence = example['text']
    label = example['label']
    if len(original_sentence.strip()) != 0:
        tokens = modified_basic_tokenizer.tokenize(original_sentence)
        for i in range(len(tokens)-1):
            for j in range(i+1, len(tokens)):
                t1 = tokens[i]
                t2 = tokens[j]
                index_tuple = [second_order_freq_bucket_map[token_frequency_map[t1]], 
                               second_order_freq_bucket_map[token_frequency_map[t2]]]
                index_tuple.sort()
                index_tuple = tuple(index_tuple)
                if index_tuple in token_freq_freq_map.keys():
                    token_freq_freq_map[index_tuple] += 1
                else:
                    token_freq_freq_map[index_tuple] = 1
                    
task_token_freq_freq_map = sorted(token_freq_freq_map.items(), key=operator.itemgetter(1), reverse=True)
task_token_freq_freq_map = OrderedDict(task_token_freq_freq_map)

In [None]:
# repartition the first order information
second_order_freq_freq_set = set([])
for k, v in task_token_freq_freq_map.items():
    second_order_freq_freq_set.add(v)
second_order_freq_freq_set = list(second_order_freq_freq_set)
second_order_freq_freq_set.sort()
# second_order_freq_freq_set = second_order_freq_freq_set[::-1]
# bucket_count = 48
# second_order_freq_freq_bucket = np.logspace(0, 
#                           math.log(len(second_order_freq_freq_set), 10), bucket_count, 
#                           endpoint=True)
# second_order_freq_freq_bucket = second_order_freq_freq_bucket[:-1]
# second_order_freq_freq_bucket = [math.ceil(n) for n in second_order_freq_freq_bucket]
# for i in range(1, len(second_order_freq_freq_bucket)):
#     if second_order_freq_freq_bucket[i] == second_order_freq_freq_bucket[i-1]:
#         second_order_freq_freq_bucket[i] += 1
# second_order_freq_freq_bucket += [len(second_order_freq_freq_set)]
# start = 0
# bucket_count = 0
# second_order_freq_freq_bucket_map = {}
# for i in range(len(second_order_freq_freq_bucket)):
#     end = second_order_freq_freq_bucket[i]
#     bucket_freqs = second_order_freq_freq_set[start:second_order_freq_freq_bucket[i]]
#     for freq in bucket_freqs:
#         second_order_freq_freq_bucket_map[freq] = bucket_count+1
#     bucket_count += 1
#     start = second_order_freq_freq_bucket[i]
second_order_freq_freq_bucket_map = {}
new_bucket_idx = 0
freq_bucket_map = {}
for freq in second_order_freq_freq_set:
    # bucket_num = find_bucket_number(freq, freq_bucket)
    second_order_freq_freq_bucket_map[freq] = new_bucket_idx
    new_bucket_idx += 1

bucket_length = new_bucket_idx # len(freq_bucket)
# the code above create second order buckets, now we can create second order BoW vectors!

In [None]:
# FBoW feature vectors for train split (2nd order = 1st order concat with 2nd order)
train_input_features = []
train_label_ids = []
for (ex_index, example) in enumerate(tqdm(train_df)):
    bow_feature = torch.zeros(bucket_length) # up-to 2nd feature map
    if sentence2_key is None:
        sentence_combined = example[sentence1_key]
    else:
        sentence_combined = example[sentence1_key] + " [SEP] " + example[sentence2_key]
    sentence_tokens = modified_basic_tokenizer.tokenize(sentence_combined)
    sentence_tokens = sentence_tokens[:max_length]
    # first order here!
#     for t in sentence_tokens:
#         bow_feature[freq_bucket_map[token_frequency_map[t]]-1] += 1 # bucket count
    # awesome :) second order here!
    for i in range(len(sentence_tokens)-1):
        for j in range(i+1, len(sentence_tokens)):
            t1 = sentence_tokens[i]
            t2 = sentence_tokens[j]
            index_tuple = [second_order_freq_bucket_map[token_frequency_map[t1]], second_order_freq_bucket_map[token_frequency_map[t2]]]
            index_tuple.sort()
            index_tuple = tuple(index_tuple)
            second_order_bucket = second_order_freq_freq_bucket_map[task_token_freq_freq_map[index_tuple]]
            bow_feature[second_order_bucket] += 1 # bucket count

    if ex_index % 50000 == 0:
        print("Example sentence: " + sentence_combined)
        print(bow_feature)
    train_input_features.append(bow_feature)
    train_label_ids.append(example["label"])
    
train_input_features = torch.stack(train_input_features, dim=0)
train_input_features = torch.tensor(train_input_features, dtype=torch.float)
train_label_ids = torch.tensor(train_label_ids, dtype=torch.long)
train_data = TensorDataset(train_input_features, train_label_ids)

In [None]:
# FBoW feature vectors for validation split
validation_input_features = []
validation_label_ids = []
for (ex_index, example) in enumerate(tqdm(eval_df)):
    bow_feature = torch.zeros(bucket_length) # up-to 2nd feature map
    if sentence2_key is None:
        sentence_combined = example[sentence1_key]
    else:
        sentence_combined = example[sentence1_key] + " [SEP] " + example[sentence2_key]
    sentence_tokens = modified_basic_tokenizer.tokenize(sentence_combined)
    sentence_tokens = sentence_tokens[:max_length]
    # first order here!
#     for t in sentence_tokens:
#         bow_feature[freq_bucket_map[token_frequency_map[t]]] += 1 # bucket count
    # awesome :) second order here!
    for i in range(len(sentence_tokens)-1):
        for j in range(i+1, len(sentence_tokens)):
            t1 = sentence_tokens[i]
            t2 = sentence_tokens[j]
            index_tuple = [second_order_freq_bucket_map[token_frequency_map[t1]], second_order_freq_bucket_map[token_frequency_map[t2]]]
            index_tuple.sort()
            index_tuple = tuple(index_tuple)
            second_order_bucket = second_order_freq_freq_bucket_map[task_token_freq_freq_map[index_tuple]]
            bow_feature[second_order_bucket] += 1 # bucket count

    validation_input_features.append(bow_feature)
    validation_label_ids.append(example["label"])

validation_input_features = torch.stack(validation_input_features, dim=0)
validation_input_features = torch.tensor(validation_input_features, dtype=torch.float)
validation_label_ids = torch.tensor(validation_label_ids, dtype=torch.long)
validation_data = TensorDataset(validation_input_features, validation_label_ids)

In [None]:
# data loader
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=per_device_train_batch_size*n_gpu)
validation_dataloader = DataLoader(validation_data, batch_size=per_device_eval_batch_size*n_gpu, shuffle=False)

In [None]:
# restart the model
model = BOWClassifier(len(validation_label_ids.unique()), 
                      bucket_length)
lr = 1e-3
optimizer = optim.Adam(model.parameters(), lr=lr)
if n_gpu > 0 and not no_cuda:
    model = torch.nn.DataParallel(model)

In [None]:
global_step = 0
num_train_epochs = 20
max_score = -1
for _ in range(int(num_train_epochs)):
    
    model.train()
    # pbar = tqdm(train_dataloader, desc="Iteration")
    for step, batch in enumerate(train_dataloader):
        if torch.cuda.is_available() and not no_cuda:
            torch.cuda.empty_cache()

        input_features, label_ids = batch

        if torch.cuda.is_available() and not no_cuda:
            input_features = input_features.to(device)
            label_ids = label_ids.to(device)

        loss, _ = model(input_features, labels=label_ids)

        if n_gpu > 1:
            loss = loss.mean() # mean() to average on multi-gpu.
        loss.backward()

        optimizer.step()
        model.zero_grad()
        # pbar.set_postfix({'train_loss': loss.tolist()})

        if global_step % 500 == 0:
            # logger.info("***** Evaluation Interval Hit *****")
            model.eval()
            all_logits = []
            all_label_ids = []
            with torch.no_grad():
                # pbar = tqdm(validation_dataloader, desc="Iteration")
                for step, batch in enumerate(validation_dataloader):
                    if torch.cuda.is_available() and not no_cuda:
                        torch.cuda.empty_cache()
                        
                    input_features, label_ids = batch
                    
                    if torch.cuda.is_available() and not no_cuda:
                        input_features = input_features.to(device)
                        label_ids = label_ids.to(device)
                    
                    loss, logits = model(input_features, labels=label_ids)
                    logits = F.softmax(logits, dim=-1)
                    logits = logits.detach().cpu().numpy()
                    label_ids = label_ids.to('cpu').numpy()
                    outputs = np.argmax(logits, axis=1)
                    all_logits.append(outputs)
                    all_label_ids.append(label_ids)
                    
            all_logits = np.concatenate(all_logits, axis=0)
            all_label_ids = np.concatenate(all_label_ids, axis=0)
            result_to_save = classification_report(all_label_ids, all_logits, digits=5, output_dict=True)
            # print(classification_report(all_label_ids, all_logits, digits=5))
            print("Macro-F1: ", result_to_save["macro avg"]["f1-score"])
            if result_to_save["macro avg"]["f1-score"] > max_score:
                max_score = result_to_save["macro avg"]["f1-score"]
                    
        global_step += 1
print("Best Macro-F1: ", max_score)