# Library

In [None]:
# My library
from molgraph.graphmodel import *

import pickle
import numpy as np
import pandas as pd
import scipy.stats
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
from tqdm import tqdm
import os
from matplotlib.offsetbox import AnchoredText

# SINGLE

In [None]:
# Load model
file = 'bbbp'
model = 'GAT'
schema = 'A'
reduced = ['']

ts = "2023-Jan-09-23:29:08"
reduced_list = '_'.join(reduced)
log_folder_name = os.path.join(*[file, model+'_'+schema+'_'+reduced_list, f"{ts}"])

path = './dataset/'+log_folder_name+'/attention0.pickle'
with open(path, 'rb') as handle:
    attention1 = pickle.load(handle)

In [None]:
# Load model
file = 'bbbp'
model = 'GIN'
schema = 'R'
reduced = ['functional']

ts = "2022-Nov-19-05:42:56"
reduced_list = '_'.join(reduced)
log_folder_name = os.path.join(*[file, model+'_'+schema+'_'+reduced_list, f"{ts}"])

path = './dataset/'+log_folder_name+'/attention.pickle'
with open(path, 'rb') as handle:
    attention2 = pickle.load(handle)

In [None]:
attention_pearson = []
scatter_all = [list(), list()]
for a in tqdm(attention1):
    compare = []
    for g in attention1[a]:
        compare.append(list(attention1[a][g].values()))
        scatter_all[0].extend(list(attention1[a][g].values()))
    for g in attention2[a]:
        compare.append(list(attention2[a][g].values()))
        scatter_all[1].extend(list(attention2[a][g].values()))
    attention_pearson.append(scipy.stats.pearsonr(compare[0], compare[1])[0])
    # break

In [None]:
figure(figsize=(8, 5), dpi=80)
fig = sns.histplot(data=attention_pearson, bins=np.arange(-1.1, 1.1, 0.05))
fig.axvline(x=np.mean(attention_pearson), linewidth=1, color='black')
fig.set_xlim([-1,1])
plt.xlabel('spearman')

In [None]:
figure(figsize=(8, 5), dpi=80)
fig = sns.scatterplot(x=scatter_all[0], y=scatter_all[1])
fig.set_xlim([0,1])
fig.set_ylim([0,1])
plt.xlabel('atom')
plt.ylabel('reduced')

# COMBINE

In [None]:
# Load model
file = 'bbbp'
model = 'GIN'
schema = 'AR_0'
reduced = ['pharmacophore']

ts = "2023-Apr-29-17:22:22"
reduced_list = '_'.join(reduced)
log_folder_name = os.path.join(*[file, model+'_'+schema+'_'+reduced_list, f"{ts}"])

path = './dataset/'+log_folder_name+'/attention1.pickle'
with open(path, 'rb') as handle:
    attention = pickle.load(handle)

In [None]:
attention_pearson = []
scatter_all = dict()
for a in tqdm(attention):
    compare = []
    if len(attention[a]) != 2:
        print(a, attention[a])
    for g in attention[a]:
        if g not in scatter_all:
            scatter_all[g] = list()
        compare.append(list(attention[a][g].values()))
        scatter_all[g].extend(list(attention[a][g].values()))
    # attention_pearson.append(scipy.stats.pearsonr(compare[0], compare[1])[0])
    r = scipy.stats.spearmanr(compare[0], compare[1])[0]
    attention_pearson.append(r)
    # break

In [None]:
figure(figsize=(8, 5), dpi=150)
fig = sns.histplot(data=attention_pearson, bins=np.arange(-1.1, 1.1, 0.1))
fig.axvline(x=np.nanmean(attention_pearson), linewidth=1, color='black')
fig.set_xlim([-1,1])
plt.xlabel('spearmanr')
plt.title('Attention Correlation - '+'A+'+reduced[0][0].upper())
print(len(attention_pearson))
print(np.nanmean(attention_pearson))
anchored_text = AnchoredText("AVG = {:.4f}".format(np.nanmean(attention_pearson)), loc='upper right', prop=dict(size='small'))
fig.add_artist(anchored_text)

In [None]:
figure(figsize=(8, 5), dpi=80)
fig = sns.scatterplot(x=scatter_all[list(scatter_all.keys())[0]], y=scatter_all[list(scatter_all.keys())[1]])
fig.set_xlim([0,1])
fig.set_ylim([0,1])
plt.xlabel('atom')
plt.ylabel('reduced')

# CROSS

