## CAKE experiment on HoC

In [2]:
import yake
import numpy as np
from numpy.linalg import norm
import re
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, average_precision_score
from dataset import Dataset
from myModel import MyModel, MyDataset
from myExplainers import MyExplainer
from myEvaluation import MyEvaluation
import pickle
from scipy.special import softmax
from tqdm import tqdm
import datetime
import csv
import warnings
import torch
import tensorflow as tf
from helper import print_results
from cake import CAKE

Load model, data and task

In [5]:
data_path = ''
model_path = 'Trained Models/'
save_path = '/home/myloniko/ethos/Results/HoC/'

In [6]:
model_name = 'bert'
existing_rationales = True

In [None]:
task = 'multi_label'
sentence_level = True
labels = 10

model = MyModel(model_path, 'bert_hoc2', model_name, task, labels, False, False)
max_sequence_len = model.tokenizer.max_len_single_sentence
tokenizer = model.tokenizer
torch.cuda.is_available()
model.trainer.model.to('cuda')

In [8]:
hoc = Dataset(path = data_path)
x, y, label_names, rationales = hoc.load_hoc()

Split data

In [9]:
indices = np.arange(len(y))
train_texts, test_texts, train_labels, test_labels, _, test_indexes = train_test_split(x, y, indices, test_size=.2, random_state=42)
if existing_rationales:
    test_rationales = [rationales[x] for x in test_indexes]

size = (0.1 * len(y)) / len(train_labels)
train_texts, validation_texts, train_labels, validation_labels = train_test_split(list(train_texts), train_labels, test_size=size, random_state=42)

In [10]:
test_label_rationales = []
for test_rational in test_rationales:
    label_rationales = []
    for label in range(labels):
        label_rationales.append([])
    for sentence in test_rational:
        for label in range(labels):
            if label_names[label] in sentence:
                label_rationales[label].append(1)
            else:
                label_rationales[label].append(0)
    test_label_rationales.append(label_rationales)

Define the label descriptions

In [12]:
description =["sustaining proliferative signaling label: refers to the ability of cancer cells to continuously promote their own growth and division through the activation of various signaling pathways. In normal cells, proliferation is tightly regulated by complex signaling networks that include growth factors, receptors, and downstream effectors. However, cancer cells can acquire mutations or alterations in these pathways that allow them to bypass normal regulatory mechanisms and promote their own uncontrolled growth. This hallmark of cancer is often associated with the activation of oncogenes, mutations in tumor suppressor genes, and dysregulation of signaling pathways such as the MAPK/ERK and PI3K/AKT pathways. Targeting these pathways has become an important therapeutic strategy in cancer treatment.",
              "resisting cell death label: refers to the ability of cancer cells to evade programmed cell death (apoptosis) which is a normal process that eliminates damaged or unwanted cells in the body. Cancer cells can acquire mutations or dysregulation in key apoptotic pathways that allow them to survive and continue to proliferate, even in unfavorable conditions. This hallmark is often associated with mutations in tumor suppressor genes such as TP53, and dysregulation of survival pathways such as the PI3K/AKT and NF-κB pathways.",
              "genomic instability and mutation label: refer to the accumulation of genetic alterations and mutations in cancer cells. Cancer cells can acquire mutations in oncogenes, tumor suppressor genes, and DNA repair genes that lead to the loss of normal functions and promote uncontrolled growth and survival. This hallmark is often associated with defects in DNA repair pathways, exposure to mutagens such as radiation or chemicals, and errors in DNA replication or segregation during cell division.",
              "activating invasion and metastasis label: refers refers to the ability of cancer cells to invade and spread to other tissues and organs. Cancer cells can acquire mutations or alterations in genes that regulate cell adhesion, migration, and invasion, allowing them to penetrate the basement membrane and invade nearby tissues. This hallmark is often associated with the activation of oncogenes such as RAS and EGFR, the loss of tumor suppressor genes such as PTEN and CDH1, and dysregulation of signaling pathways such as the WNT and TGF-β pathways.",
              "evading growth suppressors label: refers to the ability of cancer cells to overcome normal mechanisms that restrain cell growth and proliferation. Normal cells are subject to various checkpoints that ensure proper cell cycle progression and prevent uncontrolled proliferation, but cancer cells can acquire mutations or alterations in genes that bypass these checkpoints and allow them to divide indefinitely. This hallmark is often associated with mutations in tumor suppressor genes such as RB1 and TP53, and dysregulation of signaling pathways such as the CDK and mTOR pathways.",
              "tumor-promoting inflammation label: refers to the role of chronic inflammation in promoting cancer growth and progression. Inflammatory cells and mediators can create a microenvironment that supports the survival and proliferation of cancer cells, as well as promoting angiogenesis, invasion, and metastasis. This hallmark is often associated with chronic infections, autoimmune diseases, and exposure to environmental toxins or pollutants.",
              "inducing angiogenesis label: refers to the ability of cancer cells to stimulate the formation of new blood vessels that supply nutrients and oxygen to the tumor. Cancer cells can secrete pro-angiogenic factors that promote the proliferation and migration of endothelial cells, as well as suppressors of anti-angiogenic factors that normally prevent excessive blood vessel growth. This hallmark is often associated with the activation of oncogenes such as VEGF and FGF, and dysregulation of signaling pathways such as the HIF and Notch pathways.",
              "enabling replicative immortality label: * refers to the ability of cancer cells to bypass normal mechanisms that limit the number of times a cell can divide. Normal cells have a limited capacity to divide due to the shortening of telomeres, the protective caps on the ends of chromosomes, but cancer cells can acquire mutations or alterations in genes that maintain or lengthen telomeres, allowing them to divide indefinitely. This hallmark is often associated with the activation of telomerase or alternative lengthening of telomeres (ALT) pathways.",
              "avoiding immune destruction label: refers to the ability of cancer cells to evade recognition and destruction by the immune system. Normally, the immune system is able to detect and eliminate abnormal cells, including cancer cells, through a complex process of immune surveillance. However, cancer cells can develop various strategies to avoid detection and attack by the immune system, such as downregulating the expression of antigens that can be recognized by immune cells, producing immunosuppressive factors, and impairing the function of immune cells themselves. This hallmark of cancer is a major obstacle to the success of cancer immunotherapy, which aims to harness the power of the immune system to fight cancer.",
              "cellular energetics label: refers to the altered metabolic pathways and energy utilization patterns that are characteristic of cancer cells. Cancer cells have a high demand for energy and nutrients to support their uncontrolled growth and proliferation, and they often rely on different metabolic pathways than normal cells to meet these demands. One of the most well-known metabolic alterations in cancer cells is the 'Warburg effect,' which involves a shift toward aerobic glycolysis, even in the presence of sufficient oxygen. This altered metabolism provides cancer cells with a survival advantage and is thought to be involved in various other aspects of cancer progression, such as angiogenesis and metastasis. Targeting cancer cell metabolism has emerged as a promising strategy for cancer therapy."
]
len(description)

