In [1]:
from collections import defaultdict
import os, json
from tqdm import tqdm

In [2]:
path = '/scratch/jiadeng_root/jiadeng/shared_data/datasets/SICGAN_data/'
labels = {'04379243':'table','03211117':'monitor','04401088':'cellphone','04530566': 'watercraft',  '03001627' : 'chair','03636649' : 'lamp',  '03691459': 'speaker' ,  '02828884':'bench',
'02691156': 'plane', '02808440': 'bathtub',  '02871439': 'bookcase',
'02773838': 'bag', '02801938': 'basket', '02828884' : 'bench','02880940': 'bowl' ,
'02924116': 'bus', '02933112': 'cabinet', '02942699': 'camera', '02958343': 'car', '03207941': 'dishwasher',
'03337140': 'file', '03624134': 'knife', '03642806': 'laptop', '03710193': 'mailbox',
'03761084': 'microwave', '03928116': 'piano', '03938244':'pillow', '03948459': 'pistol', '04004475': 'printer',
'04099429': 'rocket', '04256520': 'sofa', '04554684': 'washer', '04090263': 'rifle'}

In [None]:
summary = defaultdict(dict)
for sid in ['03001627', '04379243']:
    sid_pth = os.path.join(path, sid)
    for mid in tqdm(os.listdir(sid_pth)):
        mid_pth = os.path.join(sid_pth, mid)
        image_list = os.listdir(os.path.join(mid_pth, 'images'))
        summary[sid][mid] = len(image_list)

In [None]:
for sid in summary:
    print(sid,labels[sid]+' : '+str(len(set(summary[sid]))))

In [None]:
with open(path+'summary.json', "w") as f:
    json.dump(summary, f)

In [None]:
with open(path+'summary.json', "r") as f:
    summary = json.load(f)

In [None]:
split_file = defaultdict(dict)
trn_ratio = 3500
val_ratio = 4500
tst_ratio = 5500
for sid in summary:
    sid_pth = os.path.join(path, sid)
    sid_len = len(os.listdir(sid_pth))
    for i,mid in enumerate(tqdm(os.listdir(sid_pth))):
        if i < trn_ratio:
            split_type = 'train'
        elif i < val_ratio:
            split_type = 'val'
        elif i < tst_ratio:
            split_type = 'test'
        else:
            split_type = 'leftover'
        mid_pth = os.path.join(sid_pth, mid)
        num_image = len(os.listdir(os.path.join(mid_pth, 'images')))
        try:
            split_file[split_type][sid].update({mid : [i for i in range(num_image)]})
        except:
            split_file[split_type][sid] = {}
            split_file[split_type][sid].update({mid : [i for i in range(num_image)]})

In [11]:
for data in split_file:
    print(data)
    for sid in split_file[data]:
        print('\t'+labels[sid]+' : '+str(len(set(split_file[data][sid]))))

train
	chair : 3500
	table : 3500
val
	chair : 1000
	table : 1000
test
	chair : 1000
	table : 1000
leftover
	chair : 11
	table : 2177


In [None]:
with open(path+'p2m_splits.json', "w") as f:
    json.dump(split_file, f)

In [None]:
with open('p2m_splits.json', "w") as f:
    json.dump(split_file, f)

In [3]:
with open(path+'p2m_splits.json', "r") as f:
    split_file = json.load(f)

In [7]:
import argparse
import logging
import os,sys
from typing import Type
import random 
from tqdm import tqdm

import torch
import numpy as np
from torch import nn, optim

from sicgan.config import Config
from sicgan.models import Pixel2MeshHead
from sicgan.models import GraphConvClf
from sicgan.data.build_data_loader import build_data_loader
from sicgan.models import MeshLoss
from sicgan.utils.torch_utils import save_checkpoint
from torch.utils.tensorboard import SummaryWriter


import warnings
warnings.filterwarnings("ignore")

In [8]:
_C = Config('config/train_p2m.yml', [])

In [None]:
trn_dataloader = build_data_loader(_C, "MeshVox", split_name='train')

In [None]:
len(trn_dataloader)