In [None]:
attention_pearson = []
scatter_all = [list(), list()]
for a in tqdm(attention):
    compare = []
    compare.append(list(attention1[a]['atom'].values()))
    scatter_all[0].extend(list(attention1[a]['atom'].values()))
    compare.append(list(attention[a]['substructure'].values()))
    scatter_all[1].extend(list(attention[a]['substructure'].values()))
    # attention_pearson.append(scipy.stats.pearsonr(compare[0], compare[1])[0])
    attention_pearson.append(scipy.stats.spearmanr(compare[0], compare[1])[0])
    # break

In [None]:
figure(figsize=(8, 5), dpi=80)
fig = sns.histplot(data=attention_pearson, bins=np.arange(-1.1, 1.1, 0.1))
fig.axvline(x=np.mean(attention_pearson), linewidth=1, color='black')
fig.set_xlim([-1,1])
plt.xlabel('spearmanr')

In [None]:
figure(figsize=(8, 5), dpi=80)
fig = sns.scatterplot(x=scatter_all[0], y=scatter_all[1])
fig.set_xlim([0,1])
fig.set_ylim([0,1])
plt.xlabel('atom')
plt.ylabel('reduced')

In [None]:
attention_pearson = []
scatter_all = [list(), list()]
for a in tqdm(attention):
    compare = []
    compare.append(list(attention2[a]['functional'].values()))
    scatter_all[0].extend(list(attention2[a]['functional'].values()))
    compare.append(list(attention[a]['functional'].values()))
    scatter_all[1].extend(list(attention[a]['functional'].values()))
    attention_pearson.append(scipy.stats.pearsonr(compare[0], compare[1])[0])
    # break

In [None]:
figure(figsize=(8, 5), dpi=80)
fig = sns.histplot(data=attention_pearson, bins=np.arange(-1.1, 1.1, 0.1))
fig.axvline(x=np.mean(attention_pearson), linewidth=1, color='black')
fig.set_xlim([-1,1])
plt.xlabel('spearman')

In [None]:
figure(figsize=(8, 5), dpi=80)
fig = sns.scatterplot(x=scatter_all[0], y=scatter_all[1])
fig.set_xlim([0,1])
fig.set_ylim([0,1])
plt.xlabel('atom')
plt.ylabel('reduced')

# SINGLE ALL

In [None]:
attention_all = list()
# Load model
file = 't04_CYP2C8_533'
model = 'GIN'
schema = 'A'
reduced = []
fold = 5

for f in range(fold):
    s = schema
    if s == 'A':
        directory = './dataset/'+os.path.join(*[file, model+'_'+s+'_'])
        ts = next(os.walk(directory))[1][0]
        reduced_list = ''
        log_folder_name = os.path.join(*[file, model+'_'+s+'_'+reduced_list, f"{ts}"])

        path = './dataset/'+log_folder_name+'/attention'+str(f)+'.pickle'
        with open(path, 'rb') as handle:
            attention_all.append(pickle.load(handle))
    else:
        for r in reduced:
            directory = './dataset/'+os.path.join(*[file, model+'_'+s+'_'+r])
            ts = sorted(next(os.walk(directory))[1])[0]
            reduced_list = r
            log_folder_name = os.path.join(*[file, model+'_'+s+'_'+reduced_list, f"{ts}"])

            path = './dataset/'+log_folder_name+'/attention'+str(f)+'.pickle'
            with open(path, 'rb') as handle:
                attention_all.append(pickle.load(handle))

In [None]:
len(attention_all)

In [None]:
for a in attention_all:
    for aa in a:
        print(a[aa])
        break
    # break

In [None]:
from rdkit.Chem import AllChem
fingerprint = list()
test_y = list()
attention_result = [[] for _ in range(len(attention_all))]

for a in attention_all[0]: # loop smilse from atom graph
    for i in range(len(attention_all)): # loop graph
        for g in attention_all[i][a]: # loop node of smiles a in graph i
            # if len(list(attention_all[i][a][g])) != 0 and g != 'atom':
            if len(list(attention_all[i][a][g])) != 0:
                attention_result[i].append(list(attention_all[i][a][g].values()))

attention_result_t = []
for r in attention_result:
    attention_result_t.append(np.transpose(r))

attention_result = attention_result_t

In [None]:
len(attention_result)

In [None]:
for i in attention_result:
    print(len(i))
    for ii in i:
        print(ii)
        break

In [None]:
fig, ax = plt.subplots(len(attention_result), len(attention_result), sharex=True, sharey=True, figsize=(15,15))
label = range(fold)

