In [1]:
%load_ext autoreload
%autoreload 2
import os
os.chdir('/home/svcl-oowl/brandon/research/CVPR_2021_REFINE/sil_consistent_at_inference')
print(os.getcwd())

/home/svcl-oowl/brandon/research/CVPR_2021_REFINE/sil_consistent_at_inference


In [2]:
import pickle
import pprint

import torch

from utils import general_utils, datasets
from utils.datasets import gen_data_collate
from deformation.deformation_net_graph_convolutional_full import DeformationNetworkGraphConvolutionalFull
from deformation.multiview_semantic_discriminator_network import MultiviewSemanticDiscriminatorNetwork
from deformation.semantic_discriminator_net_points import PointsSemanticDiscriminatorNetwork
from deformation.forward_pass import batched_forward_pass, compute_sem_dis_logits
from utils.visualization_tools import save_tensor_img

In [9]:
cfg_path = "configs/test.yaml"

device = torch.device("cuda:0")
cfg = general_utils.load_config(cfg_path, "configs/default.yaml")
batch_size = cfg["semantic_dis_training"]["batch_size"]

# Test Datasets

In [9]:
generation_dataset = datasets.GenerationDataset(cfg)
generation_loader = torch.utils.data.DataLoader(generation_dataset, batch_size=batch_size, num_workers=1, shuffle=False, collate_fn=gen_data_collate, drop_last=True)


Loading cached generation dataset at caches/generation_dataset_cache_500.pt...


In [6]:
print(len(generation_dataset))
print(generation_dataset[0])

500
{'instance_name': '000_fake', 'mesh': <pytorch3d.structures.meshes.Meshes object at 0x7fa3af594748>, 'mesh_verts': tensor([[5.0947, 4.9838, 5.0988],
        [5.1497, 5.0246, 5.1155],
        [5.0561, 5.0058, 5.0963],
        ...,
        [5.1186, 4.9819, 5.0694],
        [5.0715, 4.9620, 5.1475],
        [5.1361, 4.9888, 5.0745]]), 'image': tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
      

In [13]:
real_dataset = datasets.RealDataset(cfg, device)
semantic_dis_loader = torch.utils.data.DataLoader(real_dataset, batch_size=batch_size, num_workers=1, shuffle=True)

Reusing previusly computed points assets...
Caching generation dataset...


HBox(children=(FloatProgress(value=0.0, max=500.0), HTML(value='')))




In [16]:
print(real_dataset[0].shape)
print(len(real_dataset))

torch.Size([642, 3])
500


# Test Networks

In [16]:
deform_net = DeformationNetworkGraphConvolutionalFull(cfg, device)
deform_net.to(device)
semantic_dis_net = MultiviewSemanticDiscriminatorNetwork(cfg)
semantic_dis_net.to(device)

MultiviewSemanticDiscriminatorNetwork(
  (net_1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (14): ReLU(inplace=True)
    (15): MaxPool2d(kernel_size=2, stride=2, padding=0,

In [17]:
for real_batch in semantic_dis_loader:
    #real_image_batch = real_batch["mv_images"].to(device)
    #print(real_image_batch.shape)
    print(real_batch.shape)
    print(len(real_batch.shape))
    real_batch = real_batch.to(device)
    pred_logits_real = semantic_dis_net(real_batch)
    print(pred_logits_real)

torch.Size([1, 8, 3, 224, 224])
5
tensor([[-1.0398]], device='cuda:0', grad_fn=<AddmmBackward>)


In [18]:
# TODO: sanity check these images
for gen_batch in generation_loader:
    _, deformed_meshes, _ = batched_forward_pass(cfg, device, deform_net, semantic_dis_net, gen_batch, compute_losses=False)
    pred_logits_fake, semantic_dis_debug_data = compute_sem_dis_logits(deformed_meshes, semantic_dis_net, device, cfg)

In [19]:
save_tensor_img(real_batch, "real", "notebooks/out")
save_tensor_img(semantic_dis_debug_data, "generated", "notebooks/out")

# Misc

In [18]:
# make dummy pose dict
dummy_pose_dict = {"sphere_642":{"dist":1, "elev":10, "azim":8}}
dummy_pose_path = "data/adversarial/gen/cubetest/poses.p"
pickle.dump(dummy_pose_dict, open(dummy_pose_path,"wb"))
pprint.pprint(pickle.load(open(dummy_pose_path, 'rb')))

{'sphere_642': {'azim': 8, 'dist': 1, 'elev': 10}}