In [9]:
splits_file = _C.DATASETS.SPLITS_FILE
split_name = 'test'
with open(splits_file, "r") as f:
    splits = json.load(f)
if split_name is not None:
    if split_name in ["train", "train_eval"]:
        split = splits["train"]
    else:
        split = splits[split_name]

In [12]:
len(split)

2

In [20]:
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import json
import logging
import os
import torch
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.structures import Meshes
from torch.utils.data import Dataset

import torchvision.transforms as T
from PIL import Image
from sicgan.data.utils import imagenet_preprocess, project_verts
# from shapenet.utils.coords import SHAPENET_MAX_ZMAX, SHAPENET_MIN_ZMIN, project_verts

logger = logging.getLogger('mesh')


class MeshVoxDataset(Dataset):
    def __init__(
        self,
        data_dir,
        normalize_images=True,
        split=None,
        return_mesh=False,
        voxel_size=32,    # Not required
        num_samples=5000,
        sample_online=False,
        in_memory=False,
        return_id_str=False,
    ):

        super(MeshVoxDataset, self).__init__()
        if not return_mesh and sample_online:
            raise ValueError("Cannot sample online without returning mesh")
        self.data_dir = data_dir
        self.return_mesh = return_mesh
        # self.voxel_size = voxel_size
        self.num_samples = num_samples
        self.sample_online = sample_online
        self.return_id_str = return_id_str
        
        self.synset_ids = []
        self.model_ids = []
        self.image_ids = []
        self.mid_to_samples = {}

        transform = [T.Resize((192,256))]
        transform.append(T.ToTensor())
        if normalize_images:
            transform.append(imagenet_preprocess())   # Change this to r2n2 params
        self.transform = T.Compose(transform)

        summary_json = os.path.join(data_dir, "summary.json")
        print(data_dir)
        with open(summary_json, "r") as f:
            summary = json.load(f)
            for sid in summary:
                print("Starting synset %s" % sid)
                allowed_mids = None
                if split is not None:
                    if sid not in split:
                        print("Skipping synset %s" % sid)
                        continue
                    elif isinstance(split[sid], list):
                        print('list')
                        allowed_mids = set(split[sid])
                    elif isinstance(split, dict):
                        print('dict')
                        allowed_mids = set(split[sid].keys())
                print(len(allowed_mids))
#                 print(allowed_mids)
                a = []
                b = 0
                for mid, num_imgs in summary[sid].items():
                    a.append(mid not in allowed_mids)
                    if allowed_mids is not None and mid not in allowed_mids:
#                         print('skipping over : ', mid)
#                         print(mid not in allowed_mids)
                        continue
                    allowed_iids = None
                    if split is not None and isinstance(split[sid], dict):
                        allowed_iids = set(split[sid][mid])
                    if not sample_online and in_memory:
                        samples_path = os.path.join(data_dir, sid, mid, "samples.pt")
                        samples = torch.load(samples_path)
                        self.mid_to_samples[mid] = samples
                    for iid in range(num_imgs):
                        if allowed_iids is None or iid in allowed_iids:
                            b +=1
                            self.synset_ids.append(sid)
                            self.model_ids.append(mid)
                            self.image_ids.append(iid)
#                         else:
#                             print(iid in allowed_iids, iid, allowed_iids)
#                             break
                print(np.sum(a),b)

    def __len__(self):
        return len(self.synset_ids)

    def __getitem__(self, idx):
        sid = self.synset_ids[idx]
        mid = self.model_ids[idx]
        iid = self.image_ids[idx]
        pass

In [21]:
dset = MeshVoxDataset(
    _C.DATASETS.DATA_DIR,
    split=split,
    num_samples=_C.G.MESH_HEAD.GT_NUM_SAMPLES,
    return_mesh=True,
    sample_online=False,
    return_id_str=False,
)

/scratch/jiadeng_root/jiadeng/shared_data/datasets/SICGAN_data/
Starting synset 03001627
dict
1000
4511 24000
Starting synset 04379243
dict
1000
6677 24000


In [23]:
len(dset)/32

1500.0

In [17]:
synset_ids = []
model_ids = []
image_ids = []
mid_to_samples = {}

