In [1]:
import os
import sys
import json
import pickle
import matplotlib.pyplot as plt
import pandas as pd

# Switch to parent path to import local module
parent_path = str(os.getcwd()).split('notebooks')[0] # zeosyn_gen
os.chdir(parent_path)
print('Switched directory to:', os.getcwd())

import torch
import data.utils as utils
sys.modules['utils'] = utils # Way to get around relative imports in utils for ZeoSynGen_dataset # https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory
from eval import load_model, get_prediction_and_ground_truths, eval_zeolite_aggregated, eval_zeolite_osda, eval_single_system, get_metric_dataframes
from data.metrics import maximum_mean_discrepancy, wasserstein_distance
from models.diffusion import *

Switched directory to: /home/jupyter/Elton/Zeolites/zeosyn_gen


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_type = 'diff'
fname = 'v6'
split = 'system'

In [3]:
def eval_diff(model_type, fname, split, cond_scale):   
    # Load configs
    with open(f'runs/{model_type}/{split}/{fname}/configs.json') as f:
        configs = json.load(f)
    with open(f"runs/{model_type}/{split}/{fname}/train_loss_list.pkl", 'rb') as f: # load dataset
        train_loss_list = pickle.load(f)
    with open(f"runs/{model_type}/{split}/{fname}/val_loss_list.pkl", 'rb') as f: # load dataset
        val_loss_list = pickle.load(f)

    model, configs = load_model(model_type, fname, split)
    syn_pred, syn_pred_scaled, syn_true, syn_true_scaled, dataset = get_prediction_and_ground_truths(model, configs, cond_scale=cond_scale)

    mmd_zeo_agg_df, wsd_zeo_agg_df = eval_zeolite_aggregated(syn_pred, syn_pred_scaled, syn_true, syn_true_scaled, dataset, configs)
    mmd_zeo_osda_df, wsd_zeo_osda_df = eval_zeolite_osda(syn_pred, syn_pred_scaled, syn_true, syn_true_scaled, dataset, configs)

In [4]:
# Vary cond_scale given fixed model
for cond_scale in [
                    0.75, 
                    1, 
                    1.25
                    ]:
    print('cond_scale: ', cond_scale)
    eval_diff(model_type, fname, split, cond_scale=cond_scale)
    print()

cond_scale:  0.75
Loading model and configs...
Loading model at step 334611...
Getting model predictions and grouth truths...
SYSTEMS:
train+val: 1856 test: 464

n_datapoints:
train: 14749 val: 2107 test: 5168
Loading synthetic predictions from saved predictions...
Calculating metrics for zeolite-aggregated systems...


  prec_zeo_agg_df_mean, rec_zeo_agg_df_mean = prec_zeo_agg_df.mean(0), rec_zeo_agg_df.mean(0)


Si/Al_prec         0.640047
Al/P_prec          0.767220
Si/Ge_prec         0.764924
Si/B_prec          0.771434
Na/T_prec          0.794168
K/T_prec           0.934929
OH/T_prec          0.712459
F/T_prec           0.846071
H2O/T_prec         0.568139
sda1/T_prec        0.570906
cryst_temp_prec    0.589103
cryst_time_prec    0.459325
dtype: float64 Si/Al_rec         0.763083
Al/P_rec          0.904820
Si/Ge_rec         0.873963
Si/B_rec          0.873416
Na/T_rec          0.939690
K/T_rec           0.972721
OH/T_rec          0.849359
F/T_rec           0.948728
H2O/T_rec         0.868570
sda1/T_rec        0.831445
cryst_temp_rec    0.812534
cryst_time_rec    0.720802
dtype: float64                    0
Si/Al       0.701565
Al/P        0.836020
Si/Ge       0.819443
Si/B        0.822425
Na/T        0.866929
K/T         0.953825
OH/T        0.780909
F/T         0.897399
H2O/T       0.718354
sda1/T      0.701175
cryst_temp  0.700818
cryst_time  0.590063
Mean MMD: 1.7345713514548082
Mean WSD

  prec_zeo_osda_df_mean, rec_zeo_osda_df_mean = prec_zeo_osda_df.mean(0), rec_zeo_osda_df.mean(0)


Si/Al_prec         0.562948
Al/P_prec          0.716012
Si/Ge_prec         0.825279
Si/B_prec          0.831382
Na/T_prec          0.719876
K/T_prec           0.913810
OH/T_prec          0.671711
F/T_prec           0.822657
H2O/T_prec         0.552812
sda1/T_prec        0.493103
cryst_temp_prec    0.493367
cryst_time_prec    0.319017
dtype: float64 Si/Al_rec         0.799414
Al/P_rec          0.895380
Si/Ge_rec         0.904608
Si/B_rec          0.901778
Na/T_rec          0.930131
K/T_rec           0.972036
OH/T_rec          0.878780
F/T_rec           0.930554
H2O/T_rec         0.837363
sda1/T_rec        0.812569
cryst_temp_rec    0.790033
cryst_time_rec    0.672406
dtype: float64                    0
Si/Al       0.681181
Al/P        0.805696
Si/Ge       0.864943
Si/B        0.866580
Na/T        0.825004
K/T         0.942923
OH/T        0.775245
F/T         0.876605
H2O/T       0.695088
sda1/T      0.652836
cryst_temp  0.641700
cryst_time  0.495712
Mean MMD: 1.9682576438432098
Mean WSD

  prec_zeo_agg_df_mean, rec_zeo_agg_df_mean = prec_zeo_agg_df.mean(0), rec_zeo_agg_df.mean(0)


