In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import os
import matplotlib.pyplot as plt
from main import Generator3D, generate_noise, upsample_3d

output_dir = "Oxy_3d_demo"
device = torch.device(f"cuda:0")
model_path = os.path.join(output_dir, 'trained_model_3d.pth')

saved_data = torch.load(model_path, map_location=device)
trained_generators_state_dicts = saved_data['generators_state_dicts']
fixed_noise_maps = saved_data['fixed_noise_maps']
pyramid_shapes = saved_data['pyramid_shapes']
train_opt_dict = saved_data['opt']
pyramid = saved_data['pyramid']

In [None]:
saved_data.keys()

In [None]:
type(saved_data['generators_state_dicts'][0])

In [None]:
saved_data['pyramid'][0].shape

In [None]:
saved_data['pyramid_shapes'][0]

In [None]:
saved_data['fixed_noise_maps'][4].shape

In [None]:
training_data = pyramid[4].detach().cpu().numpy()
training_data.shape
plt.imshow(training_data.squeeze().mean(0))
plt.colorbar()

In [None]:
# Create dummy pyramid for shape info if needed, or use saved shapes
# This part might need refinement depending on how gen_size interacts
class DummyOpt: pass
train_opt = DummyOpt()
for k, v in train_opt_dict.items(): 
    setattr(train_opt, k, v)
train_opt.nc_im = 1 # Assume single channel if not saved explicitly


In [None]:
def generate_3d_sample(trained_generators_state_dicts, 
                       pyramid, 
                       opt, device, 
                       gen_start_scale=0, 
                       custom_noise_shape=None):
    num_scales = len(trained_generators_state_dicts)
    generators = []
    # Load generators from state dicts
    for i in range(num_scales):
        netG = Generator3D(opt).to(device)
        # Load state dict corresponding to scale i (0=finest, N=coarsest)
        # Note: trained_generators was returned finest-to-coarsest
        netG.load_state_dict(trained_generators_state_dicts[i])
        netG.eval() # Set to evaluation mode
        generators.append(netG)

    # Determine starting scale index (N = num_scales - 1)
    start_scale_idx_actual = num_scales - 1 - gen_start_scale

    # Generate initial noise at the starting scale
    if custom_noise_shape:
        # Use custom shape (C, D, H, W)
        noise_shape = (opt.nc_im,) + tuple(custom_noise_shape)
    else:
        # Use shape from the corresponding pyramid level
        noise_shape = pyramid[::-1][gen_start_scale].shape[1:] # Get C, D, H, W

    current_noise = generate_noise(noise_shape, device)
    current_vol = torch.zeros((1,) + noise_shape, device=device) # Initial previous output is zero
    current_vol = pyramid[::-1][gen_start_scale]
    
    scale_factor_r = opt.scale_factor

    # Generate through the pyramid from start_scale down to 0 (finest)
    with torch.no_grad():
        for scale_idx in range(start_scale_idx_actual, -1, -1): # Iterate N, N-1,..., start_scale,..., 0
            # Get the generator for this scale (index maps directly: 0=finest, N=coarsest)
            # Need to map scale_idx (N..0) to list index (0..N)
            generator_list_idx = num_scales - 1 - scale_idx
            netG = generators[generator_list_idx]

            # Upsample previous volume
            prev_vol_upsampled = upsample_3d(current_vol, 
                                             scale_factor=scale_factor_r)

            # Determine target size for this scale
            if custom_noise_shape and scale_idx == start_scale_idx_actual:
                 target_size = noise_shape[-3:]
            elif scale_idx < num_scales -1 : # Not the coarsest scale being generated
                 # Infer target size by scaling up from the next coarser scale's pyramid shape
                 coarser_pyramid_idx = num_scales - 1 - (scale_idx + 1)
                 coarser_dims = np.array(pyramid[coarser_pyramid_idx].shape[-3:])
                 target_dims_float = coarser_dims * scale_factor_r
                 target_size = tuple(np.round(target_dims_float).astype(int))
                 # Ensure minimum size 1
                 target_size = tuple(max(1, d) for d in target_size)
            else: # Coarsest scale being generated (scale_idx == num_scales - 1)
                 target_size = noise_shape[-3:] # Use noise shape directly


            # Resize upsampled volume and noise to target size
            prev_vol_upsampled = F.interpolate(prev_vol_upsampled, size=target_size, mode='trilinear', align_corners=False)
            # pyramid_upsampled = F.interpolate(pyramid[coarser_pyramid_idx+1], size=target_size, mode='trilinear', align_corners=False) 
            # prev_vol_upsampled = (prev_vol_upsampled + pyramid_upsampled)/2
            noise_this_scale = F.interpolate(current_noise, size=target_size, mode='trilinear', align_corners=False)
            noise_this_scale = generate_noise(noise_this_scale.shape[1:], device)
            
            # Generate volume for this scale
            current_vol = netG(noise_this_scale, prev_vol_upsampled)

            # Prepare noise for the next finer scale (if any)
            if scale_idx > 0:
                # print(f"Generating noise for scale {scale_idx-1}")
                # print(f'target size: {target_size}')
                target_size = (1, *target_size)
                current_noise = generate_noise(target_size, device) # Generate new noise based on current size

    return current_vol