for r_i in tqdm(range(len(attention_result))): # loop graph as index r_i
    for r_j in range(len(attention_result)): # loop graph as index r_j
        if r_i > r_j:
            attention_pearson = list()
            for ii, aa in enumerate(attention_result[r_i]): # loop smiles as index ii
                corr = scipy.stats.spearmanr(attention_result[r_i][ii], attention_result[r_j][ii])[0]
                attention_pearson.append(corr)

            sns.histplot(data=attention_pearson, bins=np.arange(-1.1, 1.1, 0.05), ax=ax[r_i, r_j], label=str(label[r_i])+'-'+str(label[r_j]))
            ax[r_i, r_j].axvline(x=np.mean(attention_pearson), linewidth=1, color='black')
            ax[r_i, r_j].set_xlim([-1,1])
            plt.xlabel('spearman')

# SINGLE COMPARE ALL

In [None]:
attention_all = list()
# Load model
file = 'bbbp'
model = 'GAT'
schema = ['A', 'AR_0']
reduced = ['junctiontree', 'cluster', 'functional', 'pharmacophore', 'substructure']

for s in schema:
    if s == 'A':
        directory = './dataset/'+os.path.join(*[file, model+'_'+s+'_'])
        ts = next(os.walk(directory))[1][0]
        reduced_list = ''
        log_folder_name = os.path.join(*[file, model+'_'+s+'_'+reduced_list, f"{ts}"])

        path = './dataset/'+log_folder_name+'/attention.pickle'
        with open(path, 'rb') as handle:
            attention_all.append(pickle.load(handle))
    else:
        for r in reduced:
            directory = './dataset/'+os.path.join(*[file, model+'_'+s+'_'+r])
            ts = sorted(next(os.walk(directory))[1])[0]
            reduced_list = r
            log_folder_name = os.path.join(*[file, model+'_'+s+'_'+reduced_list, f"{ts}"])

            path = './dataset/'+log_folder_name+'/attention.pickle'
            with open(path, 'rb') as handle:
                attention_all.append(pickle.load(handle))

In [None]:
len(attention_all)

In [None]:
for a in attention_all:
    for aa in a:
        print(a[aa])
        break
    # break

In [None]:
from rdkit.Chem import AllChem
# import functools
fingerprint = list()
test_y = list()
attention_result = [[] for _ in range(len(attention_all))]

for a in attention_all[0]: # loop smilse from atom graph
    for i in range(len(attention_all)): # loop graph
        # for g in attention_all[i][a]: # loop node of smiles a in graph i
        if len(attention_all[i][a]) !=1:
            dict_list = []
            for g in attention_all[i][a]:
                dict_list.append(attention_all[i][a][g])
            # print(dict_list)
            dictf = {k: dict_list[0].get(k, 0) + dict_list[1].get(k, 0) for k in set(dict_list[0]) & set(dict_list[1])}
            attention_result[i].append(list(dictf.values()))
        elif len(attention_all[i][a]) ==1:
            for g in attention_all[i][a]: # loop node of smiles a in graph i
                attention_result[i].append(list(attention_all[i][a][g].values()))

attention_result_t = []
for r in attention_result:
    attention_result_t.append(np.transpose(r))

attention_result = attention_result_t

In [None]:
len(attention_result)

In [None]:
for i in attention_result:
    print(len(i))
    for ii in i:
        print(ii)
        break

In [None]:
# attention_result

In [None]:
fig, ax = plt.subplots(len(attention_result), len(attention_result), sharex=True, sharey=True, figsize=(15,15))
# label = ['atom'] + reduced + ['ECFP6']
label = ['atom'] + reduced

for r_i in tqdm(range(len(attention_result))): # loop graph as index r_i
    for r_j in range(len(attention_result)): # loop graph as index r_j
        if r_i > r_j:
            attention_pearson = list()
            for ii, aa in enumerate(attention_result[r_i]): # loop smiles as index ii
                corr = scipy.stats.spearmanr(attention_result[r_i][ii], attention_result[r_j][ii])[0]
                attention_pearson.append(corr)

            sns.histplot(data=attention_pearson, bins=np.arange(-1.1, 1.1, 0.05), ax=ax[r_i, r_j], label=label[r_i]+'-'+label[r_j])
            ax[r_i, r_j].axvline(x=np.mean(attention_pearson), linewidth=1, color='black')
            ax[r_i, r_j].set_xlim([-1,1])
            plt.xlabel('spearman')