In [1]:
import numpy as np
import pandas as pd
import os

from rationale_3players_sentence_classification_models import ClassifierModule, HardRationale3PlayerClassificationModel
from rationale_3players_for_emnlp import HardRationale3PlayerClassificationModelForEmnlp

import torch
from transformers import *
from torch.utils import data
from torch.autograd import Variable

from collections import deque

from tqdm import tqdm

To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html


## Specify arguments for the model and data processing

In [2]:
# load the data
DATA_FOLDER = os.path.join("../../sentiment_dataset/data/")
LABEL_COL = "label"
TEXT_COL = "sentence"
TO_LOWER = True
MAX_LEN = 150
BATCH_SIZE_PRED = 512
TRAIN_SIZE = 0.6
batch_size = 128
TOKEN_CUTOFF = 75

class Argument():
    def __init__(self):
        self.model_type = 'RNN'
        self.cell_type = 'GRU'
        self.hidden_dim = 400
        self.embedding_dim = 768
        self.kernel_size = 5
        self.layer_num = 1
        self.fine_tuning = False
        self.z_dim = 2
        self.gumbel_temprature = 0.1
        self.cuda = True
        self.batch_size = 40
        self.mlp_hidden_dim = 50
        self.dropout_rate = 0.4
        self.use_relative_pos = True
        self.max_pos_num = 20
        self.pos_embedding_dim = -1
        self.fixed_classifier = True
        self.fixed_E_anti = True
        self.lambda_sparsity = 3.0
        self.lambda_continuity = 1.0
        self.lambda_anti = 1.0
        self.lambda_pos_reward = 0.1
        self.exploration_rate = 0.05
        self.highlight_percentage = 0.3
        self.highlight_count = 8
        self.count_tokens = 8
        self.count_pieces = 4
        self.lambda_acc_gap = 1.2
        self.label_embedding_dim = 400
        self.game_mode = '3player'
        self.margin = 0.2
#         self.lm_setting = 'single'
        self.lm_setting = 'multiple'
#         self.lambda_lm = 100.0
        self.lambda_lm = 1.0
        self.ngram = 4
        self.with_lm = False
        self.batch_size_ngram_eval = 5
        self.lr=0.001
        self.working_dir = '/dccstor/yum-dbqa/Rationale/structured_rationale/game_model_with_lm/beer_single_working_dir'
        self.model_prefix = 'tmp.%s.highlight%.2f.cont%.2f'%(self.game_mode, 
                                                                             self.highlight_percentage, 
                                                                             self.lambda_continuity)
        self.pre_trained_model_prefix = 'pre_trained_cls.model'

        self.save_path = os.path.join("..", "models")
        self.model_prefix = "sst2rnpmodel"
        self.save_best_model = True
        self.num_labels = 2
        
args = Argument()


# Embedding Layer

We want to use the pre-trained BERT embeddings, which generates embedded word vectors from word tokens.

#### Process for a single sentence
1.) generate_tokens() takes the BERT tokenizer and a sentence and tokenizes this text, to a limit of TOKEN_CUTOFF tokens. If the number of tokens is less than TOKEN_CUTOFF, it pads the tokens with the BERT pad symbol. It also provides a mask that can be used to ignore any pad tokens in classifier models down the road.<br>
2.) embedding_func() takes the tokens from generate_tokens and uses them to make a corresponding embedding.

#### Multiple sentences
get_all_tokens takes a pandas dataframe, and adds columns tokens and mask into that dataframe.

In [3]:
pretrained_weights = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(pretrained_weights)

# Load pre-trained model (weights)
model = BertModel.from_pretrained('bert-base-uncased')
# Put the model in "evaluation" mode, meaning feed-forward operation.
if args.cuda:
    model.cuda()
model.eval()

def generate_tokens(tokenizer, text):
    tokenized_text = tokenizer.tokenize(text)
    tokenized_text = tokenized_text[:TOKEN_CUTOFF - 2]
    tokenized_text = ["[CLS]"] + tokenized_text + ["[SEP]"]
    pad_length = TOKEN_CUTOFF - len(tokenized_text)
    mask = [1] * len(tokenized_text) + [0] * pad_length
    
    tokenized_text = tokenized_text + ["[PAD]"] * pad_length
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    
    return np.array(indexed_tokens), np.array(mask)
    
def embedding_func(tokens):
    ones_mask = Variable(torch.from_numpy(np.ones((len(tokens), TOKEN_CUTOFF))))
    if args.cuda:
        ones_mask = ones_mask.cuda()
    with torch.no_grad():
        embeddings = model(tokens, ones_mask)[0]        
    return embeddings

