### Downloading required files

In [1]:
!gdown 18UdsemUdKo75GXmQLLVYdcOhNZ3zM215

Downloading...
From: https://drive.google.com/uc?id=18UdsemUdKo75GXmQLLVYdcOhNZ3zM215
To: /content/shapenet_car.pt
100% 518M/518M [00:02<00:00, 220MB/s]


In [2]:
# !pip install -q torch==1.11.0 torchvision==0.10.0 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
!pip install -q ninja xatlas gdown
!pip install -q git+https://github.com/NVlabs/nvdiffrast/
!pip install -q meshzoo ipdb imageio gputil h5py point-cloud-utils imageio imageio-ffmpeg==0.4.4 pyspng==0.1.0
!pip install -q urllib3
!pip install -q scipy
!pip install -q click
!pip install -q tqdm
!pip install -q opencv-python==4.5.4.58

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m307.2/307.2 kB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m229.4/229.4 kB[0m [31m24.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for nvdiffrast (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.2/9.2 MB[0m [31m23.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m26.9/26.9 MB[0m [31m62.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m106.2/106.2 kB[0m [31m13.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━

In [3]:
!git clone https://github.com/m22cs058/GET3D.git

Cloning into 'GET3D'...
remote: Enumerating objects: 252, done.[K
remote: Counting objects: 100% (116/116), done.[K
remote: Compressing objects: 100% (76/76), done.[K
remote: Total 252 (delta 53), reused 48 (delta 40), pack-reused 136[K
Receiving objects: 100% (252/252), 175.43 MiB | 26.64 MiB/s, done.
Resolving deltas: 100% (76/76), done.


In [4]:
!mv shapenet_car.pt GET3D/pretrained_model/

In [5]:
%cd GET3D

/content/GET3D


### Import Libraries

In [None]:
import os
import click
import re
import json
import tempfile
import numpy as np
import torch
import PIL.Image
import imageio
import cv2
from tqdm import tqdm
import copy
import dnnlib
from torch_utils.ops import conv2d_gradfix, grid_sample_gradfix
from training import training_loop_3d
from metrics import metric_main
from torch_utils import training_stats
from torch_utils import custom_ops

### Loading and Saving Mesh and OBJs

In [None]:
def save_obj(points, faces, filepath):
    with open(filepath, 'w') as f:
        for point in points:
            f.write(f"v {point[0]} {point[1]} {point[2]}\n")

        for face in faces:
            f.write(f"f {' '.join(str(index+1) for index in face)}\n")

In [None]:
def savemeshtes2(points, textures, faces, texture_faces, filepath):
    folder, filename = os.path.split(filepath)
    filename, _ = os.path.splitext(filename)

    material_file = os.path.join(folder, f"{filename}.mtl")
    with open(material_file, 'w') as f:
        f.write("newmtl material_0\n")
        f.write("map_Kd {}.png\n".format(filename))

    with open(filepath, 'w') as f:
        f.write(f"mtllib {filename}.mtl\n")

        for point in points:
            f.write(f"v {point[0]} {point[1]} {point[2]}\n")

        for texture in textures:
            f.write(f"vt {texture[0]} {texture[1]}\n")

        f.write("usemtl material_0\n")

        for face, tface in zip(faces, texture_faces):
            f.write(f"f {'/'.join(str(vidx) for vidx in [*face+1, *tface+1])}\n")

In [6]:
def loadobj(filepath):
    verts = []
    faces = []

    with open(filepath, 'r') as f:
        for line in f:
            tokens = line.strip().split()
            if not tokens:
                continue

            if tokens[0] == 'v':
                verts.append([float(t) for t in tokens[1:]] )

            elif tokens[0] == 'f':
                faces.append([int(t.split('/')[0]) - 1 for t in tokens[1:]])

    faces = np.array(faces, dtype=np.int64)
    verts = np.array(verts, dtype=np.float32)

    return verts, faces


def loadobjtex(filepath):
    verts = []
    textures = []
    faces = []
    texture_faces = []

    with open(filepath, 'r') as f:
        for line in f:
            tokens = line.strip().split()
            if not tokens:
                continue

            if tokens[0] == 'v':
                verts.append([float(t) for t in tokens[1:]])

            elif tokens[0] == 'vt':
                textures.append([float(t) for t in tokens[1:3]])

            elif tokens[0] == 'f':
                face_tokens = [t.split('/') for t in tokens[1:]]
                face = [int(t[0]) - 1 for t in face_tokens]
                tface = [int(t[1]) - 1 for t in face_tokens]

                faces.append(face)
                texture_faces.append(tface)

    faces = np.array(faces, dtype=np.int64)
    texture_faces = np.array(texture_faces, dtype=np.int64)
    verts = np.array(verts, dtype=np.float32)
    textures = np.array(textures, dtype=np.float32)

    return verts, textures, faces, texture_faces

### Util functions for inferencing and results



In [None]:
def save_image_grid(img, fname, drange, grid_size):
    lo, hi = drange
    img = np.asarray(img, dtype=np.float32)
    img = (img - lo) * (255 / (hi - lo))
    img = np.rint(img).clip(0, 255).astype(np.uint8)

    gw, gh = grid_size
    _N, C, H, W = img.shape
    gw = _N // gh
    img = img.reshape([gh, gw, C, H, W])
    img = img.transpose(0, 3, 1, 4, 2)
    img = img.reshape([gh * H, gw * W, C])

    assert C in [1, 3]
    if not fname is None:
        if C == 1:
            PIL.Image.fromarray(img[:, :, 0], 'L').save(fname)
        if C == 3:
            PIL.Image.fromarray(img, 'RGB').save(fname)
    return img

In [None]:
def save_3d_shape(mesh_v_list, mesh_f_list, root, idx):
    n_mesh = len(mesh_f_list)
    mesh_dir = os.path.join(root, 'mesh_pred')
    os.makedirs(mesh_dir, exist_ok=True)
    for i_mesh in range(n_mesh):
        mesh_v = mesh_v_list[i_mesh]
        mesh_f = mesh_f_list[i_mesh]
        mesh_name = os.path.join(mesh_dir, '%07d_%02d.obj' % (idx, i_mesh))
        save_obj(mesh_v, mesh_f, mesh_name)


In [None]:
def save_visualization_for_interpolation(generator, num_sam=10, c_to_compute_w_avg=None, save_dir=None, gen_mesh=False):
    with torch.no_grad():
        generator.update_w_avg(c_to_compute_w_avg)
        geo_codes = torch.randn(num_sam, generator.z_dim, device=generator.device)
        tex_codes = torch.randn(num_sam, generator.z_dim, device=generator.device)
        ws_geo = generator.mapping_geo(geo_codes, None, truncation_psi=0.7)
        ws_tex = generator.mapping(tex_codes, None, truncation_psi=0.7)
        camera_list = [generator.synthesis.generate_rotate_camera_list(n_batch=num_sam)[4]]

        select_geo_codes = np.arange(4)  # You can change to other selected shapes
        select_tex_codes = np.arange(4)
        for i in range(len(select_geo_codes) - 1):
            ws_geo_a = ws_geo[select_geo_codes[i]].unsqueeze(dim=0)
            ws_geo_b = ws_geo[select_geo_codes[i + 1]].unsqueeze(dim=0)
            ws_tex_a = ws_tex[select_tex_codes[i]].unsqueeze(dim=0)
            ws_tex_b = ws_tex[select_tex_codes[i + 1]].unsqueeze(dim=0)
            new_ws_geo = []
            new_ws_tex = []
            n_interpolate = 10
            for _i in range(n_interpolate):
                w = float(_i + 1) / n_interpolate
                w = 1 - w
                new_ws_geo.append(ws_geo_a * w + ws_geo_b * (1 - w))
                new_ws_tex.append(ws_tex_a * w + ws_tex_b * (1 - w))
            new_ws_tex = torch.cat(new_ws_tex, dim=0)
            new_ws_geo = torch.cat(new_ws_geo, dim=0)
            save_path = os.path.join(save_dir, 'interpolate_%02d' % (i))
            os.makedirs(save_path, exist_ok=True)
            gen_swap(new_ws_geo, new_ws_tex, camera_list[0], generator, save_path=save_path, gen_mesh=gen_mesh)

In [None]:
def save_visualization(G_ema, grid_z, grid_c, run_dir, cur_nimg, grid_size, cur_tick, image_snapshot_ticks=50, save_gif_name=None, save_all=True, grid_tex_z=None,):
    with torch.no_grad():
        G_ema.update_w_avg()
        camera_list = G_ema.synthesis.generate_rotate_camera_list(n_batch=grid_z[0].shape[0])
        camera_img_list = []
        if not save_all:
            camera_list = [camera_list[4]]  # we only save one camera for this
        if grid_tex_z is None:
            grid_tex_z = grid_z
        for i_camera, camera in enumerate(camera_list):
            images_list = []
            mesh_v_list = []
            mesh_f_list = []
            for z, geo_z, c in zip(grid_tex_z, grid_z, grid_c):
                img, mask, sdf, deformation, v_deformed, mesh_v, mesh_f, gen_camera, img_wo_light, tex_hard_mask = G_ema.generate_3d(
                    z=z, geo_z=geo_z, c=c, noise_mode='const',
                    generate_no_light=True, truncation_psi=0.7, camera=camera)
                rgb_img = img[:, :3]
                save_img = torch.cat([rgb_img, mask.permute(0, 3, 1, 2).expand(-1, 3, -1, -1)], dim=-1).detach()
                images_list.append(save_img.cpu().numpy())
                mesh_v_list.extend([v.data.cpu().numpy() for v in mesh_v])
                mesh_f_list.extend([f.data.cpu().numpy() for f in mesh_f])
            images = np.concatenate(images_list, axis=0)
            if save_gif_name is None:
                save_file_name = 'fakes'
            else:
                save_file_name = 'fakes_%s' % (save_gif_name.split('.')[0])
            if save_all:
                img = save_image_grid(
                    images, None,
                    drange=[-1, 1], grid_size=grid_size)
            else:
                img = save_image_grid(
                    images, os.path.join(
                        run_dir,
                        f'{save_file_name}_{cur_nimg // 1000:06d}_{i_camera:02d}.png'),
                    drange=[-1, 1], grid_size=grid_size)
            camera_img_list.append(img)
        if save_gif_name is None:
            save_gif_name = f'fakes_{cur_nimg // 1000:06d}.gif'
        if save_all:
            imageio.mimsave(os.path.join(run_dir, save_gif_name), camera_img_list)
        n_shape = 10  # we only save 10 shapes to check performance
        if cur_tick % min((image_snapshot_ticks * 20), 100) == 0:
            save_3d_shape(mesh_v_list[:n_shape], mesh_f_list[:n_shape], run_dir, cur_nimg // 100)

In [None]:
def save_textured_mesh_for_inference(
        G_ema, grid_z, grid_c, run_dir, save_mesh_dir=None,
        c_to_compute_w_avg=None, grid_tex_z=None, use_style_mixing=False):
    with torch.no_grad():
        G_ema.update_w_avg(c_to_compute_w_avg)
        save_mesh_idx = 0
        mesh_dir = os.path.join(run_dir, save_mesh_dir)
        os.makedirs(mesh_dir, exist_ok=True)
        for idx in range(len(grid_z)):
            geo_z = grid_z[idx]
            if grid_tex_z is None:
                tex_z = grid_z[idx]
            else:
                tex_z = grid_tex_z[idx]
            generated_mesh = G_ema.generate_3d_mesh(
                geo_z=geo_z, tex_z=tex_z, c=None, truncation_psi=0.7,
                use_style_mixing=use_style_mixing)
            for mesh_v, mesh_f, all_uvs, all_mesh_tex_idx, tex_map in zip(*generated_mesh):
                savemeshtes2(
                    mesh_v.data.cpu().numpy(),
                    all_uvs.data.cpu().numpy(),
                    mesh_f.data.cpu().numpy(),
                    all_mesh_tex_idx.data.cpu().numpy(),
                    os.path.join(mesh_dir, '%07d.obj' % (save_mesh_idx))
                )
                lo, hi = (-1, 1)
                img = np.asarray(tex_map.permute(1, 2, 0).data.cpu().numpy(), dtype=np.float32)
                img = (img - lo) * (255 / (hi - lo))
                img = img.clip(0, 255)
                mask = np.sum(img.astype(np.float), axis=-1, keepdims=True)
                mask = (mask <= 3.0).astype(np.float)
                kernel = np.ones((3, 3), 'uint8')
                dilate_img = cv2.dilate(img, kernel, iterations=1)
                img = img * (1 - mask) + dilate_img * mask
                img = img.clip(0, 255).astype(np.uint8)
                PIL.Image.fromarray(np.ascontiguousarray(img[::-1, :, :]), 'RGB').save(
                    os.path.join(mesh_dir, '%07d.png' % (save_mesh_idx)))
                save_mesh_idx += 1

In [7]:
def save_geo_for_inference(G_ema, run_dir):
    import kaolin as kal
    def normalize_and_sample_points(mesh_v, mesh_f, kal, n_sample, normalized_scale=1.0):
        vertices = mesh_v.cuda()
        scale = (vertices.max(dim=0)[0] - vertices.min(dim=0)[0]).max()
        mesh_v1 = vertices / scale * normalized_scale
        mesh_f1 = mesh_f.cuda()
        points, _ = kal.ops.mesh.sample_points(mesh_v1.unsqueeze(dim=0), mesh_f1, n_sample)
        return points

    with torch.no_grad():
        use_style_mixing = True
        truncation_phi = 1.0
        mesh_dir = os.path.join(run_dir, 'gen_geo_for_eval_phi_%.2f' % (truncation_phi))
        surface_point_dir = os.path.join(run_dir, 'gen_geo_surface_points_for_eval_phi_%.2f' % (truncation_phi))
        os.makedirs(mesh_dir, exist_ok=True)
        os.makedirs(surface_point_dir, exist_ok=True)
        n_gen = 1500 * 5  # We generate 5x of test set here
        i_mesh = 0
        for i_gen in tqdm(range(n_gen)):
            geo_z = torch.randn(1, G_ema.z_dim, device=G_ema.device)
            generated_mesh = G_ema.generate_3d_mesh(
                geo_z=geo_z, tex_z=None, c=None, truncation_psi=truncation_phi,
                with_texture=False, use_style_mixing=use_style_mixing)
            for mesh_v, mesh_f in zip(*generated_mesh):
                if mesh_v.shape[0] == 0: continue
                save_obj(mesh_v.data.cpu().numpy(), mesh_f.data.cpu().numpy(), os.path.join(mesh_dir, '%07d.obj' % (i_mesh)))
                points = normalize_and_sample_points(mesh_v, mesh_f, kal, n_sample=2048, normalized_scale=1.0)
                np.savez(os.path.join(surface_point_dir, '%07d.npz' % (i_mesh)), pcd=points.data.cpu().numpy())
                i_mesh += 1

### Generate and save results


In [None]:
def clean_training_set_kwargs_for_metrics(training_set_kwargs):
    if 'add_camera_cond' in training_set_kwargs:
        training_set_kwargs['add_camera_cond'] = True
    return training_set_kwargs

def initialize_torch_settings(random_seed, num_gpus, rank):
    np.random.seed(random_seed * num_gpus + rank)
    torch.manual_seed(random_seed * num_gpus + rank)
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    conv2d_gradfix.enabled = True
    grid_sample_gradfix.enabled = True

def generate_and_save_visualizations(G_ema, device, run_dir, grid_z, grid_c, grid_tex_z=None):
    grid_size = (5, 5)
    n_shape = grid_size[0] * grid_size[1]
    print('Generating... ')
    save_visualization(G_ema, grid_z, grid_c, run_dir, 0, grid_size, 0, save_all=False, grid_tex_z=grid_tex_z)

In [None]:
def generate_textured_mesh_for_inference(G_ema, device, run_dir, grid_z, grid_c, grid_tex_z=None):
    print('==> generate inference 3d shapes with texture')
    save_textured_mesh_for_inference(
        G_ema, grid_z, grid_c, run_dir, save_mesh_dir='texture_mesh_for_inference',
        c_to_compute_w_avg=None, grid_tex_z=grid_tex_z
    )

def generate_interpolation_results(G_ema, run_dir):
    print('==> generate interpolation results')
    save_visualization_for_interpolation(G_ema, save_dir=os.path.join(run_dir, 'interpolation'))

def compute_fid_scores(metrics, G_ema, training_set_kwargs, num_gpus, rank, device, run_dir, resume_pretrain):
    print('==> compute FID scores for generation')
    for metric in metrics:
        training_set_kwargs = clean_training_set_kwargs_for_metrics(training_set_kwargs)
        training_set_kwargs['split'] = 'test'
        result_dict = metric_main.calc_metric(
            metric=metric, G=G_ema,
            dataset_kwargs=training_set_kwargs, num_gpus=num_gpus, rank=rank,
            device=device
        )
        metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=resume_pretrain)

def generate_shapes_for_evaluation(G_ema, run_dir):
    print('==> generate 7500 shapes for evaluation')
    save_geo_for_inference(G_ema, run_dir)

### Inferencing function

In [9]:
def inference(
        run_dir='.', training_set_kwargs={}, G_kwargs={}, D_kwargs={}, metrics=[],
        random_seed=0, num_gpus=1, rank=0, inference_vis=False,
        inference_to_generate_textured_mesh=False, resume_pretrain=None,
        inference_save_interpolation=False, inference_compute_fid=False,
        inference_generate_geo=False, **dummy_kawargs
):
    from torch_utils.ops import upfirdn2d, bias_act, filtered_lrelu

    upfirdn2d._init()
    bias_act._init()
    filtered_lrelu._init()

    device = torch.device('cuda', rank)
    initialize_torch_settings(random_seed, num_gpus, rank)

    common_kwargs = dict(
        c_dim=0, img_resolution=training_set_kwargs.get('resolution', 1024), img_channels=3
    )
    G_kwargs['device'] = device

    G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device)
    G_ema = copy.deepcopy(G).eval()

    if resume_pretrain is not None and rank == 0:
        print('==> resume from pretrained path %s' % resume_pretrain)
        model_state_dict = torch.load(resume_pretrain, map_location=device)
        G.load_state_dict(model_state_dict['G'], strict=True)
        G_ema.load_state_dict(model_state_dict['G_ema'], strict=True)

    grid_z = torch.randn([25, G.z_dim], device=device).split(1)
    grid_tex_z = torch.randn([25, G.z_dim], device=device).split(1)
    grid_c = torch.ones(25, device=device).split(1)

    generate_and_save_visualizations(G_ema, device, run_dir, grid_z, grid_c, grid_tex_z)

    if inference_to_generate_textured_mesh:
        generate_textured_mesh_for_inference(G_ema, device, run_dir, grid_z, grid_c, grid_tex_z)

    if inference_save_interpolation:
        generate_interpolation_results(G_ema, run_dir)

    if inference_compute_fid:
        compute_fid_scores(metrics, G_ema, training_set_kwargs, num_gpus, rank, device, run_dir, resume_pretrain)

    if inference_generate_geo:
        generate_shapes_for_evaluation(G_ema, run_dir)


### train_3d.py

In [None]:
def subprocess_fn(rank, c, temp_dir):
    dnnlib.util.Logger(file_name=os.path.join(c.run_dir, 'log.txt'), file_mode='a', should_flush=True)

    # Init torch.distributed.
    if c.num_gpus > 1:
        init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
        if os.name == 'nt':
            init_method = 'file:///' + init_file.replace('\\', '/')
            torch.distributed.init_process_group(
                backend='gloo', init_method=init_method, rank=rank, world_size=c.num_gpus)
        else:
            init_method = f'file://{init_file}'
            torch.distributed.init_process_group(
                backend='nccl', init_method=init_method, rank=rank, world_size=c.num_gpus)

    # Init torch_utils.
    sync_device = torch.device('cuda', rank) if c.num_gpus > 1 else None
    training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
    if rank != 0:
        custom_ops.verbosity = 'none'

    if c.inference_vis:
        inference(rank=rank, **c)
    # Execute training loop.
    else:
        training_loop_3d.training_loop(rank=rank, **c)

#### Training/Inferencing Launcher

In [None]:
def launch_training(c, desc, outdir, dry_run):
    dnnlib.util.Logger(should_flush=True)

    # Pick output directory.
    prev_run_dirs = []
    if os.path.isdir(outdir):
        prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x))]
    prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs]
    prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None]
    cur_run_id = max(prev_run_ids, default=-1) + 1
    if c.inference_vis:
        c.run_dir = os.path.join(outdir, 'inference')
    else:
        c.run_dir = os.path.join(outdir, f'{cur_run_id:05d}-{desc}')
        assert not os.path.exists(c.run_dir)

    # Create output directory.
    print('Creating output directory...')
    if not os.path.exists(c.run_dir):
        os.makedirs(c.run_dir)
    with open(os.path.join(c.run_dir, 'training_options.json'), 'wt') as f:
        json.dump(c, f, indent=2)

    # Launch processes.
    print('Launching processes...')
    torch.multiprocessing.set_start_method('spawn', force=True)
    with tempfile.TemporaryDirectory() as temp_dir:
        if c.num_gpus == 1:
            subprocess_fn(rank=0, c=c, temp_dir=temp_dir)
        else:
            torch.multiprocessing.spawn(fn=subprocess_fn, args=(c, temp_dir), nprocs=c.num_gpus)

#### Initializing Arguments

In [None]:
def init_dataset_kwargs(data, opt=None):
    try:
        if opt.use_shapenet_split:
            dataset_kwargs = dnnlib.EasyDict(
                class_name='training.dataset.ImageFolderDataset',
                path=data, use_labels=True, max_size=None, xflip=False,
                resolution=opt.img_res,
                data_camera_mode=opt.data_camera_mode,
                add_camera_cond=opt.add_camera_cond,
                camera_path=opt.camera_path,
                split='test' if opt.inference_vis else 'train',
            )
        else:
            dataset_kwargs = dnnlib.EasyDict(
                class_name='training.dataset.ImageFolderDataset',
                path=data, use_labels=True, max_size=None, xflip=False, resolution=opt.img_res,
                data_camera_mode=opt.data_camera_mode,
                add_camera_cond=opt.add_camera_cond,
                camera_path=opt.camera_path,
            )
        dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs)  # Subclass of training.dataset.Dataset.
        dataset_kwargs.camera_path = opt.camera_path
        dataset_kwargs.resolution = dataset_obj.resolution  # Be explicit about resolution.
        dataset_kwargs.use_labels = dataset_obj.has_labels  # Be explicit about labels.
        dataset_kwargs.max_size = len(dataset_obj)  # Be explicit about dataset size.
        return dataset_kwargs, dataset_obj.name
    except IOError as err:
        raise click.ClickException(f'--data: {err}')


def parse_comma_separated_list(s):
    if isinstance(s, list):
        return s
    if s is None or s.lower() == 'none' or s == '':
        return []
    return s.split(',')

### Arguments using click command and main file [from main repo]


In [10]:
@click.command()
# Required from StyleGAN2.
@click.option('--outdir', help='Where to save the results', metavar='DIR', required=True)
@click.option('--cfg', help='Base configuration', type=click.Choice(['stylegan3-t', 'stylegan3-r', 'stylegan2']), default='stylegan2')
@click.option('--gpus', help='Number of GPUs to use', metavar='INT', type=click.IntRange(min=1), required=True)
@click.option('--batch', help='Total batch size', metavar='INT', type=click.IntRange(min=1), required=True)
@click.option('--gamma', help='R1 regularization weight', metavar='FLOAT', type=click.FloatRange(min=0), required=True)
# My custom configs
### Configs for inference
@click.option('--resume_pretrain', help='Resume from given network pickle', metavar='[PATH|URL]', type=str)
@click.option('--inference_vis', help='whther we run infernce', metavar='BOOL', type=bool, default=False, show_default=True)
@click.option('--inference_to_generate_textured_mesh', help='inference to generate textured meshes', metavar='BOOL', type=bool, default=False, show_default=False)
@click.option('--inference_save_interpolation', help='inference to generate interpolation results', metavar='BOOL', type=bool, default=False, show_default=False)
@click.option('--inference_compute_fid', help='inference to generate interpolation results', metavar='BOOL', type=bool, default=False, show_default=False)
@click.option('--inference_generate_geo', help='inference to generate geometry points', metavar='BOOL', type=bool, default=False, show_default=False)
### Configs for dataset

@click.option('--data', help='Path to the Training data Images', metavar='[DIR]', type=str, default='./tmp')
@click.option('--camera_path', help='Path to the camera root', metavar='[DIR]', type=str, default='./tmp')
@click.option('--img_res', help='The resolution of image', metavar='INT', type=click.IntRange(min=1), default=1024)
@click.option('--data_camera_mode', help='The type of dataset we are using', type=str, default='shapenet_car', show_default=True)
@click.option('--use_shapenet_split', help='whether use the training split or all the data for training', metavar='BOOL', type=bool, default=False, show_default=False)
### Configs for 3D generator##########
@click.option('--iso_surface', help='Differentiable iso-surfacing method', type=click.Choice(['dmtet', 'flexicubes']), default='dmtet')
@click.option('--use_style_mixing', help='whether use style mixing for generation during inference', metavar='BOOL', type=bool, default=True, show_default=False)
@click.option('--one_3d_generator', help='whether we detach the gradient for empty object', metavar='BOOL', type=bool, default=True, show_default=True)
@click.option('--dmtet_scale', help='Scale for the dimention of dmtet', metavar='FLOAT', type=click.FloatRange(min=0, max=10.0), default=1.0, show_default=True)
@click.option('--n_implicit_layer', help='Number of Implicit FC layer for XYZPlaneTex model', metavar='INT', type=click.IntRange(min=1), default=1)
@click.option('--feat_channel', help='Feature channel for TORGB layer', metavar='INT', type=click.IntRange(min=0), default=16)
@click.option('--mlp_latent_channel', help='mlp_latent_channel for XYZPlaneTex network', metavar='INT', type=click.IntRange(min=8), default=32)
@click.option('--deformation_multiplier', help='Multiplier for the predicted deformation', metavar='FLOAT', type=click.FloatRange(min=1.0), default=1.0, required=False)
@click.option('--tri_plane_resolution', help='The resolution for tri plane', metavar='INT', type=click.IntRange(min=1), default=256)
@click.option('--n_views', help='number of views when training generator', metavar='INT', type=click.IntRange(min=1), default=1)
@click.option('--use_tri_plane', help='Whether use tri plane representation', metavar='BOOL', type=bool, default=True, show_default=True)
@click.option('--tet_res', help='Resolution for teteahedron', metavar='INT', type=click.IntRange(min=1), default=90)
@click.option('--latent_dim', help='Dimention for latent code', metavar='INT', type=click.IntRange(min=1), default=512)
@click.option('--geometry_type', help='The type of geometry generator', type=str, default='conv3d', show_default=True)
@click.option('--render_type', help='Type of renderer we used', metavar='STR', type=click.Choice(['neural_render', 'spherical_gaussian']), default='neural_render', show_default=True)
### Configs for training loss and discriminator#
@click.option('--d_architecture', help='The architecture for discriminator', metavar='STR', type=str, default='skip', show_default=True)
@click.option('--use_pl_length', help='whether we apply path length regularization', metavar='BOOL', type=bool, default=False, show_default=False)  # We didn't use path lenth regularzation to avoid nan error
@click.option('--gamma_mask', help='R1 regularization weight for mask', metavar='FLOAT', type=click.FloatRange(min=0), default=0.0, required=False)
@click.option('--d_reg_interval', help='The internal for R1 regularization', metavar='INT', type=click.IntRange(min=1), default=16)
@click.option('--add_camera_cond', help='Whether we add camera as condition for discriminator', metavar='BOOL', type=bool, default=True, show_default=True)
@click.option('--lambda_flexicubes_surface_reg', help='Weights for flexicubes regularization L_dev', metavar='FLOAT', type=click.FloatRange(min=0), default=0.5, required=False)
@click.option('--lambda_flexicubes_weights_reg', help='Weights for flexicubes regularization on weights', metavar='FLOAT', type=click.FloatRange(min=0), default=0.1, required=False)

## Miscs
# Optional features.
@click.option('--cond', help='Train conditional model', metavar='BOOL', type=bool, default=False, show_default=True)
@click.option('--freezed', help='Freeze first layers of D', metavar='INT', type=click.IntRange(min=0), default=0, show_default=True)
# Misc hyperparameters.
@click.option('--batch-gpu', help='Limit batch size per GPU', metavar='INT', type=click.IntRange(min=1), default=4)
@click.option('--cbase', help='Capacity multiplier', metavar='INT', type=click.IntRange(min=1), default=32768, show_default=True)
@click.option('--cmax', help='Max. feature maps', metavar='INT', type=click.IntRange(min=1), default=512, show_default=True)
@click.option('--glr', help='G learning rate  [default: varies]', metavar='FLOAT', type=click.FloatRange(min=0))
@click.option('--dlr', help='D learning rate', metavar='FLOAT', type=click.FloatRange(min=0), default=0.002, show_default=True)
@click.option('--map-depth', help='Mapping network depth  [default: varies]', metavar='INT', type=click.IntRange(min=1))
@click.option('--mbstd-group', help='Minibatch std group size', metavar='INT', type=click.IntRange(min=1), default=4, show_default=True)
# Misc settings.
@click.option('--desc', help='String to include in result dir name', metavar='STR', type=str)
@click.option('--metrics', help='Quality metrics', metavar='[NAME|A,B,C|none]', type=parse_comma_separated_list, default='fid50k', show_default=True)
@click.option('--kimg', help='Total training duration', metavar='KIMG', type=click.IntRange(min=1), default=20000, show_default=True)
@click.option('--tick', help='How often to print progress', metavar='KIMG', type=click.IntRange(min=1), default=1, show_default=True)  ##
@click.option('--snap', help='How often to save snapshots', metavar='TICKS', type=click.IntRange(min=1), default=50, show_default=True)  ###
@click.option('--seed', help='Random seed', metavar='INT', type=click.IntRange(min=0), default=0, show_default=True)
@click.option('--fp32', help='Disable mixed-precision', metavar='BOOL', type=bool, default=True, show_default=True)  # Let's use fp32 all the case without clamping
@click.option('--nobench', help='Disable cuDNN benchmarking', metavar='BOOL', type=bool, default=False, show_default=True)
@click.option('--workers', help='DataLoader worker processes', metavar='INT', type=click.IntRange(min=0), default=3, show_default=True)
@click.option('-n', '--dry-run', help='Print training options and exit', is_flag=True)

def main(**kwargs):
    # Initialize config.
    print('==> start')
    opts = dnnlib.EasyDict(kwargs)  # Command line arguments.
    c = dnnlib.EasyDict()  # Main config dict.
    c.G_kwargs = dnnlib.EasyDict(
        class_name=None, z_dim=opts.latent_dim, w_dim=opts.latent_dim, mapping_kwargs=dnnlib.EasyDict())
    c.D_kwargs = dnnlib.EasyDict(
        class_name='training.networks_get3d.Discriminator', block_kwargs=dnnlib.EasyDict(),
        mapping_kwargs=dnnlib.EasyDict(), epilogue_kwargs=dnnlib.EasyDict())
    c.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0, 0.99], eps=1e-8)
    c.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', betas=[0, 0.99], eps=1e-8)
    c.loss_kwargs = dnnlib.EasyDict(class_name='training.loss.StyleGAN2Loss')

    c.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, prefetch_factor=2)
    c.inference_vis = opts.inference_vis
    # Training set.
    if opts.inference_vis:
        c.inference_to_generate_textured_mesh = opts.inference_to_generate_textured_mesh
        c.inference_save_interpolation = opts.inference_save_interpolation
        c.inference_compute_fid = opts.inference_compute_fid
        c.inference_generate_geo = opts.inference_generate_geo

    c.training_set_kwargs, dataset_name = init_dataset_kwargs(data=opts.data, opt=opts)
    if opts.cond and not c.training_set_kwargs.use_labels:
        raise click.ClickException('--cond=True requires labels specified in dataset.json')
    c.training_set_kwargs.split = 'train' if opts.use_shapenet_split else 'all'
    if opts.use_shapenet_split and opts.inference_vis:
        c.training_set_kwargs.split = 'test'
    c.training_set_kwargs.use_labels = opts.cond
    c.training_set_kwargs.xflip = False
    # Hyperparameters & settings.p
    c.G_kwargs.iso_surface = opts.iso_surface
    c.G_kwargs.one_3d_generator = opts.one_3d_generator
    c.G_kwargs.n_implicit_layer = opts.n_implicit_layer
    c.G_kwargs.deformation_multiplier = opts.deformation_multiplier
    c.resume_pretrain = opts.resume_pretrain
    c.D_reg_interval = opts.d_reg_interval
    c.G_kwargs.use_style_mixing = opts.use_style_mixing
    c.G_kwargs.dmtet_scale = opts.dmtet_scale
    c.G_kwargs.feat_channel = opts.feat_channel
    c.G_kwargs.mlp_latent_channel = opts.mlp_latent_channel
    c.G_kwargs.tri_plane_resolution = opts.tri_plane_resolution
    c.G_kwargs.n_views = opts.n_views

    c.G_kwargs.render_type = opts.render_type
    c.G_kwargs.use_tri_plane = opts.use_tri_plane
    c.D_kwargs.data_camera_mode = opts.data_camera_mode
    c.D_kwargs.add_camera_cond = opts.add_camera_cond

    c.G_kwargs.tet_res = opts.tet_res

    c.G_kwargs.geometry_type = opts.geometry_type
    c.num_gpus = opts.gpus
    c.batch_size = opts.batch
    c.batch_gpu = opts.batch_gpu or opts.batch // opts.gpus
    # c.G_kwargs.geo_pos_enc = opts.geo_pos_enc
    c.G_kwargs.data_camera_mode = opts.data_camera_mode
    c.G_kwargs.channel_base = c.D_kwargs.channel_base = opts.cbase
    c.G_kwargs.channel_max = c.D_kwargs.channel_max = opts.cmax

    c.G_kwargs.mapping_kwargs.num_layers = 8

    c.D_kwargs.architecture = opts.d_architecture
    c.D_kwargs.block_kwargs.freeze_layers = opts.freezed
    c.D_kwargs.epilogue_kwargs.mbstd_group_size = opts.mbstd_group
    c.loss_kwargs.gamma_mask = opts.gamma if opts.gamma_mask == 0.0 else opts.gamma_mask
    c.loss_kwargs.r1_gamma = opts.gamma
    c.loss_kwargs.lambda_flexicubes_surface_reg = opts.lambda_flexicubes_surface_reg
    c.loss_kwargs.lambda_flexicubes_weights_reg = opts.lambda_flexicubes_weights_reg
    c.G_opt_kwargs.lr = (0.002 if opts.cfg == 'stylegan2' else 0.0025) if opts.glr is None else opts.glr
    c.D_opt_kwargs.lr = opts.dlr
    c.metrics = opts.metrics
    c.total_kimg = opts.kimg
    c.kimg_per_tick = opts.tick
    c.image_snapshot_ticks = c.network_snapshot_ticks = opts.snap
    c.random_seed = c.training_set_kwargs.random_seed = opts.seed
    c.data_loader_kwargs.num_workers = opts.workers
    c.network_snapshot_ticks = 200
    # Sanity checks.
    if c.batch_size % c.num_gpus != 0:
        raise click.ClickException('--batch must be a multiple of --gpus')
    if c.batch_size % (c.num_gpus * c.batch_gpu) != 0:
        raise click.ClickException('--batch must be a multiple of --gpus times --batch-gpu')
    if c.batch_gpu < c.D_kwargs.epilogue_kwargs.mbstd_group_size:
        raise click.ClickException('--batch-gpu cannot be smaller than --mbstd')
    if any(not metric_main.is_valid_metric(metric) for metric in c.metrics):
        raise click.ClickException(
            '\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))

    # Base configuration.
    c.ema_kimg = c.batch_size * 10 / 32
    c.G_kwargs.class_name = 'training.networks_get3d.GeneratorDMTETMesh'
    c.loss_kwargs.style_mixing_prob = 0.9  # Enable style mixing regularization.
    c.loss_kwargs.pl_weight = 0.0  # Enable path length regularization.
    c.G_reg_interval = 4  # Enable lazy regularization for G.
    c.G_kwargs.fused_modconv_default = 'inference_only'  # Speed up training by using regular convolutions instead of grouped convolutions.
    # Performance-related toggles.
    if opts.fp32:
        c.G_kwargs.num_fp16_res = c.D_kwargs.num_fp16_res = 0
        c.G_kwargs.conv_clamp = c.D_kwargs.conv_clamp = None
    if opts.nobench:
        c.cudnn_benchmark = False

    # Description string.
    desc = f'{opts.cfg:s}-{dataset_name:s}-gpus{c.num_gpus:d}-batch{c.batch_size:d}-gamma{c.loss_kwargs.r1_gamma:g}'
    if opts.desc is not None:
        desc += f'-{opts.desc}'
    # Launch.
    print('==> launch training')
    launch_training(c=c, desc=desc, outdir=opts.outdir, dry_run=opts.dry_run)

### Invoking to run the whole code

In [11]:
ctx = click.Context(main)
ctx.invoke(
    main,
    outdir='save_inference_results/shapenet_car',
    gpus=1,
    batch=4,
    gamma=40,
    data_camera_mode='shapenet_car',
    dmtet_scale=1.0,
    use_shapenet_split=True,
    one_3d_generator=True,
    fp32=False,
    inference_vis=True,
    resume_pretrain='/content/GET3D/pretrained_model/shapenet_car.pt'
)

==> start
==> use shapenet dataset
==> ERROR!!!! THIS SHOULD ONLY HAPPEN WHEN USING INFERENCE
==> use image path: ./tmp, num images: 1234
==> launch training

Training options:
{
  "G_kwargs": {
    "class_name": "training.networks_get3d.GeneratorDMTETMesh",
    "z_dim": 512,
    "w_dim": 512,
    "mapping_kwargs": {
      "num_layers": 8
    },
    "iso_surface": "dmtet",
    "one_3d_generator": true,
    "n_implicit_layer": 1,
    "deformation_multiplier": 1.0,
    "use_style_mixing": true,
    "dmtet_scale": 1.0,
    "feat_channel": 16,
    "mlp_latent_channel": 32,
    "tri_plane_resolution": 256,
    "n_views": 1,
    "render_type": "neural_render",
    "use_tri_plane": true,
    "tet_res": 90,
    "geometry_type": "conv3d",
    "data_camera_mode": "shapenet_car",
    "channel_base": 32768,
    "channel_max": 512,
    "fused_modconv_default": "inference_only"
  },
  "D_kwargs": {
    "class_name": "training.networks_get3d.Discriminator",
    "block_kwargs": {
      "freeze_layers"

### To delete the results

In [7]:
# !rm -rf save_inference_results/