In [1]:
%load_ext autoreload
%autoreload 2
notebook_fixed_dir = False

In [2]:
# this cell can only be called once
import os
if not notebook_fixed_dir:
    os.chdir('..')
    notebook_fixed_dir = True
print(os.getcwd())

/home/svcl-oowl/brandon/research/sil_consistent_at_inference


In [3]:
import pprint

import torch
import pickle
from tqdm import tqdm
from PIL import Image
import numpy as np
from pytorch3d.renderer import (
    look_at_view_transform
)
import matplotlib.pyplot as plt
import glob
from pathlib import Path

from utils import utils
import deformation.losses as def_losses
from deformation.semantic_discriminator_loss import SemanticDiscriminatorLoss 
from semantic_discriminator_trainer import train
from deformation.semantic_discriminator_dataset import SemanticDiscriminatorDataset
from deformation.semantic_discriminator_net import SemanticDiscriminatorNetwork

In [4]:
gpu_num = 0
device = torch.device("cuda:"+str(gpu_num))
cfg_path = "configs/testing.yaml"

In [5]:
#mesh_paths = ["data/test_dataset/0001old.obj"]
#mesh_path = "data/test_dataset_one_processed/batch_1_of_1/0001old.obj"
mesh_paths = glob.glob(os.path.join("data/onet_chair_pix3d_dann_simplified_processed/batch_1_of_5","*.obj"))
#mesh_paths = [str(path) for path in list(Path(os.path.join("data/misc/example_shapenet")).rglob('*.obj'))]
#pprint.pprint(mesh_paths)

In [None]:
cfg = utils.load_config(cfg_path,"configs/default.yaml")
semantic_loss_computer = SemanticDiscriminatorLoss(cfg, device)

for mesh_path in mesh_paths[:4]:
    with torch.no_grad():
        mesh = utils.load_untextured_mesh(mesh_path, device)
        semantic_dis_loss, semantic_loss_renders = semantic_loss_computer.compute_loss(mesh)
    print("{}: {}".format(mesh_path,semantic_dis_loss.item()))

In [8]:
cfg = utils.load_config(cfg_path,"configs/default.yaml")
semantic_loss_computer = SemanticDiscriminatorLoss(cfg, device)
val_dataset = SemanticDiscriminatorDataset(cfg, "val")
val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=1, num_workers=4, shuffle=False,
    collate_fn=None, worker_init_fn=None)

In [6]:
semantic_discriminator_net = SemanticDiscriminatorNetwork(cfg)
semantic_discriminator_net.load_state_dict(torch.load(cfg["training"]["semantic_dis_weight_path"]))
semantic_discriminator_net.to(device)