In [19]:
data_dir = _C.DATASETS.DATA_DIR
summary_json = os.path.join(data_dir, "summary.json")
print(data_dir)
with open(summary_json, "r") as f:
    summary = json.load(f)
    for sid in summary:
        print("Starting synset %s" % sid)
        allowed_mids = None
        if split is not None:
            if sid not in split:
                print("Skipping synset %s" % sid)
                continue
            elif isinstance(split[sid], list):
                print('list')
                allowed_mids = set(split[sid])
            elif isinstance(split, dict):
                print('dict')
                allowed_mids = set(split[sid].keys())
        print(len(allowed_mids))
        a = []
        b = 0
        for mid, num_imgs in summary[sid].items():
            if allowed_mids is not None and mid not in allowed_mids:
    #                         print('skipping over : ', mid)
    #                         print(mid not in allowed_mids)
                continue
            else:
                a.append(mid)
            allowed_iids = None
            if split is not None and isinstance(split[sid], dict):
                allowed_iids = set(split[sid][mid])
            for iid in range(num_imgs):
                if allowed_iids is None or iid in allowed_iids:
                    b +=1
                    synset_ids.append(sid)
                    model_ids.append(mid)
                    image_ids.append(iid)
    #                         else:
    #                             print(iid in allowed_iids, iid, allowed_iids)
    #                             break
        print(len(a),b)
        
        

/scratch/jiadeng_root/jiadeng/shared_data/datasets/SICGAN_data/
Starting synset 03001627
dict
1000
1000 24000
Starting synset 04379243
dict
1000
1000 24000


In [None]:
b = 0 
a = 0
m =[]
for mid, num_imgs in summary[sid].items():
    a+=1
    if allowed_mids is not None and mid not in allowed_mids:
        b+=1
    else:
        m.append(mid)

In [None]:
a, b, len(m)

In [None]:
for mid, num_imgs in summary[sid].items():
    print(mid, num_imgs)
    break

In [None]:
mid

In [None]:
allowed_mids = set(split[sid].keys())

In [None]:
mid in allowed_mids

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
%load_ext autoreload
%autoreload 2 

In [None]:
from sicgan.models import MeshLoss

In [None]:
from PyGEL3D import gel
from PyGEL3D import js
import re
from pytorch3d.io import load_obj, save_obj

In [None]:
from sicgan.models import Pixel2MeshHead
from sicgan.config import Config


_C = Config('./config/sicgan_train.yml', [])
G = Pixel2MeshHead(_C).cuda()

In [None]:
from sicgan.data.build_data_loader import build_data_loader
train = build_data_loader(_C, "MeshVox", split_name='train')

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
from tqdm import tqdm

In [None]:
for data in tqdm(train):
    imgs = data[0]
    meshes = data[1]
    break

In [None]:
meshes

In [None]:
plt.imshow(imgs[2].permute(1,2,0).numpy())

In [None]:
def plot_mesh(mesh):
    save_obj('mesh.obj', mesh.verts_packed(), mesh.faces_packed())
    js.set_export_mode()
    m = gel.obj_load('mesh.obj')
    js.display(m, smooth=False)

In [None]:
# plot_mesh(i[1][1])

In [None]:
m = G(imgs.cuda())

In [None]:
m.detach()

In [None]:
from sicgan.models import GraphConvClf

In [None]:
D = GraphConvClf(_C).cuda()

In [None]:
D(m.cuda())

In [None]:
    loss_fn_kwargs = {
        "chamfer_weight": _C.G.MESH_HEAD.CHAMFER_LOSS_WEIGHT,
        "normal_weight": _C.G.MESH_HEAD.NORMAL_LOSS_WEIGHT,
        "edge_weight": _C.G.MESH_HEAD.EDGE_LOSS_WEIGHT,
        "gt_num_samples": _C.G.MESH_HEAD.GT_NUM_SAMPLES,
        "pred_num_samples": _C.G.MESH_HEAD.PRED_NUM_SAMPLES,
    }

In [None]:
mesh_loss = MeshLoss(**loss_fn_kwargs).cuda()

In [None]:
mesh_loss(m.cuda(), meshes.cuda())