In [None]:
import torch
import lib.utils.bookkeeping as bookkeeping
from torch.utils.data import DataLoader
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import lib.models.models as models
import lib.models.model_utils as model_utils
from lib.datasets import mnist, maze, protein, synthetic
import lib.datasets.dataset_utils as dataset_utils
import lib.losses.losses as losses
import lib.losses.losses_utils as losses_utils
import lib.training.training as training
import lib.training.training_utils as training_utils
import lib.optimizers.optimizers as optimizers
import lib.optimizers.optimizers_utils as optimizers_utils
import lib.sampling.sampling as sampling
import lib.sampling.sampling_utils as sampling_utils
import os
from lib.datasets.maze import maze_acc
from ruamel.yaml.scalarfloat import ScalarFloat
def count_turns(labyrinth):
    directions = [(0, 1), (1, 0), (0, -1), (-1, 0)]  # Rechts, Unten, Links, Oben
    num_turns = 0

    for i in range(1, labyrinth.shape[0] - 1):
        for j in range(1, labyrinth.shape[1] - 1):
            if labyrinth[i, j] == 1:
                valid_neighbors = 0
                for dx, dy in directions:
                    ni, nj = i + dx, j + dy
                    if labyrinth[ni, nj] == 1:
                        valid_neighbors += 1
                if valid_neighbors == 2:  # Zwei gültige Nachbarn bedeuten eine Richtungsänderung
                    num_turns += 1

    return num_turns

In [None]:

path = 'SavedModels/MAZE/' # 'SavedModels/MAZE/' 'SavedModels/MNIST/'
date = '2023-12-30' # 2
config_name = 'config_001_bert.yaml' # config_001_hollowMLEProb.yaml
model_name = 'model_229999_bert.pt' 


path = 'SavedModels/MAZEunet/' # 'SavedModels/MAZE/' 'SavedModels/MNIST/'
date = '2023-12-30' # 2
config_name = 'config_001_lastunet.yaml' # config_001_hollowMLEProb.yaml
model_name = 'model_399999_lastunet.pt'


path = 'SavedModels/MAZE/' # 'SavedModels/MAZE/' 'SavedModels/MNIST/'
date = '2023-12-22' # 2
config_name = 'config_001_hollow8M.yaml' # config_001_hollowMLEProb.yaml
model_name = 'model_299999_hollow8M.pt' 

path = 'SavedModels/MAZE/' # 'SavedModels/MAZE/' 'SavedModels/MNIST/'
date = '2023-12-28' # 2
config_name = 'config_001_hollowelbo.yaml' # config_001_hollowMLEProb.yaml
model_name = 'model_259999_hollowelbo.pt'

path = 'SavedModels/MAZE/' # 'SavedModels/MAZE/' 'SavedModels/MNIST/'
date = '2023-12-22' # 2
config_name = 'config_001_hollow8M.yaml' # config_001_hollowMLEProb.yaml
model_name = 'model_299999_hollow8M.pt' 

path = 'SavedModels/MAZE/' # 'SavedModels/MAZE/' 'SavedModels/MNIST/'
date = '2023-12-22' # 2
config_name = 'config_001_hollow8M.yaml' # config_001_hollowMLEProb.yaml
model_name = 'model_299999_hollow8M.pt' 

path = 'SavedModels/MAZE/' # 'SavedModels/MAZE/' 'SavedModels/MNIST/'
date = '2024-01-05' # 2
config_name = 'config_001_hollowdirect.yaml' # config_001_hollowMLEProb.yaml
model_name = 'model_299999_hollowdirect.pt' 

path = 'SavedModels/MAZE/' # 'SavedModels/MAZE/' 'SavedModels/MNIST/'
date = '2023-12-22' # 2
config_name = 'config_001_hollow8M.yaml' # config_001_hollowMLEProb.yaml
model_name = 'model_299999_hollow8M.pt' 

config_path = os.path.join(path, date, config_name)
checkpoint_path = os.path.join(path, date, model_name)

In [None]:
# creating models
cfg = bookkeeping.load_config(config_path)
cfg.sampler.name = 'TAULStepSize' #ExactSampling' # ElboLBJF CRMTauL CRMLBJF
cfg.sampler.num_corrector_steps = 0
cfg.sampler.corrector_entry_time = ScalarFloat(0.0)
cfg.sampler.num_steps = 5
cfg.sampler.is_ordinal = False
#print(cfg)
device = torch.device(cfg.device)

model = model_utils.create_model(cfg, device)
model = model.float()
print("number of parameters: ", sum([p.numel() for p in model.parameters()]))

#modified_model_state = utils.remove_module_from_keys(loaded_state['model'])
#model.load_state_dict(modified_model_state)
#optimizer = optimizers_utils.get_optimizer(model.parameters(), cfg)
optimizer = torch.optim.Adam(model.parameters(), cfg.optimizer.lr)

sampler = sampling_utils.get_sampler(cfg)

state = {"model": model, "optimizer": optimizer, "n_iter": 0}
state = bookkeeping.load_state(state, checkpoint_path)
state['model'].eval()


In [None]:
n_samples = 30 #cfg.data.batch_size 
samples, changes = sampler.sample(model, n_samples)
print(changes)
print(np.mean(changes))
saved_samples = samples

