In [1]:
import os
import json
import torch

import numpy as np
from scipy.special import softmax
from numpy import argmax

from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification

from datetime import datetime
from torch.utils.data import Dataset, DataLoader


class FactoryEncoder:

    def __init__(self, tokenizer, device, max_length):
        self.tokenizer = tokenizer
        self.device = device
        self.max_length = max_length
    # end

    # Q: WHY NOT USING PARAMETERS?
    # A: FOR FUN
    def get_instance(self, sentences):
        instance = SimpleEncoder(sentences)
        instance.set_tokenizer(self.tokenizer)
        instance.set_device(self.device)
        instance.set_max_length(self.max_length)
        return instance
    # end
# end

class SimpleEncoder(Dataset):
    def __init__(self, sentences):
        self.sentences = sentences
    # end

    def __getitem__(self, index):
        sentence = self.sentences[index]
        sentence = ' '.join(sentence.split())
        inputs = self.tokenizer.encode_plus(
            sentence, None,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            return_token_type_ids=False,
            truncation=True,
            return_tensors='pt'
        )

        for key in inputs:
            inputs[key].to(self.device)
        # end

        return inputs
    # end

    def set_tokenizer(self, tokenizer):
        self.tokenizer = tokenizer
    # end

    def set_device(self, device):
        self.device = device
    # end

    def set_max_length(self, max_length):
        self.max_length = max_length
    # end


class FactoryDecoder:
    def __init__(self, labels):
        self.labels = labels
    # end

    # outputs_raw:
    def get_instance(self, outputs_raw):
        instance = SimpleDecoder(self.labels, outputs_raw)
        instance.enable_label().enable_proba().enable_index()
        return instance
    # end
# end

class SimpleDecoder:

    def __init__(self, labels, outputs_raw):
        self.outputs_raw = outputs_raw
        self.labels = labels
        self.decoders = {}
    # end

    def enable_label(self):
        def decode_label(outputs_raw):
            return [self.labels for i in range(len(outputs_raw))]
        # end

        self.decoders['label'] = decode_label
        return self
    # end

    def enable_proba(self):
        def decode_proba(outputs_raw):
            # print('jinyuj: decode_proba:185 outputs_raw: {}'.format(outputs_raw))
            probas = softmax(outputs_raw, axis=1).tolist()
            return probas
        # end

        self.decoders['proba'] = decode_proba
        return self
    # end

    def enable_index(self):
        def decode_index(outputs_raw):
            return [list(range(len(self.labels))) for i in range(len(outputs_raw))]
        # end

        self.decoders['index'] = decode_index
        return self
    # end

    def decode(self, str_items_output):
        items_output = str_items_output.split(',')
        dict_output_decoded = {}

        for item_output in items_output:
            if item_output in self.decoders:
                func_decode = self.decoders[item_output]
                dict_output_decoded[item_output] = func_decode(self.outputs_raw)
            # end
        # end

        # get inner quantity(quantity of samples)
        num_samples = len(dict_output_decoded[list(dict_output_decoded.keys())[0]])
        num_klasses = len(dict_output_decoded[list(dict_output_decoded.keys())[0]][0])

        # infos_outputs = [[{}] * num_klasses] * num_samples
        infos_outputs = [[{} for j in range(num_klasses)] for i in range(num_samples)]

        # transform
        #  {
        #       'label': [['safe', 'comment'], ['safe', 'comment']],
        #       'proba': [[0.02207409217953682, 0.9779258370399475], [0.7, 0.3]],
        #       'index': [[0, 1], [0, 1]]
        #  }
        # to
        #  [[{'label': 'safe', 'proba': 0.01701054722070694, 'index': 0}, {'label': 'comment', 'proba': 0.982989490032196, 'index': 1}]]
        for item_output in dict_output_decoded:
            outputs_decoded = dict_output_decoded[item_output]
            for i, output_decoded in enumerate(outputs_decoded):        # i for num_samples
                for j, item_decoded in enumerate(output_decoded):
                    infos_outputs[i][j][item_output] = item_decoded           # output_decoded = ['safe', comment]
                # end
            # end
        # end transformation

        return infos_outputs
    # end
# end

