In [1]:
# ========public pkgs========
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import json

import numpy as np


#========= private pkgs==========================
from load_data import custom_datasets
from load_data import custom_transform

from load_data_h5py import SDFDataset
from SDF_VAE_improved import Encoder,Decoder,SDF_VAE


In [2]:
import numpy as np

def make_symmetric(a_numpy):
    """
    Mirror an N x N x N SDF matrix to create a symmetric 2N x 2N x 2N SDF matrix.

    :param a_numpy: An N x N x N numpy array representing the SDF in positive XYZ space.
    :return: A 2N x 2N x 2N symmetric numpy array representing the full SDF.
    """
    N = a_numpy.shape[0]
    
    a_numpy = np.flip(a_numpy, axis=2)


    a_numpy = np.flip(a_numpy, axis=1)


    a_numpy = np.flip(a_numpy, axis=0)

    # Initialize the full symmetric matrix of size 2N x 2N x 2N
    full_matrix = np.zeros((2 * N, 2 * N, 2 * N))

    # Populate the full matrix by mirroring a_numpy across its dimensions
    # Front-top-left quarter
    full_matrix[:N, :N, :N] = a_numpy
    # Front-top-right quarter
    full_matrix[:N, N:, :N] = a_numpy[:, ::-1, :]
    # Front-bottom-left quarter
    full_matrix[N:, :N, :N] = a_numpy[::-1, :, :]
    # Front-bottom-right quarter
    full_matrix[N:, N:, :N] = a_numpy[::-1, ::-1, :]

    # Back mirrors of the front
    full_matrix[:, :, N:] = full_matrix[:, :, :N][:, :, ::-1]

    # return full_matrix
    return a_numpy


In [3]:
# ==================Load configuration file===============
with open('./config_cluster.json') as f:
    config = json.load(f)

# Extract configuration parameters
batch_size = config['model_params']['batch_size']
latent_dim = config['model_params']['latent_dim']
beta = config['model_params']['beta']
learning_rate = config['train_params']['learning_rate']
epochs = config['train_params']['epochs']
manual_seed = config['random_seed']['manual_seed']
cuda_manual_seed = config['random_seed']['cuda_manual_seed']
loading_checkpoint = config['train_params']['loading_checkpoint']
# Paths from configuration
data_path_train = config['Path']['train_data_path']
data_path_test = config['Path']['test_data_path']
save_path = config['Path']['save_path']
checkpoint_path = config['Path']['log_path']

In [4]:
# Set random seeds for reproducibility
torch.manual_seed(manual_seed)
torch.cuda.manual_seed(cuda_manual_seed)
sdf_dimen = 50
# load test and train data
dataset_train = SDFDataset(data_path_train)
loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)

dataset_train = SDFDataset(data_path_test)
loader_test = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
model = SDF_VAE(input_channels=1, latent_dim=latent_dim, D=sdf_dimen)


In [5]:
# Setup device (GPU/CPU)
if torch.cuda.is_available(): # GPU is available
    if torch.cuda.device_count() > 1:
        print('GPU is available')
        model = nn.DataParallel(model)
    device = torch.device("cuda:0")
    GPU = True
    model.to(device)
else:  # only cpu is available
    print('CPU is only available')
    GPU = False
    device = torch.device("cpu")
    model.to(device)

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=learning_rate,
)
# Load checkpoint if specified
if loading_checkpoint:
    checkpoint = torch.load(checkpoint_path+'checkpoint.tar')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    current_epoch = checkpoint['epoch']
    loss = checkpoint['loss']
else: # use the initial model and optimiser
    current_epoch = 0

CPU is only available


In [6]:
# Define the loss function
def lossfunc(sdf,sdf_hat,iso,iso_hat,mu,logvar,beta):
    """
    Computes the Variational Autoencoder (VAE) loss function, combining reconstruction loss and KL divergence.
    Args
    Returns:
        torch.Tensor: The computed loss value.
    """
    # print('sdf_hat',sdf_hat.shape)
    # print('sdf',sdf.shape)
    # print('iso_hat',iso_hat.shape)
    # print('iso',iso.shape)


    

    sdf_loss = F.mse_loss(sdf_hat, sdf,reduction = 'mean')
    iso_loss = F.mse_loss(iso_hat, iso,reduction = 'mean')
    recons_loss = sdf_loss+ iso_loss
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    # kl_loss = torch.mean(-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim = 1), dim = 0)
    return recons_loss + beta* kl_loss

