In [None]:
%load_ext autoreload
%autoreload 2

import os
import torch
import argparse

from tqdm.notebook import tqdm, trange

from diffmask.models.sentiment_classification_sst import (
    BertSentimentClassificationSST,
    MyDataset,
    my_collate_fn_token,
    load_sst
)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--gpu", type=str, default="0")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument(
        "--val_filename",
        type=str,
        default=""
    )
    parser.add_argument(
        "--val_rationale",
        type=str,
        default=""
   )
    parser.add_argument(
        "--model_path",
        type=str,
        default="",

    )
    
    hparams, _ = parser.parse_known_args()
    
    torch.manual_seed(hparams.seed)
    
    os.environ["CUDA_VISIBLE_DEVICES"] = hparams.gpu
    
device = "cuda:0"

model = BertSentimentClassificationSST.load_from_checkpoint(hparams.model_path).to(device)

model.freeze()
model.prepare_data()

In [None]:
val_dataset, _ = load_sst(
                hparams.val_filename, None, model.hparams.dataset, 3, hparams.val_rationale, model.hparams.token_cls
            )
val_dataloader = torch.utils.data.DataLoader(
            val_dataset, batch_size=model.hparams.batch_size, collate_fn=my_collate_fn_token, num_workers=8
        )

my_name = ''
dir_path = ''

with open(dir_path + 'rationale_token_' + my_name + ".txt", 'w') as f, \
        open(dir_path + "rationale_idx_" + my_name + ".txt", 'w') as w:
    for i, batch in tqdm(enumerate(val_dataloader), total=len(val_dataset) // model.hparams.batch_size):
#     for i, batch in tqdm(enumerate(model.val_dataloader()), total=len(model.val_dataset) // model.hparams.batch_size):
        inputs = model.tokenizer.batch_encode_plus(batch[0], pad_to_max_length=True, return_tensors='pt').to(device)
        input_ids = inputs['input_ids']
        mask = inputs['attention_mask']
        token_type_ids = inputs['token_type_ids']

        logits = model.forward(input_ids, mask, token_type_ids)[0]
        results = logits.argmax(-1)
        
        for idx in range(len(batch[0])): 
            source_1 = model.tokenizer.tokenize(batch[0][idx][0])
            source_2 = model.tokenizer.tokenize(batch[0][idx][1])
            tokens = ["[CLS]"] + source_1 + ["[SEP]"] + source_2 + ["[SEP]"]
            rationale, rationale_idx = [], []
            
            for j in range(mask[idx].sum()):
                if j == (len(source_1) + 2):
                    rationale.append('|')
                else:
                    if results[idx][j] == 1:
                        rationale.append(tokens[j])
                rationale_idx.append(str(int(results[idx][j])))

            f.write(' '.join(rationale) + '\n')
            w.write(' '.join(rationale_idx) + '\n')

In [None]:
%load_ext autoreload
%autoreload 2

import os
import argparse
import torch

from tqdm.notebook import tqdm, trange

from diffmask.models.sentiment_classification_sst import (
    BertSentimentClassificationSST,
    MyDataset,
    my_collate_fn,
    my_collate_fn_rationale,
    load_sst
)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--gpu", type=str, default="0")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument(
        "--val_filename",
        type=str,
        default=""
    )
    parser.add_argument(
        "--val_rationale",
        type=str,
        default=""
   )
    parser.add_argument(
        "--model_path",
        type=str,
        default="",
    )

    hparams, _ = parser.parse_known_args()
    
    torch.manual_seed(hparams.seed)
    
    os.environ["CUDA_VISIBLE_DEVICES"] = hparams.gpu

device = "cuda:0"

model = BertSentimentClassificationSST.load_from_checkpoint(hparams.model_path).to(device)

model.freeze()
model.prepare_data()

In [None]:
val_dataset, _ = load_sst(
                hparams.val_filename, None, model.hparams.dataset, model.hparams.num_labels, hparams.val_rationale, model.hparams.token_cls
            )
val_dataloader = torch.utils.data.DataLoader(
            val_dataset, batch_size=model.hparams.batch_size, collate_fn=my_collate_fn_rationale, num_workers=8
        )

val_acc, num = [0,0,0], [0,0,0]

# for i, batch in tqdm(enumerate(model.val_dataloader()), total=len(model.val_dataset) // model.hparams.batch_size):
for i, batch in tqdm(enumerate(val_dataloader), total=len(val_dataset) // model.hparams.batch_size):
    inputs = model.tokenizer.batch_encode_plus(batch[0], pad_to_max_length=True, return_tensors='pt').to(device)
    input_ids = inputs['input_ids']
    mask = inputs['attention_mask']
    token_type_ids = inputs['token_type_ids']
    labels = batch[1].to(device)
    # rationale_ids = batch[2].to(device)
    
    logits = model.forward(input_ids, mask, token_type_ids)[0]
    # logits = model.forward(input_ids, mask, token_type_ids, rationale_ids=rationale_ids)[0]

    for logit, label in zip(logits.argmax(-1), labels):
        val_acc[label] += (logit == label).int()
        num[label] += 1

print('entailment acc:', (val_acc[0] / num[0]), num[0]) # 'entailment': 0, 'neutral': 1, 'contradiction': 2
print('neutral acc:', (val_acc[1] / num[1]), num[1])
print('contradiction acc:', (val_acc[2] / num[2]), num[2])
print('all acc:', (sum(val_acc) / sum(num)))