def get_all_tokens(data):
    l = []
    m = []
    for sentence in data:
        token_list, mask = generate_tokens(tokenizer, sentence)
        l.append(token_list)
        m.append(mask)
    tokens = pd.DataFrame({"tokens": l, "mask": m})
    return tokens


## Load the Data

We read the data from files that already have it split between test and train sets.


In [4]:
def load_data(fpath):
    df_dict = {LABEL_COL: [], TEXT_COL: []}
    with open(fpath, 'r') as f:
        label_start = 0
        sentence_start = 2
        for line in f:
            label = int(line[label_start])
            sentence = line[sentence_start:]
            df_dict[LABEL_COL].append(label)
            df_dict[TEXT_COL].append(sentence)
    return pd.DataFrame.from_dict(df_dict)

df_train = load_data(os.path.join(DATA_FOLDER, 'stsa.binary.train'))
df_test = load_data(os.path.join(DATA_FOLDER, 'stsa.binary.test'))

# create training and testing labels
y_train = df_train[LABEL_COL]
y_test = df_test[LABEL_COL]

# create training and testing inputs
X_train = df_train[TEXT_COL]
X_test = df_test[TEXT_COL]

df_train = pd.concat([df_train, get_all_tokens(X_train)], axis=1)
df_test = pd.concat([df_test, get_all_tokens(X_test)], axis=1)

## Set up the model

In [5]:
args = Argument()

classification_model = HardRationale3PlayerClassificationModelForEmnlp(embedding_func, args)

if args.cuda:
    classification_model.cuda()

classification_model.init_optimizers()
classification_model.init_C_model()

args.fixed_E_anti = False
classification_model.fixed_E_anti = args.fixed_E_anti
args.with_lm = False
args.lambda_lm = 1.0

train_losses = []
train_accs = []
dev_accs = [0.0]
dev_anti_accs = [0.0]
dev_cls_accs = [0.0]
test_accs = [0.0]
test_anti_accs = [0.0]
test_cls_accs = [0.0]
best_dev_acc = 0.0
best_test_acc = 0.0
num_iteration = 100
display_iteration = 1
test_iteration = 1

eval_accs = [0.0]
eval_anti_accs = [0.0]

queue_length = 200
z_history_rewards = deque(maxlen=queue_length)
z_history_rewards.append(0.)

classification_model.init_optimizers()
classification_model.init_rl_optimizers()
classification_model.init_reward_queue()

old_E_anti_weights = classification_model.E_anti_model.predictor._parameters['weight'][0].cpu().data.numpy()



### Utilization Functions

In [8]:
def generate_data(batch):
    x_mat = np.stack(batch["tokens"], axis=0)
    x_mask = np.stack(batch["mask"], axis=0)
    y_vec = np.stack(batch["label"], axis=0)
    
    batch_x_ = Variable(torch.from_numpy(x_mat)).to(torch.int64)
    batch_m_ = Variable(torch.from_numpy(x_mask)).type(torch.FloatTensor)
    batch_y_ = Variable(torch.from_numpy(y_vec)).to(torch.int64)

    if args.cuda:
        batch_x_ = batch_x_.cuda()
        batch_m_ = batch_m_.cuda()
        batch_y_ = batch_y_.cuda()

    return batch_x_, batch_m_, batch_y_

def _get_sparsity(z, mask):
    mask_z = z * mask
    seq_lengths = torch.sum(mask, dim=1)

    sparsity_ratio = torch.sum(mask_z, dim=-1) / seq_lengths #(batch_size,)
#     sparsity_count = torch.sum(mask_z, dim=-1)

    return sparsity_ratio

def _get_continuity(z, mask):
    mask_z = z * mask
    seq_lengths = torch.sum(mask, dim=1)
    
    mask_z_ = torch.cat([mask_z[:, 1:], mask_z[:, -1:]], dim=-1)
        
    continuity_ratio = torch.sum(torch.abs(mask_z - mask_z_), dim=-1) / seq_lengths #(batch_size,) 
#     continuity_count = torch.sum(torch.abs(mask_z - mask_z_), dim=-1)
    
    return continuity_ratio

def display_example(x, m, z):
    seq_len = int(m.sum().item())
    ids = x[:seq_len]
    tokens = tokenizer.convert_ids_to_tokens(ids)
    
    final = ""
    for i in range(1, len(tokens) - 1):
        if z[i]:
            final += "[" + tokens[i] + "]"
        else:
            final += tokens[i]
        final += " "
    print(final)

