In [None]:
import cohortintgrad as csig
import pickle, torch, itertools, os, json
import util_xtda_chem, tobacco_util, data_construct
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd

In [None]:
tobacco_util.cat_feat_ver1(
    item="Methane",
    pressure_item=100,
    tobacco_dir = "3rdparty/tobacco_1.0/",
)

In [None]:
_, x, _, _ = util_xtda_chem.dataload(npz_fn='tobacco_1.0_CH4_100bar_x_sc.npy', df_fn='tobacco_1.0_CH4_100bar_sc.csv')
num_data = 1000
x = x[:1000]

model_fn = 'model_tobacco.pkl'
with open(model_fn, mode="rb") as f:
    rf = pickle.load(f)

y = rf.predict(x.reshape(x.shape[0], -1))

In [None]:
IG = csig.CohortIntGrad(torch.Tensor(x), torch.Tensor(y), ratio=0.01, n_step=50)
ig, rd = IG.igcs_stack(list(range(x.shape[0])))
torch.save(ig, f'igcs_tobacco_100bar_cube_predict')

In [None]:
cif_id=4049
with open("map_from_testID_to_cifID_tobacco.json", mode="r") as f:
    map_js = json.load(f)
data_id=np.where(np.array(list(map_js.values())) == cif_id)[0].item()

onemaxB = 27.0
twomaxB = 27.0
onemaxP = 8.8
twomaxP = 3.5
px_size = 54

In [None]:
os.makedirs("tobacco_cat_predict", exist_ok=True)
tobacco_util.cs_for_csig_dump_px_by_px_forsingle(
    cif_id=cif_id,
    dir_name = "tobacco_cat_predict",
    igcs_cube_fn="igcs_tobacco_100bar_cube_predict",
    map_js_fn = "map_from_testID_to_cifID_tobacco.json",
    cat_df_fn = "cat_feat_tobacco_1.0.csv",
    output_stack_fn = f"cs_for_igcs_allstack_predict_001_{cif_id}.npy",
    px_size = 54,
    cohort_size = 1000,)

In [None]:
cat_df = pd.read_csv("cat_feat_tobacco_1.0.csv").set_index("Unnamed: 0")
xt1, yt1 = data_construct.ticks(
    px_size=px_size, max_birth=onemaxB, max_persistence=onemaxP
)
xt2, yt2 = data_construct.ticks(
    px_size=px_size, max_birth=twomaxB, max_persistence=twomaxP
)

all_np_fn = f"tobacco_cat_predict/cs_for_igcs_allstack_predict_001_{cif_id}.npy"
igcs_target = np.load(all_np_fn)
#igcs_target = all_np[:, data_id]

df_fn = "tobacco_1.0_CH4_100bar_sc.csv"
df = pd.read_csv(df_fn).set_index("Unnamed: 0")
_, test_x, _, _ = train_test_split(
    df["fn"].values, df["adsorption"].values, test_size=0.2, random_state=1018
)

test_x_cif = [int(i.split('-')[1].split('.')[0]) for i in test_x]

tick_num = len(xt1) - 1
vmax = np.max(abs(igcs_target))

In [None]:
fig, ax = plt.subplots(4, 2, figsize=(13, 26))
for i, j in itertools.product(range(4), range(2)):
    ax[i, j].set_title(
        f"CS of {cat_df.loc[test_x_cif[data_id]][i+2]} for cycles in H{j+1}\nsum of attr={np.round(np.sum(igcs_target[:,i].reshape(2, px_size, px_size)[j, ::-1]),2)}",
        fontsize=18,
    )
    im_content = ax[i, j].imshow(
        igcs_target[:, i].reshape(2, px_size, px_size)[j, ::-1],
        cmap="seismic",
        vmin=-vmax,
        vmax=vmax,
    )
    cb = fig.colorbar(im_content, ax=ax[i, j])
    cb.ax.tick_params(labelsize=18)

    ax[i, j].invert_yaxis()
    ax[i, j].set_xticks(
        range(0, px_size + int(px_size / tick_num), int(px_size / tick_num))
    )
    ax[i, j].set_yticks(
        range(0, px_size + int(px_size / tick_num), int(px_size / tick_num))
    )
    ax[i, j].set_xlabel("Birth (angstrom)", fontsize=18)
    ax[i, j].set_ylabel("Persistence (angstrom)", fontsize=18)
for i in range(4):
    ax[i, 0].set_xticklabels(xt1, fontsize=18, rotation=270)
    ax[i, 0].set_yticklabels(yt1, fontsize=18)
    ax[i, 1].set_xticklabels(xt2, fontsize=18, rotation=270)
    ax[i, 1].set_yticklabels(yt2, fontsize=18)
fig.tight_layout()

plt.show()

In [None]:
cat_x = cat_df.loc[map_js.values()].iloc[:1000][['template','node1','node2','edge']].values
IG = csig.CohortIntGrad(x=cat_x, y=y[:1000], n_step=50, ratio=0.01)
direct_attr = np.vstack([IG.cohort_kernel_shap(t_id=i) for i in range(1000)])

cat_df_lim = cat_df.loc[map_js.values()].iloc[:1000].reset_index()
pd.DataFrame(direct_attr[cat_df_lim[cat_df_lim['5']=='L_43'].index]).describe()[3].loc['mean']

In [None]:
pd.DataFrame(direct_attr[cat_df_lim[cat_df_lim['5']=='L_43'].index]).loc[np.where(cat_df_lim[cat_df_lim['5']=='L_43']['0']==cif_id)[0]][3]