In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
import argparse
import logging
import os,sys
from typing import Type
import random 
from tqdm import tqdm

import torch
import numpy as np
from pytorch3d.structures import Meshes
from pytorch3d.io import load_obj, save_obj
from pytorch3d.ops import GraphConv

In [None]:
from gcnna.config import Config
from gcnna.models.base_nn import GraphConvClf
from gcnna.layers.utils import unpack_mesh_attr, pack_mesh_attr, pad_mesh_attr
from scripts.ico_objects import ico_disk

In [None]:
def load_mesh(tst_obj):
    verts, faces, aux = load_obj(tst_obj)
    mesh = Meshes(verts=[verts], faces=[faces.verts_idx]).cuda()
    return mesh

table_pth = '/scratch/jiadeng_root/jiadeng/shared_data/datasets/ShapeNetCore.v1/04379243/1028a9cbaa7a333230bbd4cddd04c77b/model.obj'
airplane_pth = '/scratch/jiadeng_root/jiadeng/shared_data/datasets/ShapeNetCore.v1/02691156/105f7f51e4140ee4b6b87e72ead132ed/model.obj'
chair_pth = '/scratch/jiadeng_root/jiadeng/shared_data/datasets/ShapeNetCore.v1/03001627/1016f4debe988507589aae130c1f06fb/model.obj'
rifle_pth = '/scratch/jiadeng_root/jiadeng/shared_data/datasets/ShapeNetCore.v1/04090263/10439f1f81fbee202be79d8b285c1e//model.obj'
table, airplane, chair, rifle = load_mesh(table_pth), load_mesh(airplane_pth), load_mesh(chair_pth), load_mesh(rifle_pth)

In [None]:
v1, f1, _ = load_obj(table_pth)
v2, f2, _ = load_obj(rifle_pth)
# mesh = Meshes(verts=[v1,v2], faces=[f1.verts_idx,f2.verts_idx]).cuda()
mesh = Meshes(verts=[v1], faces=[f1.verts_idx]).cuda()

In [None]:
cfg = Config('config/train_clf.yml')
clf = GraphConvClf(cfg).cuda()

In [None]:
clf

In [None]:
a = clf(mesh)

In [None]:
mesh = ico_disk().cuda()

In [None]:
def get_adjacency(verts, edges):
    V = verts.shape[0] 
    e0, e1 = edges.unbind(1)

    idx01 = torch.stack([e0, e1], dim=1)  # (sum(E_n), 2)
    idx10 = torch.stack([e1, e0], dim=1)  # (sum(E_n), 2)
    idx = torch.cat([idx01, idx10], dim=0).t()  # (2, 2*sum(E_n))

    # First, we construct the adjacency matrix,
    # i.e. A[i, j] = 1 if (i,j) is an edge, or
    # A[e0, e1] = 1 &  A[e1, e0] = 1
    ones = torch.ones(idx.shape[1], dtype=torch.float32, device=verts.device)
    A = torch.sparse.FloatTensor(idx, ones, (V, V))
    return A

def update_mesh(verts, A, idx):
    pass

In [None]:
def update_mesh(attn_w, verts, edges, mask_idx):
    num_nodes = attn_w.size(0)
    mask = mask_idx.new_full((num_nodes, ), -1)
    i = torch.arange(mask_idx.size(0), dtype=torch.long, device=mask_idx.device)
    mask[mask_idx] = i
    row, col = edges.unbind(1)
    
    # update the edges
    row, col = mask[row], mask[col]
    mask = (row >= 0) & (col >= 0)
    row, col = row[mask], col[mask]
    edges_updated = torch.stack([row, col], dim=1)
    
    # update the verts
    verts_updated = verts[mask_idx] * torch.tanh(attn_w[mask_idx]).view(-1, 1)
    
    return verts_updated, edges_updated

In [None]:
mesh.edges_packed()

In [None]:
# from pytorch3d.structures.utils import packed_to_list, list_to_padded, padded_to_list, list_to_packed
verts = mesh.verts_packed()
edges = mesh.edges_packed()
verts_idx = mesh.verts_packed_to_mesh_idx()
edges_idx = mesh.edges_packed_to_mesh_idx()
ratio = 0.5