In [None]:
original = np.load('3d_data_channel_(76, 88, 114).npy')
original.shape

In [None]:
original.max()

In [None]:
REAL_1 = [] 
for i in range(10):
    dummy = generate_3d_sample(trained_generators_state_dicts,
                                pyramid, # Pass dummy pyramid for shape reference
                                train_opt, # Use options from training
                                device,
                                gen_start_scale=0).detach().cpu().numpy()
    dummy = (dummy - dummy.min())/(dummy.max() - dummy.min()) * (original.max()-original.min()) + original.min()
    REAL_1.append(dummy)
REAL_1 = np.array(REAL_1).squeeze()

In [None]:
training_data = pyramid[-1].detach().cpu().numpy()
training_data.shape
plt.imshow(training_data.squeeze().mean(0))
plt.colorbar()

In [None]:
plt.imshow(original.mean(0))
plt.colorbar()

In [None]:
plt.imshow(REAL_1[1].mean(0) - REAL_1[0].mean(0), cmap = plt.cm.seismic)
plt.colorbar()

In [None]:
img_height*len(REAL_1)//num_of_image_in_a_row

In [None]:
img_width = 4
img_height = 4
num_of_image_in_a_row = 5
len_images = 100
for i in range(len_images):
    if i == 0:
        plt.figure(figsize = (img_width*num_of_image_in_a_row,img_height*len_images//num_of_image_in_a_row))
    plt.subplot(len_images//num_of_image_in_a_row, num_of_image_in_a_row,i+1)
    plt.imshow(REAL_1[i].mean(0))
    plt.colorbar()

In [None]:
img_width = 4
img_height = 4
num_of_image_in_a_row = 5

for i in range(len(REAL_1)):
    if i == 0:
        plt.figure(figsize = (img_width*len(REAL_1),img_height*len(REAL_1)//5))
    plt.subplot(len(REAL_1)//num_of_image_in_a_row, num_of_image_in_a_row,i+1)
    plt.imshow(REAL_1[i].mean(0)[:,1:]- original.mean(0), cmap = plt.cm.seismic)
    plt.colorbar()

In [None]:
plt.imshow(REAL_1[1].mean(0)-)

In [None]:
plt.imshow(REAL_1[0].mean(0))

In [None]:
REAL_1 = [] 
for i in range(500):
    dummy = generate_3d_sample(trained_generators_state_dicts,
                                pyramid, # Pass dummy pyramid for shape reference
                                train_opt, # Use options from training
                                device,
                                gen_start_scale=0).detach().cpu().numpy()
    dummy = (dummy - dummy.min())/(dummy.max() - dummy.min()) * (original.max()-original.min()) + original.min()
    REAL_1.append(dummy)
REAL_1 = np.array(REAL_1).squeeze()

In [None]:
plt.figure(figsize=(15,5)   )
plt.subplot(1,2,1)
plt.title("Mean of 1,000 realizations")
plt.imshow(REAL_1.mean(1).mean(0), cmap='viridis')
plt.colorbar()
plt.subplot(1,2,2)
plt.title("Standard Deviation")
plt.imshow(REAL_1.mean(1).std(0), cmap='jet')
plt.colorbar()