In [None]:
import torch
import numpy as np
import os
from DiffusionFreeGuidence.DiffusionCondition import GaussianDiffusionSampler, GaussianDiffusionTrainer
from DiffusionFreeGuidence.ModelCondition import UNet
from Utils import *

modelConfig = {
    "state": "eval",
    "epoch": 1,
    "batch_size": 100,
    "T": 4000, 
    "channel": 128,
    "channel_mult": [1, 2, 2, 2],
    "num_res_blocks": 2,
    "dropout": 0.15,
    "lr": 1e-4,
    "multiplier": 2.5,
    "beta_1": 1e-4,
    "beta_T": 0.028,
    "img_size": 32,
    "grad_clip": 1.,
    "device": "cuda:0",
    "w": 1.0,
    "save_dir": "./CheckpointsCondition4000/Model_0/", 
    "training_load_weight": None,
    "test_load_weight": "ckpt_99_.pt", 
    "sampled_dir": "./SampledImgs/",
    "sampledNoisyImgName": "NoisyGuidenceImgs.png",
    "sampledImgName": "SampledGuidenceImgs.png",
    "nrow": 8
}

In [None]:
# Choose a label
label = 5

In [None]:
# Generate 100 images of the given label and save their generation paths

device = torch.device(modelConfig["device"])
# load model and evaluate
with torch.no_grad():
    model = UNet(T=modelConfig["T"], num_labels=10, ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"],
                    num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)
    ckpt = torch.load(os.path.join(
        modelConfig["save_dir"], modelConfig["test_load_weight"]), map_location=device)
    model.load_state_dict(ckpt)
    print("model load weight done.")
    model.eval()
    sampler = GaussianDiffusionSampler(
        model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"], w=modelConfig["w"]).to(device)
    # Sampled from standard normal distribution
    noisyImage = torch.randn(
        size=[modelConfig["batch_size"], 3, modelConfig["img_size"], modelConfig["img_size"]], device=device)
    sampledImgs, x_hat_tensor, x_t_tensor, snr, alphas_bar = sampler(noisyImage, torch.tensor([label]*modelConfig["batch_size"]).to(device), tracking_mode=True)
torch.save(sampledImgs, '100_sampled_images_label_'+str(label)+'_model_0.pt')
torch.save(x_t_tensor, '100_generation_steps_label_'+str(label)+'_model_0.pt')

In [None]:
sampledImgs = torch.load('100_sampled_images_label_'+str(label)+'_model_0.pt')
x_t_tensor = torch.load('100_generation_steps_label_'+str(label)+'_model_0.pt')

est_list_list = []

for index in range(10):
    est_list = []

    for t in range(3750, 250 - 1, -250):
        est = get_est(x_t_tensor[index][t-1], label, model_num=0, t=t, iter=5).cpu().numpy()
        est_list.append(est)

    est = get_est(sampledImgs[index], label, model_num=0, t=0, iter=5).cpu().numpy()
    est_list.append(est)
    est_list = np.array(est_list)
    est_list_list.append(est_list)
    torch.save(est_list, 'PMI_along_generation_label_'+str(label)+'index'+str(index)+'_model_0.pt')

est_list_list = torch.tensor(np.array(est_list_list))
torch.save(est_list_list, 'PMI_along_generation_label_'+str(label)+'_model_0.pt')
print("saved")

In [None]:
# Example code for plotting PMI along reverse sampling process for all 10 labels

est_table = []

for label in range(1, 11):
    est_list_list = torch.load('PMI_along_generation_1/PMI_along_generation_label_'+str(label)+'_model_0.pt').cpu().numpy()
    est_table.append(est_list_list)

timestep = np.linspace(4000-250,0,4000//250)

for label in range(10):    
    for i in range(10):
        plt.plot(timestep, est_table[label][i])
    plt.xlabel('timestep')
    plt.ylabel('PMI')
    plt.axhline(0, color='black', linewidth=0.5)
    title = 'PMI along sampling process - Label ' + str(label)
    plt.title(title)

    # Save the plot with its title as the filename
    # plt.savefig(title + '.png')

    # Display the plot
    plt.show()

    # Clear the current plot to avoid overlap with the next one
    plt.clf()