## reuse the model on cpu machine, 
### the trained model from GPU or parallel gpus, they are different


In [7]:
from collections import OrderedDict

mode_index = '192'
PATH  = './save_model/model_for_resol_50/'
loaded_model = PATH+'/VAEmodel_'+mode_index+'.pt'
def load_model(model, path):
    # Load the state dictionary from the file.
    state_dict = torch.load(path, map_location=torch.device('cpu'))

    # Check if the model was trained and saved using DataParallel or DistributedDataParallel
    # by checking the presence of 'module.' prefix in the state dictionary keys.
    is_multi_gpu_model = any(k.startswith('module.') for k in state_dict.keys())

    if is_multi_gpu_model:
        # If the model was saved with 'module.' prefixes (indicative of DataParallel or DistributedDataParallel usage),
        # create a new state dictionary without these prefixes.
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            # Remove the 'module.' prefix.
            name = k[7:] if k.startswith('module.') else k
            new_state_dict[name] = v
        state_dict = new_state_dict

    # If there are no 'module.' prefixes, the state_dict is assumed to be from a single GPU training
    # and is used without modification.

    # Load the state dictionary into the model.
    model.load_state_dict(state_dict)
    return model

model = load_model(model, PATH+'/VAEmodel_'+mode_index+'.pt')
model.eval()


SDF_VAE(
  (encoder): Encoder(
    (decoder_sdf): Sequential(
      (0): Sequential(
        (0): Conv3d(1, 16, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
        (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.01)
      )
      (1): Sequential(
        (0): Conv3d(16, 32, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
        (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.01)
      )
      (2): Sequential(
        (0): Conv3d(32, 32, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
        (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.01)
      )
      (3): Sequential(
        (0): Conv3d(32, 64, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
        (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine

In [8]:
mu = torch.load('./save_model//model_for_resol_50/mu_list_'+str(mode_index)+'.pt', map_location=device)

In [9]:
# sdf_z,iso_z = model.decoder(mu)
index = 52
sdf_z,iso_z = model.decoder(mu[index,:].unsqueeze(0))

import numpy as np
sdf = make_symmetric(sdf_z.squeeze().detach().numpy())

iso = iso_z.detach().numpy()
path_recon = './data/reconstruct_sdf/'
np.save(path_recon+'sdf'+str(index)+'.npy',sdf)
np.save(path_recon+'iso'+str(index)+'.npy',iso)

decodeed sdf torch.Size([1, 1, 50, 50, 50])
decodeed iso torch.Size([1])


In [12]:
mu_1 = mu[10,:].unsqueeze(0)
mu_2 = mu[80,:].unsqueeze(0)

# mu_1 = mu[45,:].unsqueeze(0)
# mu_2 = mu[88,:].unsqueeze(0)
def interpolate_betweeen_A_and_B(latent_A,latent_B,N):
    ''' inteprolatte of two latent vector 'A' 'B' ,with N stepts
    '''
    code = torch.Tensor(N, 6).to(device)
    for i in range(N):
        code[i] = i / (N - 1) * latent_A + (1 - i / (N - 1) ) * latent_B
    return code

Nbcell = 50
mu_inter = interpolate_betweeen_A_and_B(mu_1, mu_2, Nbcell)

In [13]:
sdf_z,iso_z = model.decoder(mu_inter)
for ii in range(Nbcell):
    sdf = sdf_z[ii].squeeze().detach().cpu().numpy()
    print(sdf.shape)
    iso = iso_z[ii].detach().cpu().numpy()
    np.save(path_recon+'sdf'+str(ii),sdf)
    np.save(path_recon+'iso'+str(ii), iso)

decodeed sdf torch.Size([100, 1, 50, 50, 50])
decodeed iso torch.Size([100])
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)
(50, 50, 50)


In [14]:
sdf_z.shape


torch.Size([100, 1, 50, 50, 50])

In [None]:
iso