In [1]:
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import os
import argparse
import time
from im2scene import config
from im2scene.checkpoints import CheckpointIO
import logging
logger_py = logging.getLogger(__name__)

In [2]:
import os
import urllib
import torch
from torch.utils import model_zoo
import shutil
import datetime

In [3]:
import mrcfile
from tqdm import tqdm

In [4]:
device = torch.device('cuda:0')

In [5]:
out_dir = 'out/FFHQ_256b/'

In [6]:
cfg = config.load_config("configs/256res/FFHQ_256b.yaml", 'configs/default.yaml')

In [7]:
# Load model

In [8]:
model = config.get_model(cfg, device=device, len_dataset=1)

In [9]:
checkpoint_io = CheckpointIO(out_dir, model=model)
try:
    load_dict = checkpoint_io.load('model_best.pt')
    print("Loaded model checkpoint.")
except FileExistsError:
    load_dict = dict()
    print("No model checkpoint found.")


out/FFHQ_256b/model_best.pt
=> Loading checkpoint from local file...
Loaded model checkpoint.


In [22]:
def create_samples(N=128, voxel_origin=[0, 0, 0], cube_length=2.0):
    x, y, z = voxel_origin
    halfcube = cube_length/2
    x = torch.arange(0, N) * cube_length / (N-1) + x - halfcube
    y = torch.arange(0, N) * cube_length / (N-1) + y - halfcube
    z = torch.arange(0, N) * cube_length / (N-1) + z - halfcube

    grid = torch.stack(torch.meshgrid(x, y, z), dim=-1)
    
    samples = grid.reshape(1, -1, 3)

    return samples

In [16]:
batch_size=15
z_dim=256
torch.manual_seed(0)
latent = model.generator.get_latent_codes(15, 0.65) # copy from render script to get consistent results


In [23]:
# CREATE SAMPLES
voxel_res = 512

points = create_samples(voxel_res).to(device)
raydirs = torch.zeros_like(points) # shape is independent of raydir

In [24]:
## RUN SAMPLES

In [25]:
seeds = [5,]
# seeds = [7]
MAX_SAMPLES_PER_BATCH = int(2**20)

for seed in seeds:
    # batch across samples:
    with torch.no_grad():
        
        shape_z = latent[0][seed:seed+1]
        app_z = latent[0][seed:seed+1]


        sigma = torch.zeros(points.shape[:2]).to(device)

        i = 0
        with tqdm(total = sigma.shape[1]) as pbar:
            while i < sigma.shape[1]:
                _, sigma[:, i:i+MAX_SAMPLES_PER_BATCH] = model.generator.decoder(points[:, i:i+MAX_SAMPLES_PER_BATCH], raydirs[:, i:i+MAX_SAMPLES_PER_BATCH], shape_z, app_z)

                i += MAX_SAMPLES_PER_BATCH
                pbar.update(MAX_SAMPLES_PER_BATCH)

        sigma = torch.sigmoid(sigma)

    sigma = sigma.reshape(1, voxel_res, voxel_res, voxel_res)[0].cpu().numpy()
    sigma = np.transpose(sigma, (1, 2, 0))
    sigma = np.flip(sigma, axis=2)
    with mrcfile.new_mmap(os.path.join(out_dir, f'shape_{seed}.mrc'), overwrite=True, shape=sigma.shape, mrc_mode=2) as mrc:
        mrc.data[:] = sigma

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 134217728/134217728 [00:12<00:00, 10988636.20it/s]