SemanticDiscriminatorNetwork(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Dropout(p=0.8, inplace=False)
    (6): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.2, inplace=True)
    (9): Dropout(p=0.8, inplace=False)
    (10): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (11): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): LeakyReLU(negative_slope=0.2, inplace=True)
    (13): Dropout(p=0.8, inplace=

In [9]:
val_accuracies = []
n = 100
i = 0
for val_batch in tqdm(val_loader):
    semantic_discriminator_net.eval()
    #print(val_batch['real_path'])
    #print(val_batch['fake_path'])
    with torch.no_grad():
        pred_logits_real = semantic_discriminator_net(val_batch['real'].to(device))
        pred_logits_fake = semantic_discriminator_net(val_batch['fake'].to(device))
        batch_size = val_batch['real'].shape[0]
        real_labels = torch.zeros((batch_size, 1)).to(device)
        fake_labels = torch.ones((batch_size, 1)).to(device)
        print("real preds:")
        #print(pred_logits_real)
        print(torch.sigmoid(pred_logits_real))
        print("fake preds:")
        #print(pred_logits_fake)
        print(torch.sigmoid(pred_logits_fake))
        
        real_correct_vec = (torch.sigmoid(pred_logits_real) > 0.5) == real_labels.byte()
        fake_correct_vec = (torch.sigmoid(pred_logits_fake) > 0.5) == fake_labels.byte()
        val_accuracies.append(real_correct_vec.cpu().numpy())
        val_accuracies.append(fake_correct_vec.cpu().numpy())

    val_accuracy = np.mean(np.concatenate(val_accuracies, axis = 0))
    print(val_accuracy)
    i += 1
    if i >= n:
        break


  0%|          | 0/7672 [00:00<?, ?it/s][A
  0%|          | 1/7672 [00:00<17:58,  7.11it/s][A
  0%|          | 19/7672 [00:00<12:46,  9.98it/s][A

real preds:
tensor([[0.5516]], device='cuda:0')
fake preds:
tensor([[0.6661]], device='cuda:0')
0.5
real preds:
tensor([[0.4679]], device='cuda:0')
fake preds:
tensor([[0.6746]], device='cuda:0')
0.75
real preds:
tensor([[0.5001]], device='cuda:0')
fake preds:
tensor([[0.6115]], device='cuda:0')
0.6666666666666666
real preds:
tensor([[0.4760]], device='cuda:0')
fake preds:
tensor([[0.6308]], device='cuda:0')
0.75
real preds:
tensor([[0.5672]], device='cuda:0')
fake preds:
tensor([[0.6540]], device='cuda:0')
0.7
real preds:
tensor([[0.4510]], device='cuda:0')
fake preds:
tensor([[0.5551]], device='cuda:0')
0.75
real preds:
tensor([[0.4887]], device='cuda:0')
fake preds:
tensor([[0.5923]], device='cuda:0')
0.7857142857142857
real preds:
tensor([[0.4636]], device='cuda:0')
fake preds:
tensor([[0.6144]], device='cuda:0')
0.8125
real preds:
tensor([[0.5757]], device='cuda:0')
fake preds:
tensor([[0.6125]], device='cuda:0')
0.7777777777777778
real preds:
tensor([[0.5138]], device='cuda:0')
f


  0%|          | 36/7672 [00:00<09:09, 13.90it/s][A
  1%|          | 53/7672 [00:00<06:37, 19.18it/s][A

real preds:
tensor([[0.3719]], device='cuda:0')
fake preds:
tensor([[0.8015]], device='cuda:0')
0.8
real preds:
tensor([[0.3740]], device='cuda:0')
fake preds:
tensor([[0.7562]], device='cuda:0')
0.8055555555555556
real preds:
tensor([[0.4426]], device='cuda:0')
fake preds:
tensor([[0.6985]], device='cuda:0')
0.8108108108108109
real preds:
tensor([[0.3929]], device='cuda:0')
fake preds:
tensor([[0.6329]], device='cuda:0')
0.8157894736842105
real preds:
tensor([[0.3530]], device='cuda:0')
fake preds:
tensor([[0.6847]], device='cuda:0')
0.8205128205128205
real preds:
tensor([[0.3969]], device='cuda:0')
fake preds:
tensor([[0.6827]], device='cuda:0')
0.825
real preds:
tensor([[0.3614]], device='cuda:0')
fake preds:
tensor([[0.7493]], device='cuda:0')
0.8292682926829268
real preds:
tensor([[0.3991]], device='cuda:0')
fake preds:
tensor([[0.7803]], device='cuda:0')
0.8333333333333334
real preds:
tensor([[0.3820]], device='cuda:0')
fake preds:
tensor([[0.7743]], device='cuda:0')
0.8372093023


  1%|          | 70/7672 [00:00<04:50, 26.13it/s][A
  1%|          | 87/7672 [00:00<03:36, 35.01it/s][A

real preds:
tensor([[0.4506]], device='cuda:0')
fake preds:
tensor([[0.5631]], device='cuda:0')
0.8913043478260869
real preds:
tensor([[0.4478]], device='cuda:0')
fake preds:
tensor([[0.5207]], device='cuda:0')
0.8928571428571429
real preds:
tensor([[0.4248]], device='cuda:0')
fake preds:
tensor([[0.6511]], device='cuda:0')
0.8943661971830986
real preds:
tensor([[0.4575]], device='cuda:0')
fake preds:
tensor([[0.5470]], device='cuda:0')
0.8958333333333334
real preds:
tensor([[0.5340]], device='cuda:0')
fake preds:
tensor([[0.7166]], device='cuda:0')
0.8904109589041096
real preds:
tensor([[0.4369]], device='cuda:0')
fake preds:
tensor([[0.6598]], device='cuda:0')
0.8918918918918919
real preds:
tensor([[0.4547]], device='cuda:0')
fake preds:
tensor([[0.7179]], device='cuda:0')
0.8933333333333333
real preds:
tensor([[0.4429]], device='cuda:0')
fake preds:
tensor([[0.6978]], device='cuda:0')
0.8947368421052632
real preds:
tensor([[0.5328]], device='cuda:0')
fake preds:
tensor([[0.6845]], d


  1%|          | 87/7672 [00:20<03:36, 35.01it/s][A

In [None]:
print(val_accuracy)

In [28]:
m = torch.distributions.Uniform(torch.tensor([0.0]), torch.tensor([0.0]))
t = m.sample((10,1))
print(t)
print(t.shape)


tensor([[[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]],

        [[0.]]])
torch.Size([10, 1, 1])
