In [1]:
import os
import sys
import json
import pickle
import matplotlib.pyplot as plt
import pandas as pd

# Switch to parent path to import local module
parent_path = str(os.getcwd()).split('notebooks')[0] # zeosyn_gen
os.chdir(parent_path)
print('Switched directory to:', os.getcwd())

import torch
import data.utils as utils
sys.modules['utils'] = utils # Way to get around relative imports in utils for ZeoSynGen_dataset # https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory
from eval import load_model, get_prediction_and_ground_truths, eval_zeolite_aggregated, eval_zeolite_osda, eval_single_system, get_metric_dataframes
from data.metrics import maximum_mean_discrepancy, wasserstein_distance
from models.diffusion import *

Switched directory to: /home/jupyter/Elton/Zeolites/zeosyn_gen


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_type = 'diff'
fname = 'v2'
split = 'system'

In [3]:
def eval_diff(model_type, fname, split, cond_scale):   
    # Load configs
    with open(f'runs/{model_type}/{split}/{fname}/configs.json') as f:
        configs = json.load(f)
    with open(f"runs/{model_type}/{split}/{fname}/train_loss_list.pkl", 'rb') as f: # load dataset
        train_loss_list = pickle.load(f)
    with open(f"runs/{model_type}/{split}/{fname}/val_loss_list.pkl", 'rb') as f: # load dataset
        val_loss_list = pickle.load(f)

    model, configs = load_model(model_type, fname, split)
    syn_pred, syn_pred_scaled, syn_true, syn_true_scaled, dataset = get_prediction_and_ground_truths(model, configs, cond_scale=cond_scale)

    mmd_zeo_agg_df, wsd_zeo_agg_df = eval_zeolite_aggregated(syn_pred, syn_pred_scaled, syn_true, syn_true_scaled, dataset, configs)
    mmd_zeo_osda_df, wsd_zeo_osda_df = eval_zeolite_osda(syn_pred, syn_pred_scaled, syn_true, syn_true_scaled, dataset, configs)

In [4]:
# Vary cond_scale given fixed model
for cond_scale in [0.75, 1, 1.25]:
    print('cond_scale: ', cond_scale)
    eval_diff(model_type, fname, split, cond_scale=cond_scale)
    print()

cond_scale:  0.75
Loading model and configs...
Getting model predictions and grouth truths...
SYSTEMS:
train+val: 1856 test: 464

n_datapoints:
train: 14749 val: 2107 test: 5168
Loading synthetic predictions from saved predictions...
Calculating metrics for zeolite-aggregated systems...
Mean MMD: 1.6142277946838965
Mean WSD: 0.5239517722854049
Calculating metrics for zeolite-OSDA systems...
Mean MMD: 1.9497980212652555
Mean WSD: 0.4894656485095479

cond_scale:  1
Loading model and configs...
Getting model predictions and grouth truths...
SYSTEMS:
train+val: 1856 test: 464

n_datapoints:
train: 14749 val: 2107 test: 5168
Loading synthetic predictions from saved predictions...
Calculating metrics for zeolite-aggregated systems...
Mean MMD: 1.7617650215442364
Mean WSD: 0.5149444227115215
Calculating metrics for zeolite-OSDA systems...
Mean MMD: 2.042940849898964
Mean WSD: 0.46391971207263494

cond_scale:  1.25
Loading model and configs...
Getting model predictions and grouth truths...
SYS