# Initialization

In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
import torch, os, shutil, pickle
from tqdm import tqdm
from glob import glob
from pytorch3d.io import load_obj, save_obj
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.structures import Meshes
import scipy.io as sio
import pandas as pd
from gcnna.launcher import FeatureVisualization
from GEOMetrics.layers import * 
from GEOMetrics.models import *
from GEOMetrics.voxel  import voxel2obj
from GEOMetrics.utils import Voxel_loader, GCN_Loader
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.animation #import FuncAnimation
from matplotlib.animation import FuncAnimation
mpl.rcParams['savefig.dpi'] = 80
mpl.rcParams['figure.dpi'] = 80
from pytorch3d.loss import chamfer_distance, mesh_edge_loss, mesh_laplacian_smoothing, mesh_normal_consistency

In [None]:
from plotly.offline import plot, iplot, init_notebook_mode
import plotly.graph_objects as go
init_notebook_mode(connected=True)

In [None]:
objects = ['bench','cabinet','car','cellphone','chair','lamp','monitor','plane','rifle','sofa','speaker','table','watercraft']
labels = {'04379243':'table','03211117':'monitor','04401088':'cellphone','04530566': 'watercraft',  '03001627' : 'chair','03636649' : 'lamp',  '03691459': 'speaker' ,  '02828884':'bench',
'02691156': 'plane', '02808440': 'bathtub',  '02871439': 'bookcase',
'02773838': 'bag', '02801938': 'basket', '02828884' : 'bench','02880940': 'bowl' ,
'02924116': 'bus', '02933112': 'cabinet', '02942699': 'camera', '02958343': 'car', '03207941': 'dishwasher',
'03337140': 'file', '03624134': 'knife', '03642806': 'laptop', '03710193': 'mailbox',
'03761084': 'microwave', '03928116': 'piano', '03938244':'pillow', '03948459': 'pistol', '04004475': 'printer',
'04099429': 'rocket', '04256520': 'sofa', '04554684': 'washer', '04090263': 'rifle'}

# Visualization Fns

In [None]:
save_obj()

In [None]:
from PyGEL3D import gel
from PyGEL3D import js
def plot_mesh(mesh=None, verts=None, faces=None):
    if mesh != None:
        save_obj('mesh.obj', mesh.verts_packed(), mesh.faces_packed())
    else:
        save_obj('mesh.obj', verts, faces)
    js.set_export_mode()
    m = gel.obj_load('mesh.obj')
    js.display(m, smooth=False)
    
def plot_pointcloud(mesh, title=""):
    # Sample points uniformly from the surface of the mesh.
    points = sample_points_from_meshes(mesh, 5000)
    x, y, z = points.clone().detach().cpu().squeeze().unbind(1)    
    fig = plt.figure(figsize=(5, 5))
    ax = Axes3D(fig)
    ax.scatter3D(x, z, -y)
    ax.set_xlabel('x')
    ax.set_ylabel('z')
    ax.set_zlabel('y')
    ax.set_title(title)
    ax.view_init(190, 30)
    plt.show()

def load_mesh(pth):
    verts, faces, aux = load_obj(pth)
    mesh = Meshes(verts=[verts], faces=[faces.verts_idx])
    return mesh.cuda()


# Load Trained Model

In [None]:
path = '/scratch/jiadeng_root/jiadeng/shared_data/datasets/GEOMetrics/shapenet/'
path = glob(path+'/*/*')

### 0N-GCN AE with Lat Dim = 50 - https://github.com/EdwardSmith1884/GEOMetrics

In [None]:
enc_zgcn = MeshEncoder(50).cuda()
dec_zgcn = Decoder(50).cuda()
enc_zgcn.load_state_dict(torch.load('GEOMetrics/checkpoint/zgcn_run_norm/encoder_of'))
dec_zgcn.load_state_dict(torch.load('GEOMetrics/checkpoint/zgcn_run_norm/decoder'))

### 0N-GCN AE with Lat Dim = 50

In [None]:
enc_zgcn_old = MeshEncoder(50).cuda()
dec_zgcn_old = Decoder(50).cuda()
enc_zgcn_old.load_state_dict(torch.load('GEOMetrics/checkpoint/zeron_gcn_vanilla/encoder'))
dec_zgcn_old.load_state_dict(torch.load('GEOMetrics/checkpoint/zeron_gcn_vanilla/decoder'))

### GCN AE with Lat Dim  = 50

In [None]:
enc_gcn = MeshEncoderGCN(50).cuda()
dec_gcn = Decoder(50).cuda()
enc_gcn.load_state_dict(torch.load('GEOMetrics/checkpoint/gcn_run_norm/encoder_of'))
dec_gcn.load_state_dict(torch.load('GEOMetrics/checkpoint/gcn_run_norm/decoder'))

### Data Loader

In [None]:
objects = ['chair']#['bench','sofa','chair','lamp','table']
path = '/scratch/jiadeng_root/jiadeng/shared_data/datasets/GEOMetrics/shapenet/'
paths = []
for p in glob(path+'/*'):
    if p.split('/')[-1] in objects:
        cls_pths = glob(p+'/*')        
        paths += cls_pths
