In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install lime anchor-exp transformers-interpret

Collecting lime
  Downloading lime-0.2.0.1.tar.gz (275 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/275.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m275.7/275.7 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting anchor-exp
  Downloading anchor_exp-0.0.2.0.tar.gz (427 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m427.3/427.3 kB[0m [31m43.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting transformers-interpret
  Downloading transformers_interpret-0.10.0-py3-none-any.whl (45 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.8/45.8 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
Collecting captum>=0.3.1 (from transformers-interpret)
  Downloading captum-0.7.0-py3-none-any.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [

In [None]:
import os

COLORS = {
    'red': '\033[91m',
    'green': '\033[92m',
    'yellow': '\033[93m',
    'blue': '\033[94m',
    'magenta': '\033[95m',
    'cyan': '\033[96m',
    'bold': '\033[1m',
    'reset': '\033[0m'
}

DEMET_PATH = 'drive/MyDrive/core/'
LOGPATH = DEMET_PATH + 'logs/'
if (os.path.exists(LOGPATH) == False):
    os.makedirs(LOGPATH)
else:
    os.makedirs(LOGPATH, exist_ok=True)
import os
import time

class Logger:
    def __init__(self, to_file = False):
        self.to_file = to_file
        self.data = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
        self.path = os.path.join(LOGPATH, 'log-' + self.data + '.log')
        self.colors = COLORS

    def __str__(self):
        if self.to_file:
            return 'Logging to file'

    def print(self, message, color = 'reset', bold = False):
        if bold:
            print(self.colors['bold'] + self.colors[color] + '[' + self.string_by_time() + ']:' + ' ' + self.colors['reset'] + message)
        else:
            print(self.colors[color] + '[' + self.string_by_time() + ']:' + ' ' + self.colors['reset'] + message)

    def log(self, message):
        if self.to_file:
            with open(self.path, 'a') as file:
                file.write('[' + self.string_by_time() + ']:' + ' ' + message + '\n')

    def print_and_log(self, message, color = 'reset', bold = False):
        if bold:
            print(self.colors['bold'] + self.colors[color] + '[' + self.string_by_time() + ']:' + ' ' + self.colors['reset'] + message)
        else:
            print(self.colors[color] + '[' + self.string_by_time() + ']:' + ' ' + self.colors['reset'] + message)
        if self.to_file:
            with open(self.path, 'a') as file:
                file.write('[' + self.string_by_time() + ']:' + ' ' + message + '\n')

    def string_by_time(self):
        return time.strftime('%H:%M:%S', time.localtime())

logger = Logger(True)

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import RobertaTokenizer, RobertaForSequenceClassification
import torch
import re


TOKENIZER_PATH = DEMET_PATH + 'models/distilbert/distilbert-base-cased_0'
MODEL_PATH = DEMET_PATH + 'models/distilbert/distilbert-base-cased_0'
MODEL_NAME = 'distil-bert'
EXPLANATION_PATH = DEMET_PATH + 'explanations/'
DATA_PATH = DEMET_PATH + 'shuffled_data.csv'


if 'roberta' in MODEL_NAME:
    logger.print_and_log('Using RoBERTa model for explanations','green')
    MODEL = RobertaForSequenceClassification.from_pretrained(MODEL_PATH, ignore_mismatched_sizes=True)
    TOKENIZER = RobertaTokenizer.from_pretrained(TOKENIZER_PATH)
else:
    logger.print_and_log('Using BERT model for explanations','green')
    TOKENIZER = AutoTokenizer.from_pretrained(TOKENIZER_PATH)
    MODEL = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)

class ExplainerConfig:
    def __init__(self, method):
        self.tokenizer = TOKENIZER
        self.explanation_path = EXPLANATION_PATH
        self.method = method
        self.model = MODEL
        self.model_name = MODEL_NAME

    def get_model(self):
        return self.model

    def get_model_name(self):
        return self.model_name

    def get_tokenizer(self):
        return self.tokenizer

    def get_explanation_path(self):
        return self.explanation_path

    def get_method(self):
        return self.method


def lime_tokenizer(text):
    cha_tokens = [
        r'\[CHA REPETITION\]',
        r'\[CHA RETRACING\]',
        r'\[CHA SHORT PAUSE\]',
        r'\[CHA MEDIUM PAUSE\]',
        r'\[CHA LONG PAUSE\]',
        r'\[CHA TRAILING OFF\]',
        r'\[CHA PHONOLOGICAL FRAGMENT\]',
        r'\[CHA INTERPOSED WORD\]',
        r'\[CHA FILLER\]',
        r'\[CHA NON COMPLETION OF WORD\]',
        r'\[CHA BELCHES\]',
        r'\[CHA HISSES\]',
        r'\[CHA GRUNTS\]',
        r'\[CHA WHINES\]',
        r'\[CHA COUGHS\]',
        r'\[CHA HUMS\]',
        r'\[CHA ROARS\]',
        r'\[CHA WHISTLES\]',
        r'\[CHA CRIES\]',
        r'\[CHA LAUGHS\]',
        r'\[CHA SNEEZES\]',
        r'\[CHA WHIMPERS\]',
        r'\[CHA GASPS\]',
        r'\[CHA MOANS\]',
        r'\[CHA SIGHS\]',
        r'\[CHA YAWNS\]',
        r'\[CHA GROANS\]',
        r'\[CHA MUMBLES\]',
        r'\[CHA SINGS\]',
        r'\[CHA YELLS\]',
        r'\[CHA GROWLS\]',
        r'\[CHA PANTS\]',
        r'\[CHA SQUEALS\]',
        r'\[CHA VOCALIZES\]',
        r'\[CHA TRAILING OFF QUESTION\]',
        r'\[CHA QUESTION WITH EXCLAMATION\]',
        r'\[CHA INTERRUPTION\]',
        r'\[CHA INTERRUPTION OF QUESTION\]',
        r'\[CHA SELFINTERRUPTION\]',
        r'\[CHA SELFINTERRUPTED QUESTION\]'
    ]
    pattern = '|'.join(cha_tokens) + r'|' + r'\w+'
    return re.findall(pattern, text)


def get_prediction_lime(texts):
    inputs = TOKENIZER(texts, padding=True, truncation=True, max_length=512, return_tensors="pt")

    with torch.no_grad():
        outputs = MODEL(**inputs)

    probs = torch.nn.functional.softmax(outputs.logits, dim=-1).numpy()
    return probs


[92m[21:01:28]: [0mUsing BERT model for explanations


In [None]:
import warnings
warnings.filterwarnings("ignore")

from transformers import logging
logging.set_verbosity_error()

from lime.lime_text import LimeTextExplainer
from anchor import anchor_text
from datetime import datetime
from transformers_interpret import SequenceClassificationExplainer
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import CountVectorizer
from sklearn import linear_model
import os
import spacy
import pandas as pd

df = pd.read_csv(DATA_PATH)
data = df['text'].tolist()
labels = df['gt'].tolist()
train, test, train_labels, test_labels = train_test_split(data, labels, test_size=.2, random_state=42)
train, val, train_labels, val_labels = train_test_split(train, train_labels, test_size=.1, random_state=42)

vectorizer = CountVectorizer(min_df=1)
vectorizer.fit(train)
train_vectors = vectorizer.transform(train)
test_vectors = vectorizer.transform(test)
val_vectors = vectorizer.transform(val)

c = linear_model.LogisticRegression()
c.fit(train_vectors, train_labels)
preds = c.predict(val_vectors)
def predict_lr(texts):
    return c.predict(vectorizer.transform(texts))

logger = Logger(True)

class Explainer:
    def __init__(self, ExplainerConfig):
        self.method = ExplainerConfig.get_method()
        self.explanation_path = ExplainerConfig.get_explanation_path()
        self.tokenizer = ExplainerConfig.get_tokenizer()
        self.model = ExplainerConfig.get_model()
        self.explanation_path += ExplainerConfig.get_method() + '/'
        self.explanation_path += ExplainerConfig.get_model_name() + '/'

        logger.print_and_log(str('Initializing ' + self.method.upper() + ' explainer ...'), 'green')

        if self.method == 'lime':
            self.explainer = LimeTextExplainer(class_names=['Non-Dementia', 'Dementia'], split_expression=lime_tokenizer)

        elif self.method == 'shap':
            # TODO: Implement SHAP
            pass

        elif self.method == 'anchor':
            nlp = spacy.load('en_core_web_sm')
            self.explainer = anchor_text.AnchorText(nlp, ['Non-Dementia', 'Dementia'], use_unk_distribution = True)

        elif self.method == 'transformer-interpret':
            self.explainer = SequenceClassificationExplainer(ExplainerConfig.get_model(), ExplainerConfig.get_tokenizer())

        self.tokenizer = ExplainerConfig.get_tokenizer()
        self.explanation = None
        self.explanation_name = ExplainerConfig.get_model_name()
        self.explanation_name += '_'
        self.timer_start = None
        self.timer_end = None

    def explain(self, text):
        self.timer_start = datetime.now()
        logger.print_and_log(str('Explaining with ' + self.method.upper() + ' ...'), 'green')
        try:
            if self.method == 'lime':
                self.explanation = self.explainer.explain_instance(text, get_prediction_lime, num_features=len(text.split(' ')))
                return self.explanation.as_list()

            elif self.method == 'shap':
                # TODO: Implement SHAP
                pass

            elif self.method == 'anchor':
                self.explanation = self.explainer.explain_instance(text, predict_lr, verbose=False)
                pred = self.explainer.class_names[predict_lr([text])[0]]
                alternative =  self.explainer.class_names[1 - predict_lr([text])[0]]
                temp = ""
                temp += 'Model Predicts %s\n' % pred
                temp += '\n'
                temp += 'Examples where anchor applies and model predicts %s:\n' % pred
                temp += '\n'.join([x[0] for x in self.explanation.examples(only_same_prediction=True)])
                temp += '\n\nExamples where anchor applies and model predicts %s:\n' % alternative
                temp += '\n'.join([x[0] for x in self.explanation.examples(only_different_prediction=True)])
                self.explanation = temp
                return self.explanation

            elif self.method == 'transformer-interpret':
                self.explanation = self.explainer(text, class_name='Dementia')
                return self.explanation

        except Exception as e:
            logger.print_and_log('Error in explaining: ' + str(e), 'red')

    def save(self):
        logger.print_and_log('Saving explanation...', 'green')
        self.timer_end = datetime.now()
        logger.print_and_log('Explanation took: ' + str(self.timer_end - self.timer_start), 'green')
        if not os.path.exists(self.explanation_path):
            os.makedirs(self.explanation_path, exist_ok=True)
        try:
            if self.method == 'lime':
                self.explanation.save_to_file(self.explanation_path + self.explanation_name + datetime.now().strftime("%d_%m_%Y-%H_%M_%S") + '.html')

            elif self.method == 'shap':
                # TODO: Implement SHAP
                pass

            elif self.method == 'anchor':
                with open(self.explanation_path + self.explanation_name + datetime.now().strftime("%d_%m_%Y-%H_%M_%S") + '.txt', 'w') as f:
                    f.write(str(self.explanation))
                time.sleep(10)

            elif self.method == 'transformer-interpret':
                self.explainer.visualize(self.explanation_path + self.explanation_name + datetime.now().strftime("%d_%m_%Y-%H_%M_%S") + '.html')

            with open(self.explanation_path + 'times.txt', 'a') as f:
                f.write(str(self.timer_end - self.timer_start) + '\n')

        except Exception as e:
            logger.print_and_log('Error in saving explanation: ' + str(e), 'red')



In [None]:
methods = ['lime']

df = pd.read_csv('drive/MyDrive/core/shuffled_data.csv')
data = df['text'].to_list()

confs = [ExplainerConfig(name) for name in methods]
explainers = [Explainer(conf) for conf in confs]
for explainer in explainers:
  explainer.explain("")
  explainer.save()

[92m[21:01:33]: [0mInitializing LIME explainer ...
[92m[21:01:33]: [0mExplaining with LIME ...
[92m[21:02:23]: [0mSaving explanation...
[92m[21:02:23]: [0mExplanation took: 0:00:50.355885


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import re

df = pd.read_csv('drive/MyDrive/core/shuffled_data.csv')

CHA_TOKENS = [

              '[CHA REPETITION]',
              '[CHA RETRACING]',
              '[CHA SHORT PAUSE]',
              '[CHA MEDIUM PAUSE]',
              '[CHA LONG PAUSE]',
              '[CHA TRAILING OFF]',
              '[CHA PHONOLOGICAL FRAGMENT]',
              '[CHA FILLER]',
              '[CHA NON COMPLETION OF WORD]',
              '[CHA LAUGHS]',
              '[CHA SIGHS]',
              '[CHA TRAILING OFF QUESTION]',
              '[CHA INTERRUPTION]',
              '[CHA INTERRUPTION OF QUESTION]',

              ]

dementia_counts = {token: 0 for token in CHA_TOKENS}
non_dementia_counts = {token: 0 for token in CHA_TOKENS}

for _, row in df.iterrows():
    text = row['text']
    gt = row['gt']
    for token in CHA_TOKENS:
        count = len(re.findall(re.escape(token), text))
        if gt == 1:
            dementia_counts[token] += count
        else:
            non_dementia_counts[token] += count

dementia_df = pd.DataFrame(list(dementia_counts.items()), columns=['token', 'count_dementia'])
non_dementia_df = pd.DataFrame(list(non_dementia_counts.items()), columns=['token', 'count_non_dementia'])

counts_df = pd.merge(dementia_df, non_dementia_df, on='token')

counts_melted = pd.melt(counts_df, id_vars='token', var_name='condition', value_name='count')

plt.figure(figsize=(20, 10))
sns.set(style="whitegrid")
sns.barplot(data=counts_melted, x='token', y='count', hue='condition', palette={'count_dementia': 'red', 'count_non_dementia': 'blue'})

plt.xlabel('CHA Tokens')
plt.ylabel('Count')
plt.title('CHA Token Counts in Dementia and Non-Dementia Transcripts')
plt.xticks(rotation=90, ha='center')
plt.legend(title='Condition', labels=['Dementia', 'Non-Dementia'])
plt.tight_layout()
filename = f'CHA_Token_Counts.png'
plt.savefig(DEMET_PATH + 'plots/' + filename)
plt.close()


In [None]:
counts_df['total'] = counts_df['count_dementia'] + counts_df['count_non_dementia']
counts_df['percentage_dementia'] = (counts_df['count_dementia'] / counts_df['total']) * 100
counts_df['percentage_non_dementia'] = (counts_df['count_non_dementia'] / counts_df['total']) * 100

for _, row in counts_df.iterrows():
    token = row['token']
    sizes = [row['percentage_dementia'], row['percentage_non_dementia']]
    labels = ['Dementia', 'Non-Dementia']
    colors = ['red', 'blue']
    explode = (0.1, 0)

    plt.figure(figsize=(6, 6))
    plt.pie(sizes, explode=explode, labels=labels, colors=colors, autopct='%1.1f%%', shadow=True, startangle=140)
    plt.title(f'Percentage Distribution for {token}')
    plt.axis('equal')
    filename = f'percentage_distribution_{token.strip("[]").replace(" ", "_")}.png'
    plt.savefig(DEMET_PATH + 'plots/' + filename)
    plt.close()

In [None]:
transcript_counts = {token: [] for token in CHA_TOKENS}

for _, row in df.iterrows():
    text = row['text']
    counts = {token: len(re.findall(re.escape(token), text)) for token in CHA_TOKENS}
    for token in CHA_TOKENS:
        transcript_counts[token].append(counts[token])

co_occurrence_df = pd.DataFrame(transcript_counts)

corr_matrix = co_occurrence_df.corr()

plt.figure(figsize=(20, 20))
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', center=0)
plt.xticks(rotation=90, ha='center')
plt.yticks(rotation=0)
plt.title('Correlation Matrix of CHA Token Co-occurrences', fontsize=16)
plt.tight_layout()
filename = 'token_co-occurance_matrix.png'
plt.savefig(DEMET_PATH + 'plots/' + filename, dpi=300)
plt.close()