def test():
    classification_model.eval()
    
    test_batch = df_test.sample(100)
    batch_x_, batch_m_, batch_y_ = generate_data(test_batch)
    predict, anti_predict, z, neg_log_probs = classification_model(batch_x_, batch_m_)
    
    # do a softmax on the predicted class probabilities
    _, y_pred = torch.max(predict, dim=1)
    
    # calculate sparsity
    print("Test sparsity: ", _get_sparsity(z, batch_m_).sum().item() / batch_size)
    
    accuracy = (y_pred == batch_y_).sum().item()
    print("Test accuracy: ", accuracy, "%")

    # display an example
    print("Gold Label: ", batch_y_[0].item(), " Pred label: ", y_pred[0].item())
    display_example(batch_x_[0], batch_m_[0], z[0])

## Train

In [None]:
test_freq = 10

for iteration in tqdm(range(100)):
    classification_model.train()

    # sample a batch of data
    batch = df_train.sample(batch_size, replace=True)
    batch_x_, batch_m_, batch_y_ = generate_data(batch)

    losses, predict = classification_model.train_cls_one_step(batch_x_, batch_y_, batch_m_)

    # calculate classification accuarcy
    _, y_pred = torch.max(predict, dim=1)

    acc = np.float((y_pred == batch_y_).sum().cpu().data.item()) / args.batch_size
    train_accs.append(acc)
    
    if iteration % test_freq == 0:
        test()

  1%|▊                                                                                 | 1/100 [00:03<06:06,  3.70s/it]

Test sparsity:  0.18539123237133026
Test accuracy:  49 %
Gold Label:  0  Pred label:  1
dr ##ear ##y , highly annoying . . . ` some body ' will appeal to no one . 


 11%|████████▉                                                                        | 11/100 [00:26<03:56,  2.66s/it]

Test sparsity:  0.1956399530172348
Test accuracy:  63 %
Gold Label:  1  Pred label:  1
what kids will discover is a new [collect] [##ible] [.] 


 21%|█████████████████                                                                | 21/100 [00:48<03:21,  2.55s/it]

Test sparsity:  0.1972774863243103
Test accuracy:  61 %
Gold Label:  0  Pred label:  1
acting , particularly by tam [##bor] [,] [almost] [makes] ` ` never again ' ' worth ##while , but - l ##rb - writer [\] [/] director - rr ##b - sc ##hae ##ffer should [follow] his [titular] [advice] 


 31%|█████████████████████████                                                        | 31/100 [01:10<03:00,  2.61s/it]

Test sparsity:  0.19416850805282593
Test accuracy:  73 %
Gold Label:  1  Pred label:  0
this is such a high - energy movie where the drumming and the marching are [so] excellent , who cares if the story ' [s] a little weak [.] 


 41%|█████████████████████████████████▏                                               | 41/100 [01:32<02:30,  2.55s/it]

Test sparsity:  0.18205201625823975
Test accuracy:  63 %
Gold Label:  0  Pred label:  1
. . . a weak , mani ##pu ##lative , pencil - thin story that is miraculous ##ly able to entertain anyway [.] 


 51%|█████████████████████████████████████████▎                                       | 51/100 [01:54<02:04,  2.55s/it]

Test sparsity:  0.18922418355941772
Test accuracy:  76 %
Gold Label:  1  Pred label:  1
chicago [is] sophisticated , [bra] ##sh [,] [sar] [##don] ##ic , completely [joy] ##ful in its [execution] [.] 


 61%|█████████████████████████████████████████████████▍                               | 61/100 [02:17<01:39,  2.55s/it]

Test sparsity:  0.17030000686645508
Test accuracy:  75 %
Gold Label:  0  Pred label:  0
just one more collection of penis [,] breast [and] [flat] [##ule] [##nce] [gag] ##s in search of a story [.] 


 71%|█████████████████████████████████████████████████████████▌                       | 71/100 [02:39<01:13,  2.55s/it]

Test sparsity:  0.19955259561538696
Test accuracy:  70 %
Gold Label:  0  Pred label:  1
notorious c . h . o . has o ##odle ##s of vulgar highlights [.] 


 81%|█████████████████████████████████████████████████████████████████▌               | 81/100 [03:01<00:48,  2.54s/it]

Test sparsity:  0.1928156316280365
Test accuracy:  71 %
Gold Label:  1  Pred label:  1
a dec ##ei ##ving ##ly simple film , one that grows in power in retro ##sp [##ect] [.] 


 82%|██████████████████████████████████████████████████████████████████▍              | 82/100 [03:03<00:43,  2.39s/it]