# load data
data = Voxel_loader(paths)

In [None]:
norm = FeatureVisualization.normalize_verts

In [None]:
# x = []
# y = []
# z = []
# for batch in tqdm(data):
#     verts = batch['verts']
#     x.append((verts[:,0].max(), verts[:,0].min()))
#     y.append((verts[:,1].max(), verts[:,1].min()))
#     z.append((verts[:,2].max(), verts[:,2].min()))

In [None]:
# batch = data[0]
# plot_mesh(None, norm(None,batch['verts']), batch['faces']) 
# batch['id'], batch['verts'].shape

# Deep Dream for each channel

## Some Experiment

In [None]:
from pytorch3d.utils import ico_sphere
from GEOMetrics.ico_objects import ico_disk
from torch.autograd import Variable

In [None]:
norm = FeatureVisualization.normalize_verts

In [None]:
# m = ico_sphere(2)
# deform_verts = Variable(norm(None, m.verts_packed()), requires_grad = True)
# m_ = Meshes(verts=[deform_verts], faces=[m.faces_packed()])
# plot_mesh(m_)

## Dreaming

In [None]:
from IPython.display import Javascript
display(Javascript('IPython.notebook.execute_cells_above()'))

In [None]:
enc_zgcn_old.modules

In [None]:
viz_dl = FeatureVisualization(enc_zgcn_old, 'sphere', 'dream_layer')  

In [None]:
viz_dl.dream_layer(layer='h1', filter=0, iters=200, ico_level=2)

In [None]:
plot_mesh(viz_dl.new_src_mesh)

In [None]:
# trg_obj_pth = '/scratch/jiadeng_root/jiadeng/shared_data/datasets/GEOMetrics/shapenet/chair/47dde30e987efc6c8687ff9b0b4e4ac/model.obj'
trg_obj_pth = os.path.join(batch['id'], 'model.obj')
feat_inv = FeatureVisualization(enc, 'disk')
feat_inv.invert_feats(trg_obj_pth, iters=150, target_layer='h11') #trg_feats = pca_chair.cuda())

In [None]:
plot_mesh(feat_inv.new_src_mesh)

In [None]:
obj_path = os.path.join(res_path, 'feat_inv', feat_inv.src_mesh_name)
save_obj(os.path.join(obj_path,'base.obj'), feat_inv.trg_mesh.verts_packed(), feat_inv.trg_mesh.faces_packed())

# Feature Inversion

In [None]:
batch = data[0]
# plot_mesh(None, norm(None,batch['verts']), batch['faces']) 
batch['id'], batch['verts'].shape

In [None]:
exp = 'feat_inv_vertex'
viz_fi = FeatureVisualization(enc_gcn, 'disk', exp)

In [None]:
weights = {'cd_loss': 1, 'lap_loss': 1, 'edge_loss':1}

In [None]:
%matplotlib inline
viz_fi.invert_feats(trg_obj_path=os.path.join(batch['id'], 'model.obj'), layer='h11', filter=None, lr = 0.07, weights=weights, 
                    iters=100, ico_level=3, verbose=True)

In [None]:
plot_mesh(viz_fi.new_src_mesh)

In [None]:
# %matplotlib inline
# for i in range(60):
#     viz_fi.invert_feats(trg_obj_path=os.path.join(batch['id'], 'model.obj'), layer='h1', filter=i, lr = 0.01, weights=weights, iters=300, ico_level=3)

# obj2gif

In [None]:
from GEOMetrics.obj2gif import render_main

In [None]:
import sys
sys.path.append('PyHTMLWriter/src/')
from Element import Element
from TableRow import TableRow
from Table import Table
from TableWriter import TableWriter
import numpy as np

In [None]:
# exp = 'feat_inv'
dir_pth = 'results/gcnna_data/'+exp
mesh_pths = glob(dir_pth+'/*.obj')

In [None]:
exp

In [None]:
for pth in tqdm(mesh_pths):
    render_main(pth, camera_elevation=45, camera_rdistance=2, batch_size=30, image_size=300, output_filename=dir_pth)

In [None]:
gif_pths = glob(dir_pth+'/*.gif')
gif_pths.sort()

In [None]:
t = Table()
for r in range(0,len(gif_pths),3):
    i = int(r)
    if r == 0:
        r = TableRow(isHeader = True)
    else:
        r = TableRow()
    for e in range(i, i+3):
        j=int(e)
        if j<len(gif_pths):
            e = Element()
            e.addTxt(gif_pths[j].split('/')[-1])
            e.addGIF(gif_pths[j])
    #         e.addTxt('<img src="'+gif_pths[i]+'" width="512"/>')
    #         e.addImg(gif_pths[i])
            r.addElement(e)
    t.addRow(r)
tw = TableWriter(t, exp)
tw.write()

In [None]:
gif_pths

# Exploring trained model

