In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import argparse
import torch
import json
from collections import defaultdict

import matplotlib.pyplot as plt
from tqdm.notebook import trange
from tqdm import tqdm

from diffmask.models.sentiment_classification_sst_diffmask import (
    BertSentimentClassificationSSTDiffMask,
    RecurrentSentimentClassificationSSTDiffMask,
    PerSampleDiffMaskRecurrentSentimentClassificationSSTDiffMask,
    PerSampleREINFORCERecurrentSentimentClassificationSSTDiffMask,
)
from diffmask.utils.plot import plot_sst_attributions

plt.rcParams['font.family'] = 'NanumGothic'

In [None]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--gpu", type=str, default="0")
    parser.add_argument("--model", type=str, default="./datasets/KorBERT")
    parser.add_argument("--train_filename", type=str, default="./datasets/nsmc/ratings_train.txt")
    parser.add_argument("--val_filename", type=str, default="./datasets/nsmc/ratings_test.txt")
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--gate_bias", action="store_true")
    parser.add_argument("--seed", type=float, default=0)
    parser.add_argument("--architecture", type=str, default="bert", choices=["gru", "bert"])
    parser.add_argument(
        "--model_path",
        type=str,
        default="outputs/models.ckpt",
#         or
#         default="models/sst-diffmask-input.ckpt",
    )
    parser.add_argument("--num_labels", type=int, default=2)
    parser.add_argument("--dataset", type=str, default="nsmc", choices=["nsmc", "kornli"])

    hparams, _ = parser.parse_known_args()

    torch.manual_seed(hparams.seed)

    os.environ["CUDA_VISIBLE_DEVICES"] = hparams.gpu

# Loading a model

In [None]:
device = "cuda:0"

if hparams.architecture == "bert":
    model = BertSentimentClassificationSSTDiffMask.load_from_checkpoint(hparams.model_path).to(device)
elif hparams.architecture == "gru":
    if "per_sample-diffmask" in hparams.model_path:
        model = PerSampleDiffMaskRecurrentSentimentClassificationSSTDiffMask.load_from_checkpoint(
            hparams.model_path
        ).to(device)
    elif "per_sample-reinforce" in hparams.model_path:
        model = PerSampleREINFORCERecurrentSentimentClassificationSSTDiffMask.load_from_checkpoint(
            hparams.model_path
        ).to(device)
    else:
        model = RecurrentSentimentClassificationSSTDiffMask.load_from_checkpoint(
            hparams.model_path
        ).to(device)

model.freeze()

# Creating and plotting DiffMask attributions

In [None]:
model.prepare_data()

save_path = ''

with open(save_path, 'w') as f:
    for i, batch in tqdm(enumerate(model.val_dataloader()), total=len(model.val_dataset) // model.hparams.batch_size):
        inputs = model.tokenizer.batch_encode_plus(batch[0], pad_to_max_length=True, return_tensors='pt').to(device)
        input_ids = inputs['input_ids']
        mask = inputs['attention_mask']
        token_type_ids = inputs['token_type_ids']
        labels = batch[1]

        attributions = model.forward_explainer(
            input_ids, mask, token_type_ids, attribution=True
        ).exp()

        for idx in range(len(batch[0])):
            source = model.tokenizer.tokenize(batch[0][idx])
            tokens = ["[CLS]"] + source + ["[SEP]"]
            rationale, rationale_idx = [], []
    
            gate = attributions[idx, :len(tokens)].cpu() >= 0.5
            score = list(map(lambda x: x >= 7, (gate.int()).sum(1)))
            for j, s in enumerate(score):
                if s and tokens[j] != '[CLS]' and tokens[j] != '[SEP]':
                    rationale.append(tokens[j])

            f.write(str(labels[idx].tolist()) + '\t' + ' '.join(rationale) + '\n')

In [None]:
source = "정말 최고의 영화. 어떤 영화 보다도 멋지고 아름답다"

source = model.tokenizer.tokenize(source)
tokens = ["[CLS]"] + source + ["[SEP]"]
for i in range(len(tokens)):
    tokens[i] = ' ' + tokens[i]
    
inputs_dict = {
    k: v.to(device)
    for k, v in model.tokenizer.encode_plus(
        source,
        pad_to_max_length=True,
        return_tensors="pt",
    ).items()
}
inputs_dict["mask"] = inputs_dict["attention_mask"]
del inputs_dict["attention_mask"]

attributions = model.forward_explainer(
    **inputs_dict, attribution=True
).exp()[0,:len(tokens)].cpu() >= 0.5

attributions = attributions.int()
plot_sst_attributions(attributions, tokens)