10

In [None]:
predictions = []
for test_text in test_texts:
    outputs = model.my_predict(test_text)
    predictions.append(outputs[0])

In [14]:
import tensorflow as tf
a = tf.constant(predictions, dtype = tf.float32)
b = tf.keras.activations.sigmoid(a)
predictions = b.numpy()

#Multi
pred_labels = []
for prediction in predictions:
    pred_labels.append([1 if i >= 0.5 else 0 for i in prediction])

def average_precision_wrapper(y, y_pred, view):
    return average_precision_score(y, y_pred.toarray(), average=view)

print(average_precision_score(test_labels, pred_labels, average='macro'), f1_score(test_labels, pred_labels, average='macro'))

0.7046940013545056 0.8243528479665017


In [15]:
del x, y, predictions, outputs, validation_labels, validation_texts, hoc

Create a small cake (CAKE's instance)

In [16]:
cake = CAKE(model_path = 'Trained Models/bert_hoc2', tokenizer = tokenizer, label_names = label_names, 
            label_descriptions = description, input_docs = train_texts, input_labels = train_labels, 
            input_docs_test = test_texts)

In [30]:
my_explainers = MyExplainer(label_names, model, sentence_level=True, cake = cake)

my_evaluators = MyEvaluation(label_names, model.my_predict, True, True)
my_evaluatorsP = MyEvaluation(label_names, model.my_predict, True, False)
evaluation =  {'F':my_evaluators.faithfulness, 'FTP': my_evaluators.faithful_truthfulness_penalty, 
          'NZW': my_evaluators.nzw, 'AUPRC': my_evaluators.auprc}
evaluationP = {'F':my_evaluatorsP.faithfulness, 'FTP': my_evaluatorsP.faithful_truthfulness_penalty, 
          'NZW': my_evaluatorsP.nzw, 'AUPRC': my_evaluators.auprc}

In [None]:
confs = []
for key_emb in [1, 2, 3]:
    for label_emb in [1, 2, "2_doc", 3]:
        for keyphrases in [5, 10, 15, 20]: 
            for width in [0, 1, 2, 3, 5]:
                for negatives in [True, False]:
                    confs.append([key_emb, label_emb, keyphrases, width, negatives])
len(confs)

In [None]:
import time
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=RuntimeWarning)
    
    now = datetime.datetime.now()
    file_name = save_path + 'HOC_BERT_CAKE_'+str(now.day) + '_' + str(now.month) + '_' + str(now.year)
    metrics = {'F':[], 'FTP':[], 'AUPRC': [], 'NZW':[]}
    metricsP = {'F':[], 'FTP':[], 'AUPRC': [], 'NZW':[]}
    time_r = []
    for conf in confs:
        time_r.append([])
    techniques = [my_explainers.cake_explain] 
    #for ind in tqdm(range(0,len(test_texts))):
    for ind in tqdm(range(0,len(test_texts))):
        torch.cuda.empty_cache() 
        test_label_rational = test_label_rationales[ind].copy()
        instance = test_texts[ind]
        if len(instance.split('.')) -1 < len(test_label_rational[0]):
            for label in range(labels):
                test_label_rational[label] = test_label_rational[label][:len(instance.split('.'))-1]
        my_evaluators.clear_states()
        my_evaluatorsP.clear_states()
        prediction, _, _ = model.my_predict(instance)
        enc = model.tokenizer([instance,instance], truncation=True, padding=True)[0]
        mask = enc.attention_mask
        tokens = enc.tokens
        
        if tokens.count('.') >= 2:
            interpretations = []
            kk = 0
            for conf in confs:
                #print(conf)
                ts = time.time()
                if conf[1] == 3:
                    my_explainers.cake_conf = [conf[0], conf[1], ind, conf[2], conf[3], conf[4]]
                else:
                    my_explainers.cake_conf = [conf[0], conf[1], None, conf[2], conf[3], conf[4]]
                temp = techniques[0](instance, prediction, tokens, mask, _, _)
                temp_tokens = tokens.copy()
                if sentence_level:
                    temp_tokens = temp[0].copy()[0]
                    temp = temp[1].copy()
                interpretations.append([np.array(i)/np.max(np.abs(i)) if np.max(np.abs(i))!=0 else np.zeros(len(i)) for i in temp])
                time_r[kk].append(time.time()-ts)
                kk = kk + 1
            for metric in metrics.keys():
                evaluated = []
                for interpretation in interpretations:
                    evaluated.append(evaluation[metric](interpretation, _, instance, prediction, temp_tokens, _, _, test_label_rational))
                metrics[metric].append(evaluated)
            my_evaluatorsP.saved_state = my_evaluators.saved_state.copy()
            my_evaluators.clear_states()
            for metric in metrics.keys():
                evaluatedP = []
                for interpretation in interpretations:
                    evaluatedP.append(evaluationP[metric](interpretation, _, instance, prediction, temp_tokens, _, _, test_label_rational))
                metricsP[metric].append(evaluatedP)
            with open(file_name+' (A).pickle', 'wb') as handle:
                pickle.dump(metrics, handle, protocol=pickle.HIGHEST_PROTOCOL)
            with open(file_name+' (P).pickle', 'wb') as handle:
                pickle.dump(metricsP, handle, protocol=pickle.HIGHEST_PROTOCOL)
            with open(file_name+'_TIME.pickle', 'wb') as handle:
                pickle.dump(time_r, handle, protocol=pickle.HIGHEST_PROTOCOL)
                