Si/Al_prec         0.649864
Al/P_prec          0.776353
Si/Ge_prec         0.760426
Si/B_prec          0.767836
Na/T_prec          0.790147
K/T_prec           0.929484
OH/T_prec          0.703016
F/T_prec           0.834369
H2O/T_prec         0.564272
sda1/T_prec        0.566778
cryst_temp_prec    0.581449
cryst_time_prec    0.483260
dtype: float64 Si/Al_rec         0.781080
Al/P_rec          0.877334
Si/Ge_rec         0.869579
Si/B_rec          0.887989
Na/T_rec          0.909655
K/T_rec           0.967097
OH/T_rec          0.822863
F/T_rec           0.932516
H2O/T_rec         0.794547
sda1/T_rec        0.770652
cryst_temp_rec    0.752170
cryst_time_rec    0.657510
dtype: float64                    0
Si/Al       0.715472
Al/P        0.826843
Si/Ge       0.815003
Si/B        0.827912
Na/T        0.849901
K/T         0.948290
OH/T        0.762940
F/T         0.883442
H2O/T       0.679410
sda1/T      0.668715
cryst_temp  0.666809
cryst_time  0.570385
Mean MMD: 1.8946683590228741
Mean WSD

  prec_zeo_osda_df_mean, rec_zeo_osda_df_mean = prec_zeo_osda_df.mean(0), rec_zeo_osda_df.mean(0)


Si/Al_prec         0.582355
Al/P_prec          0.725133
Si/Ge_prec         0.828740
Si/B_prec          0.835034
Na/T_prec          0.726822
K/T_prec           0.918465
OH/T_prec          0.680348
F/T_prec           0.820122
H2O/T_prec         0.558794
sda1/T_prec        0.499567
cryst_temp_prec    0.500424
cryst_time_prec    0.330599
dtype: float64 Si/Al_rec         0.783261
Al/P_rec          0.852827
Si/Ge_rec         0.896742
Si/B_rec          0.898004
Na/T_rec          0.887043
K/T_rec           0.953972
OH/T_rec          0.820618
F/T_rec           0.918058
H2O/T_rec         0.770895
sda1/T_rec        0.733232
cryst_temp_rec    0.735166
cryst_time_rec    0.598519
dtype: float64                    0
Si/Al       0.682808
Al/P        0.788980
Si/Ge       0.862741
Si/B        0.866519
Na/T        0.806933
K/T         0.936218
OH/T        0.750483
F/T         0.869090
H2O/T       0.664844
sda1/T      0.616400
cryst_temp  0.617795
cryst_time  0.464559
Mean MMD: 2.1385311862473846
Mean WSD

  prec_zeo_agg_df_mean, rec_zeo_agg_df_mean = prec_zeo_agg_df.mean(0), rec_zeo_agg_df.mean(0)


Si/Al_prec         0.659204
Al/P_prec          0.779785
Si/Ge_prec         0.760760
Si/B_prec          0.766683
Na/T_prec          0.788170
K/T_prec           0.927839
OH/T_prec          0.700155
F/T_prec           0.834427
H2O/T_prec         0.554993
sda1/T_prec        0.568313
cryst_temp_prec    0.581438
cryst_time_prec    0.477382
dtype: float64 Si/Al_rec         0.747125
Al/P_rec          0.866614
Si/Ge_rec         0.849836
Si/B_rec          0.855001
Na/T_rec          0.873707
K/T_rec           0.968767
OH/T_rec          0.795296
F/T_rec           0.945345
H2O/T_rec         0.749901
sda1/T_rec        0.746788
cryst_temp_rec    0.747634
cryst_time_rec    0.611949
dtype: float64                    0
Si/Al       0.703164
Al/P        0.823200
Si/Ge       0.805298
Si/B        0.810842
Na/T        0.830939
K/T         0.948303
OH/T        0.747726
F/T         0.889886
H2O/T       0.652447
sda1/T      0.657550
cryst_temp  0.664536
cryst_time  0.544666
Mean MMD: 1.9840838267252996
Mean WSD

  prec_zeo_osda_df_mean, rec_zeo_osda_df_mean = prec_zeo_osda_df.mean(0), rec_zeo_osda_df.mean(0)