class SimpleBertClassifier(torch.nn.Module):

    DEFAULT_FILENAME_LABEL = 'labels.json'

    def __init__(self, path_folder_model=None, name_model=None, max_length=512):
        super(SimpleBertClassifier, self).__init__()

        self.path_folder_model = path_folder_model
        self.name_model = name_model
        path_file_labels = os.path.join(self.path_folder_model, self.__class__.DEFAULT_FILENAME_LABEL)

        with open(path_file_labels, 'r') as file:
            self.labels_output_classifier = sorted(json.load(file))
        # end

        self.dict_label_index = {label: index for index, label in enumerate(self.labels_output_classifier)}
        self.classifier_max_length = max_length
        self.classifier = None

        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.loaded = False
    # end

    def forward(self, *inputs, **params):
        output = self.classifier(*inputs, **params)
        return output
    # end

    def load(self):
        if not self.loaded:
            self.classifier = DistilBertForSequenceClassification.from_pretrained(self.path_folder_model)
            self.tokenizer = DistilBertTokenizerFast.from_pretrained(self.name_model)

            self.factory_encoder = FactoryEncoder(self.tokenizer, self.device, self.classifier_max_length)
            self.factory_decoder = FactoryDecoder(self.labels_output_classifier)
            self.loaded = True
        # end

        return self
    # end

    def predicts(self, samples_input, outputs='label'):
        items_output = outputs  # transform name
        encoder = self.factory_encoder.get_instance(samples_input)
        outputs_raw = []
        for sample_encoded in encoder:
            # sample_encoded['output_hidden_states'] = True
            # sample_encoded['output_attentions'] = True

            with torch.no_grad():
                output_raw = self(**sample_encoded).logits.cpu().numpy().flatten()
                outputs_raw.append(output_raw)
            # end
        # end
        outputs_raw = np.array(outputs_raw)

        decoder = self.factory_decoder.get_instance(outputs_raw)
        info_output = decoder.decode(items_output)
        return info_output
    # end
# end

In [2]:
import functools
import numpy as np
from sklearn import linear_model
from sklearn import metrics
import copy


class SimpleLimeExplainer:

    TOKENIZER_DEFAULT = str.split

    def __init__(
            self,
            token_mask=None,
            num_samples=None,
            alpha=None,
            name_distance_function=None,
            scale_distance=None,
            name_solver=None,
            width_kernel=None
        ):
        # hard code configuration starts
        self.tokenizer = self.__class__.TOKENIZER_DEFAULT
        self.function_kernel = self._exponential_kernel
        self.width_kernel = width_kernel

        # parameter configuration starts
        self.token_mask = token_mask
        self.num_samples = num_samples
        self.alpha = alpha
        self.function_distance = functools.partial(metrics.pairwise.pairwise_distances, metric=name_distance_function)
        self.scale_distance = scale_distance
        self.name_solver = name_solver
    # end


    def explain(self, sentence, model_predict):
        tokens = self.tokenizer(sentence)
        num_features = len(tokens)

        masks_input = self._prepare_masks(self.num_samples, num_features)
        distances_input = self._calculate_distance(self.function_distance, self.scale_distance, self.function_kernel, self.width_kernel, masks_input)
        samples_input = list(self._get_perturbations(tokens, masks_input, self.token_mask))

        #
        # [
        #       [
        #           {'label': 'safe', 'proba': 0.01701054722070694, 'index': 0},        # item 1
        #           {'label': 'comment', 'proba': 0.982989490032196, 'index': 1},       # item 2
        #           {...}                                                               # item 3
        #       ],      # sample 1
        #       [...],  # sample 2
        #       ...     # sample N
        # ]
        list_items_sample = model_predict.predicts(samples_input, outputs='label,proba,index')    # numpy outputs
        list_items_sample_sorted = [sorted(items_sample, key=lambda item: -item['proba']) for items_sample in list_items_sample]
        indexes_base = np.array([items_sample_sorted[0]['index'] for items_sample_sorted in list_items_sample_sorted])
        # print('indexes_base: {}'.format(indexes_base.tolist()))
        info_final = {}
        items_sample_sorted_origin = list_items_sample_sorted[0]
        for item_sample_sorted_root in items_sample_sorted_origin:

            index_item = item_sample_sorted_root['index']
            label_item = item_sample_sorted_root['label']
            proba_item = item_sample_sorted_root['proba']
            # print('jinyuj: simple_lime_explainer.explain {} {} {}'.format(index_item, label_item, proba_item))

            indexes = copy.deepcopy(indexes_base)
            indexes[np.where(indexes != index_item)] = -1
            indexes[np.where(indexes == index_item)] = 1

            # print('indexes: {}'.format(indexes.tolist()))

            model_explain = linear_model.Ridge(alpha=self.alpha, solver=self.name_solver)
            model_explain.fit(masks_input, indexes, sample_weight=distances_input)  # one-time use
            coefs_token = model_explain.coef_

            info_final[label_item] = {
                'confidence': proba_item,
                'features': [(token, coefs_token) for token, coefs_token in zip(tokens, coefs_token)]
            }

        return info_final
    # end

    def _prepare_masks(self, num_samples, num_features):
        masks = self._sample_masks(num_samples + 1, num_features)
        # print('jinyuj: _prepare_masks:64: masks.shape: {}, num_samples: {}'.format(masks.shape, num_samples))
        assert masks.shape[0] == num_samples + 1, 'Expected num_samples + 1 masks.'
        all_true_mask = np.ones_like(masks[0], dtype=np.bool)
        masks[0] = all_true_mask
        return masks
    # end

    def _sample_masks(self, num_samples, num_features):
        rng = np.random.RandomState()
        positions = np.tile(np.arange(num_features), (num_samples, 1))
        permutation_fn = np.vectorize(rng.permutation, signature='(n)->(n)')
        permutations = permutation_fn(positions)  # A shuffled range of positions.
        num_disabled_features = rng.randint(1, num_features + 1, (num_samples, 1))
        return permutations >= num_disabled_features
    # end

    def _calculate_distance(self, function_distance, scale_distance, function_kernel, width_kernel, masks):
        distances = function_distance(masks[0].reshape(1, -1), masks).flatten()
        distances = scale_distance * distances
        distances = function_kernel(distances, width_kernel)
        return distances
    # end

    def _get_perturbations(self, tokens, masks, token_mask):
        for mask in masks:
            parts = [t if mask[i] else token_mask for i, t in enumerate(tokens)]
            yield ' '.join(parts)
        # end
    # end

    def _exponential_kernel(self, distance, width_kernel):
        return np.sqrt(np.exp(-(distance ** 2) / width_kernel ** 2))
    # end