time_r = np.array(time_r)
time_r.mean(axis=1)

In [None]:
print_results(file_name+'(P)', confs, metricsP, label_names)

# Time analysis

In [18]:
confs = []
for key_emb in [1, 2, 3]:
    for label_emb in [1, 2, 3]:
        for keyphrases in [5, 10, 15, 20]:
            for width in [0, 1, 2, 3]:
                for negatives in [False]:
                    confs.append([key_emb, label_emb, keyphrases, width, negatives])
len(confs)

144

In [None]:
import time
from tqdm.notebook import tqdm

with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=RuntimeWarning)
    
    now = datetime.datetime.now()
    time_r = []
    for conf in confs:
        time_r.append([])
    techniques = [my_explainers.cake_explain] 
    for ind in tqdm(range(10)):
        torch.cuda.empty_cache() 
        test_label_rational = test_label_rationales[ind].copy()
        instance = test_texts[ind]
        if len(instance.split('.')) -1 < len(test_label_rational[0]):
            for label in range(labels):
                test_label_rational[label] = test_label_rational[label][:len(instance.split('.'))-1]
        my_evaluators.clear_states()
        my_evaluatorsP.clear_states()
        prediction, _, _ = model.my_predict(instance)
        enc = model.tokenizer([instance,instance], truncation=True, padding=True)[0]
        mask = enc.attention_mask
        tokens = enc.tokens
        
        if tokens.count('.') >= 2:
            interpretations = []
            kk = 0
            for conf in confs:
                ts = time.time()
                if conf[1] == 3:
                    my_explainers.cake_conf = [conf[0], conf[1], ind, conf[2], conf[3], conf[4]]
                else:
                    my_explainers.cake_conf = [conf[0], conf[1], None, conf[2], conf[3], conf[4]]
                temp = techniques[0](instance, prediction, tokens, mask, _, _)
                temp_tokens = tokens.copy()
                if sentence_level:
                    temp_tokens = temp[0].copy()[0]
                    temp = temp[1].copy()
                aa = [np.array(i)/np.max(np.abs(i)) if np.max(np.abs(i))!=0 else np.zeros(len(i)) for i in temp]
                time_r[kk].append(time.time()-ts)
                kk = kk + 1