In [None]:
"""
samples = samples.reshape(n_samples, 1, cfg.data.image_size, cfg.data.image_size)
n_samples2 = 100
saving_train_path = os.path.join(cfg.saving.sample_plot_path, f"{cfg.model.name}{state['n_iter']}_{cfg.sampler.name}{cfg.sampler.num_steps}.png")
fig = plt.figure(figsize=(9, 9)) 
for i in range(n_samples2):
    plt.subplot(int(np.sqrt(n_samples2)), int(np.sqrt(n_samples2)), 1 + i)
    plt.axis("off")
    plt.imshow(np.transpose(samples[i, ...], (1,2,0)), cmap="gray")


#plt.savefig(saving_train_path)
plt.show()
plt.close()
"""

In [None]:
correct_mazes = maze_acc(saved_samples)

In [None]:
cfg.data.name = 'Maze3SComplete'
cfg.data.batch_size = n_samples

if cfg.data.name == 'Maze3SComplete':
    limit = cfg.data.batch_size
    cfg.data.limit = limit 

dataset = dataset_utils.get_dataset(cfg, device)
dataloader = torch.utils.data.DataLoader(dataset,
    batch_size=cfg.data.batch_size,
    shuffle=cfg.data.shuffle)

for i in dataloader:
    true_dl = i
    c_i = maze_acc(i.cpu().numpy())
    true_dl = true_dl.reshape(cfg.data.batch_size, -1).flatten()

In [None]:
samples = samples.reshape(cfg.data.batch_size, -1).flatten()
from scipy.stats import wasserstein_distance
print("EMD", wasserstein_distance(samples, true_dl.cpu().numpy()))


# only EMD: 5000
# Unet:
# TauL: 0.02457 40 %
# LBJF: 0.023745777777777866 45%
# MPTauL:

# Hollow:
# LBJF: 0.0011570370370370675 85%
# TauL: 0.005959999999999965 83%
# Analytical: 0.006370370370369915 83.5%

# Hollow Elbo:
# Analytical: 0.002561481481481387 62% # 
# LBJF: # 0.00424 63%
# TauL: 0.0092 63%

# 30: 0.000552592592592438

# 250 bei TauL nochmal oder 500

# ExactSampling: 1500: 85% and EMD  1.22
# LBJF: 93% 1.04

# Excplicit Hollow:
# LBJF: 18% EMD 0.005293333333333261
# TauL: 19 % EMD 0.011789629629629594

# Hollow-Transformer:
# 1 Step:
# LBJF: acc 0 und EMD 2: 0 und 0.0470.29 5: 9% 0.009 0.005 0.007: 10: 57, 0.006 0.0053
# TauL: 1: 0 0.2924 2: 0 und 0.196 5: 0 und 0.096
# Exact: 5: 16.5 16.5 % 0.014 0.0155 10: 55% 0.05 0.07


In [None]:
emd_lbjf = np.array([3.44, 3.38, 2.97, 2.35, 2.12, 1.75, 1.15]) 
emd_anal = np.array([4.75, 3.54, 3.24, 2.92, 2.65, 1.19, 1.16]) 
emd_taul = np.array([63.58, 36.05, 27.64, 17.66, 12.47, 5.96, 2.28]) 

In [None]:
steps = np.array([10, 20, 30, 50, 100, 250, 500] )
acc_lbjf = np.array([60, 75, 80, 82, 84, 85, 88] ) / 100
acc_anal = np.array([55, 75, 78, 82, 83, 83, 85]) / 100
acc_taul = np.array([3, 30, 50, 73, 80, 82, 84]) / 100
rej_rate = np.array([0.1662, 0.0732, 0.0434, 0.0243, 0.0115, 0.003, 0.00084])
#20: 0.07 30: 50: 0.0434, 0.024, 0.0

In [None]:
plt.plot(steps, acc_lbjf, label='Euler', marker='o') # 'o' für Kreise
plt.plot(steps, acc_anal, label='Analytical', marker='o')
plt.plot(steps, acc_taul, label='Tau-Leaping', marker='o')
plt.xlabel('NFE', fontsize=13)
plt.ylabel('Accuracy', fontsize=13)
plt.title('Accuracy of Sampling Methods with Varying NFE', fontsize=14)
plt.legend()
plt.grid(True)
plt.savefig('accuracy_plot.png')
plt.show()

In [None]:
plt.plot(steps, emd_lbjf, label='Euler', marker='o') # 'o' für Kreise
plt.plot(steps, emd_anal, label='Analytical', marker='o')
plt.plot(steps, emd_taul, label='Tau-Leaping', marker='o')
plt.xlabel('NFE', fontsize=13)
plt.ylabel('EMD (in units of $1\\times 10^{-3}$)', fontsize=13) 
plt.title('EMD of Sampling Methods with Varying NFE', fontsize=14)
plt.legend()
plt.grid(True)
plt.savefig('emd_plot.png')
plt.show()

In [None]:
plt.plot(steps, rej_rate, marker='o', color='black')
plt.ylabel('Average Rejection Rate', fontsize=13)
plt.xlabel('NFE', fontsize=13)
plt.title('Average Rejection Rate of Tau-Leaping Scheme', fontsize=14)
plt.grid(True)
plt.savefig('rej_rate_plot.png')
plt.show()