# end

In [3]:
config_explainer = {
    "token_mask": "[MASK]",
    "alpha": 1.0,
    "name_distance_function": "cosine",
    "name_solver": "cholesky",
    "num_samples": 128,
    "width_kernel": 25,
    "scale_distance": 100.0
}
explainer = SimpleLimeExplainer(**config_explainer)

In [4]:
import os
import pandas as pd
import numpy as np
import random

In [5]:
# model_name = "bert-base-uncased"
model_name = "distilbert-base-uncased"
model_dir = 'models_target'
max_length = 512

dir_data = 'data'
name_data_file = 'log_content_3_20220209.csv'
#name_data_file_target = 'log_content_3_20220209_target.csv'

name_label_file = 'labels.json'
path_data_relative = os.path.join(dir_data, name_data_file)
path_label_relative = os.path.join(dir_data, name_label_file)
#path_file_target_relative = os.path.join(dir_data, name_data_file_target)

In [6]:
def read_passages(path_data, path_label):
    df = pd.read_csv(path_data)
    documents = df['processed'].to_list()
    
    with open(path_label, 'r') as file:
        labels_list = sorted(json.load(file))
    # end
    
    labels_all = {l:idx for idx, l in enumerate(labels_list)}

    return documents, labels_list
# end

In [7]:
texts_predict, target_names = read_passages(path_data_relative, path_label_relative)

In [8]:
classifier = SimpleBertClassifier('models_target', 'distilbert-base-uncased')
classifier.load()

SimpleBertClassifier(
  (classifier): DistilBertForSequenceClassification(
    (distilbert): DistilBertModel(
      (embeddings): Embeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (transformer): Transformer(
        (layer): ModuleList(
          (0): TransformerBlock(
            (attention): MultiHeadSelfAttention(
              (dropout): Dropout(p=0.1, inplace=False)
              (q_lin): Linear(in_features=768, out_features=768, bias=True)
              (k_lin): Linear(in_features=768, out_features=768, bias=True)
              (v_lin): Linear(in_features=768, out_features=768, bias=True)
              (out_lin): Linear(in_features=768, out_features=768, bias=True)
            )
            (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        

In [9]:
from datetime import datetime
from collections import defaultdict

dict_text_predict = defaultdict(list)

for id_text, text_predict in enumerate(texts_predict):
    
    if text_predict not in dict_text_predict:

        print('[{}] start to handle {}: {}'.format(datetime.now().strftime("%H:%M:%S"), id_text, text_predict))
        info_explained = explainer.explain(text_predict, classifier)

        with open('data/explains/{}.json'.format(id_text), 'w+') as file:
            file.write(json.dumps({
                    'processed': text_predict,
                    'detail': info_explained
                }, indent=4)
            )
        # end
        print('[{}] done with {}'.format(datetime.now().strftime("%H:%M:%S"), id_text))
    else:
        print('[{}] ignore duplicated case: {}'.format(datetime.now().strftime("%H:%M:%S"), id_text))
    # end
    dict_text_predict[text_predict].append(id_text)
# end

In [10]:
with open('data/dict_text_predict.json', 'w+') as file:
    file.write(json.dumps(dict_text_predict))
# end