In [None]:
batch = data[0]
mesh = Meshes(verts=[norm(None, batch['verts'])], faces=[batch['faces']]).cuda()
batch['id'], batch['verts'].shape

In [None]:
plot_mesh(mesh)

In [None]:
def decode(enc, dec, mesh=None, threshold = 0.3, latent = None):
    if latent == None:
        latent  = enc.extract_feats(mesh, 'latent')
    voxel_pred = dec(latent)
    voxel_pred_ = voxel_pred.clone()
    voxel_pred_[voxel_pred > threshold] = 1  
    voxel_pred_[voxel_pred<= threshold] = 0
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    ax.voxels(voxel_pred_.squeeze(0))
    plt.show()

In [None]:
# sphere =  12, 42, 162, 642, 2562, 
# disk = 9, 25, 81, 289, 1089, 4255

In [None]:
voxel_gt = batch['voxels'].unsqueeze(0)
%matplotlib inline
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.voxels(voxel_gt.squeeze(0))
plt.show()

In [None]:
decode(enc=enc_gcn, dec=dec_gcn, mesh=mesh, threshold=0.06)

In [None]:
decode(enc=enc_zgcn, dec=dec_zgcn, mesh=mesh, threshold=0.25)

In [None]:
mesh_ = load_mesh(os.path.join(batch['id'], 'model.obj'))
decode(enc=enc_zgcn_old, dec=dec_zgcn_old, mesh=mesh_, threshold=0.3)

# Latent Space Exploration

In [None]:
objects = ['chair', 'table']#['bench','sofa','chair','lamp','table']
path = '/scratch/jiadeng_root/jiadeng/shared_data/datasets/GEOMetrics/shapenet/'
paths = []
for p in glob(path+'/*'):
    if p.split('/')[-1] in objects:
        cls_pths = glob(p+'/*')        
        paths += cls_pths[:100]
# load data
data = Voxel_loader(paths)

In [None]:
for batch in tqdm(data):
    break

In [None]:
mesh = Meshes(verts=[norm(None,batch['verts'])], faces=[batch['faces']])

In [None]:
latent = []
for batch in tqdm(data):
    mesh = Meshes(verts=[norm(None, batch['verts'])], faces=[batch['faces']])
    cls = batch['id'].split('/')[-2]
    try:
        feats = list(enc_gcn.extract_feats(mesh.cuda(), 'latent').detach().cpu().numpy()[0])
        feats.insert(0, cls)
        latent.append(feats)
    except:
        print('skipped')

In [None]:
latent = pd.DataFrame(latent)

In [None]:
# latent = pd.read_csv('latent.csv')

In [None]:
X, y = latent.iloc[:,1:].values, latent.iloc[:,0].values

In [None]:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns

In [None]:
tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=3000)
tsne_results = tsne.fit_transform(X)

In [None]:
df = pd.DataFrame()
df['y'] = y
df['tsne-2d-one'] = tsne_results[:,0]
df['tsne-2d-two'] = tsne_results[:,1]

In [None]:
plt.figure(figsize=(16,10))
sns.scatterplot(
    x="tsne-2d-one", y="tsne-2d-two",
    hue="y",
    data=df,
    legend="full",
    alpha=0.6
)

In [None]:
df_chair = latent.loc[latent['0'] =='chair']
X_chair = df_chair.iloc[:,1:].values

In [None]:
X_chair.T.shape

In [None]:
max_chair = torch.Tensor(np.max(X_chair.T, axis=1)).cuda()

In [None]:
max_chair.shape

In [None]:
pca = PCA(n_components=1)
pca_chair = pca.fit_transform(X_chair.T)
pca_chair = torch.Tensor(pca_chair).view(-1)
np.sum(pca.explained_variance_ratio_)

In [None]:
pca_chair.shape

In [None]:
path = '/scratch/jiadeng_root/jiadeng/shared_data/datasets/SICGAN_data/'
labels = {'04379243':'table','03211117':'monitor','04401088':'cellphone','04530566': 'watercraft',  '03001627' : 'chair','03636649' : 'lamp',  '03691459': 'speaker' ,  '02828884':'bench',
'02691156': 'plane', '02808440': 'bathtub',  '02871439': 'bookcase',
'02773838': 'bag', '02801938': 'basket', '02828884' : 'bench','02880940': 'bowl' ,
'02924116': 'bus', '02933112': 'cabinet', '02942699': 'camera', '02958343': 'car', '03207941': 'dishwasher',
'03337140': 'file', '03624134': 'knife', '03642806': 'laptop', '03710193': 'mailbox',
'03761084': 'microwave', '03928116': 'piano', '03938244':'pillow', '03948459': 'pistol', '04004475': 'printer',
'04099429': 'rocket', '04256520': 'sofa', '04554684': 'washer', '04090263': 'rifle'}

In [None]:
import json
path = '/scratch/jiadeng_root/jiadeng/shared_data/datasets/SICGAN_data/'
with open(path+'summary.json', "r") as f:
    summary = json.load(f)
for sid in summary:
    print(labels[sid]+' : '+str(len(set(summary[sid]))))