In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install lime torch

In [3]:
import os

COLORS = {
    'red': '\033[91m',
    'green': '\033[92m',
    'yellow': '\033[93m',
    'blue': '\033[94m',
    'magenta': '\033[95m',
    'cyan': '\033[96m',
    'bold': '\033[1m',
    'reset': '\033[0m'
}

DEMET_PATH = 'drive/MyDrive/core/'
LOGPATH = DEMET_PATH + 'logs/'
if (os.path.exists(LOGPATH) == False):
    os.makedirs(LOGPATH)
else:
    os.makedirs(LOGPATH, exist_ok=True)
import os
import time

class Logger:
    def __init__(self, to_file = False):
        self.to_file = to_file
        self.data = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
        self.path = os.path.join(LOGPATH, 'log-' + self.data + '.log')
        self.colors = COLORS

    def __str__(self):
        if self.to_file:
            return 'Logging to file'

    def print(self, message, color = 'reset', bold = False):
        if bold:
            print(self.colors['bold'] + self.colors[color] + '[' + self.string_by_time() + ']:' + ' ' + self.colors['reset'] + message)
        else:
            print(self.colors[color] + '[' + self.string_by_time() + ']:' + ' ' + self.colors['reset'] + message)

    def log(self, message):
        if self.to_file:
            with open(self.path, 'a') as file:
                file.write('[' + self.string_by_time() + ']:' + ' ' + message + '\n')

    def print_and_log(self, message, color = 'reset', bold = False):
        if bold:
            print(self.colors['bold'] + self.colors[color] + '[' + self.string_by_time() + ']:' + ' ' + self.colors['reset'] + message)
        else:
            print(self.colors[color] + '[' + self.string_by_time() + ']:' + ' ' + self.colors['reset'] + message)
        if self.to_file:
            with open(self.path, 'a') as file:
                file.write('[' + self.string_by_time() + ']:' + ' ' + message + '\n')

    def string_by_time(self):
        return time.strftime('%H:%M:%S', time.localtime())

logger = Logger(True)

In [11]:
MODEL_PATH = '/content/drive/MyDrive/core/models/'
CORE_PATH = '/content/drive/MyDrive/core/'
DATA_PATH = '/content/drive/MyDrive/core/shuffled_data.csv'
DEMET_PATH = 'drive/MyDrive/core/'
EXPLANATION_PATH = DEMET_PATH + 'explanations/'

In [16]:
import numpy as np
from transformers import BertForSequenceClassification, BertTokenizer, RobertaForSequenceClassification, RobertaTokenizer
from lime.lime_text import LimeTextExplainer
import torch
import re
from datetime import datetime

model_names = [MODEL_PATH + 'bert/' + 'bert-base-cased_0', MODEL_PATH + 'roberta/' + 'roberta-base_0', MODEL_PATH + 'distilbert/' + 'distilbert-base-cased_0']
models = [BertForSequenceClassification.from_pretrained(model_names[0]), RobertaForSequenceClassification.from_pretrained(model_names[1]), BertForSequenceClassification.from_pretrained(model_names[2])]
tokenizers = [BertTokenizer.from_pretrained(model_names[0]), RobertaTokenizer.from_pretrained(model_names[1]), BertTokenizer.from_pretrained(model_names[2])]

def lime_tokenizer(text):
    cha_tokens = [
        r'\[CHA REPETITION\]',
        r'\[CHA RETRACING\]',
        r'\[CHA SHORT PAUSE\]',
        r'\[CHA MEDIUM PAUSE\]',
        r'\[CHA LONG PAUSE\]',
        r'\[CHA TRAILING OFF\]',
        r'\[CHA PHONOLOGICAL FRAGMENT\]',
        r'\[CHA INTERPOSED WORD\]',
        r'\[CHA FILLER\]',
        r'\[CHA NON COMPLETION OF WORD\]',
        r'\[CHA BELCHES\]',
        r'\[CHA HISSES\]',
        r'\[CHA GRUNTS\]',
        r'\[CHA WHINES\]',
        r'\[CHA COUGHS\]',
        r'\[CHA HUMS\]',
        r'\[CHA ROARS\]',
        r'\[CHA WHISTLES\]',
        r'\[CHA CRIES\]',
        r'\[CHA LAUGHS\]',
        r'\[CHA SNEEZES\]',
        r'\[CHA WHIMPERS\]',
        r'\[CHA GASPS\]',
        r'\[CHA MOANS\]',
        r'\[CHA SIGHS\]',
        r'\[CHA YAWNS\]',
        r'\[CHA GROANS\]',
        r'\[CHA MUMBLES\]',
        r'\[CHA SINGS\]',
        r'\[CHA YELLS\]',
        r'\[CHA GROWLS\]',
        r'\[CHA PANTS\]',
        r'\[CHA SQUEALS\]',
        r'\[CHA VOCALIZES\]',
        r'\[CHA TRAILING OFF QUESTION\]',
        r'\[CHA QUESTION WITH EXCLAMATION\]',
        r'\[CHA INTERRUPTION\]',
        r'\[CHA INTERRUPTION OF QUESTION\]',
        r'\[CHA SELFINTERRUPTION\]',
        r'\[CHA SELFINTERRUPTED QUESTION\]'
    ]
    pattern = '|'.join(cha_tokens) + r'|' + r'\w+'
    return re.findall(pattern, text)

explainer = LimeTextExplainer(class_names=['Non-Dementia', 'Dementia'], split_expression=lime_tokenizer)

def predict(texts, model, tokenizer):
    inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
    outputs = model(**inputs)
    probs = torch.nn.functional.softmax(outputs.logits, dim=1)
    return probs.detach().numpy()

#text = ""
text = ""

explanations = [explainer.explain_instance(text, lambda x: predict(x, model, tokenizer), num_features=len(text.split(' '))) for model, tokenizer in zip(models, tokenizers)]

def get_lime_values(exp):
    feature_values = exp.as_list()
    features, values = zip(*feature_values)
    return features, np.array(values)

all_features = []
lime_values = []

for exp in explanations:
    features, values = get_lime_values(exp)
    all_features.append(features)
    lime_values.append(values)

unique_features = list(set(f for features in all_features for f in features))
lime_values_aligned = []

for values, features in zip(lime_values, all_features):
    aligned_values = []
    for uf in unique_features:
        if uf in features:
            aligned_values.append(values[features.index(uf)])
        else:
            aligned_values.append(0.0)
    lime_values_aligned.append(np.array(aligned_values))

lime_values_aligned = np.array(lime_values_aligned)

weights = np.array([0.2, 0.5, 0.3])
weights = weights / np.sum(weights)

weighted_lime_values = np.average(lime_values_aligned, axis=0, weights=weights)

def update_html_with_weighted_values(explanation, unique_features, weighted_lime_values):
    html = explanation.as_html()
    for feature, value in zip(unique_features, weighted_lime_values):
        html = re.sub(r'({})" data-w="[-+]?[0-9]*\.?[0-9]+"'.format(re.escape(feature)), r'\1" data-w="{}"'.format(value), html)
    return html

weighted_html = update_html_with_weighted_values(explanations[0], unique_features, weighted_lime_values)

explanation_name = 'weighted_lime_explanation_'
file_path = os.path.join(EXPLANATION_PATH, explanation_name + datetime.now().strftime("%d_%m_%Y-%H_%M_%S") + '.html')
with open(file_path, 'w') as f:
    f.write(weighted_html)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