V = verts.shape[0] 
attn = GraphConv(3,1).cuda()
attn_w = attn(verts, edges)

# Get list of packed meshes
attn_w_unpkd, _  = pad_mesh_attr(verts_idx, attn_w)
verts_unpkd, _  = pad_mesh_attr(verts_idx, verts)
edges_unpkd, _ = pad_mesh_attr(edges_idx, edges)
assert len(attn_w_unpkd) == len(edges_unpkd)
B = len(attn_w_unpkd)

# verts_upd, edges_upd = [], []
# verts_idx_upd, edges_idx_upd = [], []

# for i in range(B):
#     print('Unpacked: ',attn_w_unpkd[i].shape, edges_unpkd[i].shape)
#     _, mask_idx = torch.topk(attn_w_unpkd[i], int(V*ratio), dim=0, sorted=False)  
#     mask_idx, _ = torch.sort(mask_idx.view(-1))
#     v, e = update_mesh(attn_w_unpkd[i], verts_unpkd[i], edges_unpkd[i], mask_idx)
#     verts_upd.append(v)
#     edges_upd.append(e)
#     verts_idx_upd.append(torch.Tensor([i]*v.shape[0]).to(device=verts_idx.device, dtype=verts_idx.dtype))
#     edges_idx_upd.append(torch.Tensor([i]*e.shape[0]).to(device=edges_idx.device, dtype=edges_idx.dtype))
    

# verts_upd = pack_mesh_attr(verts_upd)
# edges_upd = pack_mesh_attr(edges_upd)
# verts_idx_upd = pack_mesh_attr(verts_idx_upd)
# edges_idx_upd = pack_mesh_attr(edges_idx_upd)
# print('Packed: ',verts_pkd.shape, edges_pkd.shape)

In [None]:
print(attn_w_unpkd.shape, verts_unpkd.shape, edges_unpkd.shape)

In [None]:
print(attn_w.shape, verts.shape, edges.shape)

In [None]:
_, mask_idx = torch.topk(attn_w, int(V*ratio), dim=0, sorted=False)  
mask_idx, _ = torch.sort(mask_idx.view(-1))
mask_idx

In [None]:
num_nodes = attn_w.size(0)
mask = mask_idx.new_full((num_nodes, ), -1)
mask

In [None]:
i = torch.arange(mask_idx.size(0), dtype=torch.long, device=mask_idx.device)
i

In [None]:
mask[mask_idx] = i
mask

In [None]:
row, col = edges.unbind(1)
row, col

In [None]:
row, col = mask[row], mask[col]
row, col


In [None]:
mask = (row >= 0) & (col >= 0)
mask

In [None]:
row, col = row[mask], col[mask]
row, col

In [None]:
from gcnna.layers.norm import BatchNorm
from pytorch3d.ops import GraphConv
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

In [None]:
bn = BatchNorm(3)

In [None]:
bn

In [None]:
g = GraphConv(3,7)

In [None]:
g

In [None]:
optimizer = torch.optim.Adam(
    clf.parameters(),
    lr=1,
)

In [None]:
optimizer.step()

In [None]:
optimizer.state_dict()['param_groups']

In [None]:
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0 = 4, T_mult=1, eta_min=1e-6, last_epoch=-1)

In [None]:
scheduler.get_lr()

In [None]:
verts = mesh.verts_packed()
edges = mesh.edges_packed()
verts_idx = mesh.verts_packed_to_mesh_idx()
edges_idx = mesh.edges_packed_to_mesh_idx()

In [None]:
edges_idx.detach()

In [None]:
row, col = row[mask], col[mask]
row, col

In [None]:
from gcnna.layers.pooling import SAGPool

In [None]:
s = SAGPool(3)

In [None]:
'SAGPool' in str(s)

In [None]:
torch.Tensor([0]*100).to(device=verts_idx.device, dtype=verts_idx.dtype)

In [None]:
verts_idx.device

In [42]:
import cv2