In [61]:
import argparse
import json
import logging
import os
import sys

import numpy as np
import torch
import torch.optim as optim

from torch.utils.data import DataLoader, Subset

from tqdm import tqdm

from data import PartNetVoxelDataset
from vae import VAE

device = torch.device('cuda')

In [62]:
model = VAE(128, 64)
model.load_state_dict(torch.load("checkpoints/epoch20.pt"))
model.eval()
model = model.to(device)

In [63]:
input_path = "C:\\Users\\Alexandru\\Documents\\Python Scripts\\vaebert\\shapenet"
dataset = PartNetVoxelDataset(input_path)

In [64]:
with open(os.path.join(input_path, "train_indexes.json"), "r") as f:
    train_indices = json.load(f)

with open(os.path.join(input_path, "test_indexes.json"), "r") as f:
    test_indices = json.load(f)

train_dataset = Subset(dataset, train_indices)
test_dataset = Subset(dataset, test_indices)
train_loader = DataLoader(
    train_dataset,
    batch_size=40,
    shuffle=True,
    num_workers=1,
)
test_loader = DataLoader(
    test_dataset,
    batch_size=40,
    shuffle=False,
    num_workers=1,
)

In [65]:
batch = next(iter(test_loader)).float().to(device)

In [66]:
print(batch.shape)

torch.Size([40, 1, 64, 64, 64])


In [67]:
_, batch_output = model(batch)

In [68]:
print(batch_output.shape)

torch.Size([40, 1, 64, 64, 64])


In [69]:
batch_output= batch_output.squeeze()
batch = batch.squeeze()

In [70]:
from binvox_rw import Voxels

In [74]:
with open('1_old.binvox', 'wb') as f:
    voxel1 = Voxels((np.round(batch[0].cpu().detach().numpy()) > 0), [64, 64, 64], [0.0, 0.0, 0.0], 1.0, 'xyz')
    voxel1.write(f)
with open('1_new.binvox', 'wb') as f:
    voxel1 = Voxels((1/(1 + np.exp(-(batch_output[0].cpu().detach().numpy())))) > 0.5, [64, 64, 64], [0.0, 0.0, 0.0], 1.0, 'xyz')
    voxel1.write(f)

In [79]:
with open('2_old.binvox', 'wb') as f:
    voxel1 = Voxels((np.round(batch[1].cpu().detach().numpy()) > 0), [64, 64, 64], [0.0, 0.0, 0.0], 1.0, 'xyz')
    voxel1.write(f)
with open('2_new.binvox', 'wb') as f:
    voxel1 = Voxels((1/(1 + np.exp(-(batch_output[1].cpu().detach().numpy())))) > 0.5, [64, 64, 64], [0.0, 0.0, 0.0], 1.0, 'xyz')
    voxel1.write(f)

In [78]:
print(((1/(1 + np.exp(-(batch_output[0].cpu().detach().numpy())))) > 0.8).mean())

0.21095657348632812