time_r = np.array(time_r)
time_r.mean(axis=1)

In [29]:
c = np.array(time_r[0]).reshape((144,10))
c.shape

(144, 10)

In [30]:
list(zip(confs,list(c.mean(axis=1))))

[([1, 1, 5, 0, False], 0.4882839202880859),
 ([1, 1, 5, 1, False], 0.4836899995803833),
 ([1, 1, 5, 2, False], 0.47516965866088867),
 ([1, 1, 5, 3, False], 0.4787972211837769),
 ([1, 1, 10, 0, False], 0.5817655086517334),
 ([1, 1, 10, 1, False], 0.9842474937438965),
 ([1, 1, 10, 2, False], 1.005061960220337),
 ([1, 1, 10, 3, False], 1.016555166244507),
 ([1, 1, 15, 0, False], 0.9647884607315064),
 ([1, 1, 15, 1, False], 0.9507078409194947),
 ([1, 1, 15, 2, False], 0.9060809373855591),
 ([1, 1, 15, 3, False], 0.8923252344131469),
 ([1, 1, 20, 0, False], 0.9106181144714356),
 ([1, 1, 20, 1, False], 0.9065115928649903),
 ([1, 1, 20, 2, False], 0.7036923885345459),
 ([1, 1, 20, 3, False], 0.5504607200622559),
 ([1, 2, 5, 0, False], 0.5532642602920532),
 ([1, 2, 5, 1, False], 0.5472845792770386),
 ([1, 2, 5, 2, False], 0.5445155620574951),
 ([1, 2, 5, 3, False], 0.8771285057067871),
 ([1, 2, 10, 0, False], 0.9978608608245849),
 ([1, 2, 10, 1, False], 0.9837345123291016),
 ([1, 2, 10, 2, Fal

# Qualitative Example

We choose a random example!

In [23]:
pid = 15 
lid = 3  
print(test_labels[pid])
print(test_texts[pid])

[1, 0, 0, 1, 1, 0, 0, 0, 0, 0]
micrornas ( mirnas ) are involved in cancer development and progression , acting as tumor suppressors or oncogenes . in this study , mirna profiling was performed on 10 paired bladder cancer ( bc ) tissues using 20 genechiptm mirna array , and 10 differentially expressed mirnas were identified in bc and adjacent noncancerous tissues of any disease stage grade . after validated on expanded cohort of 67 paired bc tissues and 10 human bc cell lines by qrt pcr , it was found that mir 100 was down regulated most significantly in cancer tissues . ectopic restoration of mir 100 expression in bc cells suppressed cell proliferation and motility , induced cell cycle arrest in vitro , and inhibited tumorigenesis in vivo both in subcutaneous and intravesical passage . bioinformatic analysis showed that mtor gene was a direct target of mir 100, sirna mediated mtor knockdown phenocopied the effect of mir 100 in bc cell lines . in addition , the cancerous metastatic nud

We retrieve cake's explanation

In [24]:
results = cake.keyphrase_interpretation2(test_texts[pid], 5, 1, 1, 1, False, pid)
[[i,j] for i,j in zip(results[1],results[2][lid]) if j>0]

[['paired bladder cancer', 0.5448872],
 ['cancerous metastatic nude', 0.8914017],
 ['metastatic nude mouse', 0.74588424],
 ['distant metastatic foci', 0.86254436],
 ['tumor metastasis', 0.9092887]]