In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
import numpy as np
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import json

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

# Load data

In [2]:
path = Path("../dataset/dmitry_input/").resolve()
pti_render = path / "pti_out" / "PTI_render"
post_c_path = pti_render / "post_c.npy"
post_mp4_path = pti_render / "post.mp4"

# Training

In [30]:
%cd ../deps/neural-head-avatars

/home/dmitry/clones/PanoHead-clostra/deps/neural-head-avatars


In [42]:
from argparse import ArgumentParser
from configargparse import ArgumentParser as ConfigArgumentParser
sys.path.insert(0, str(Path("../..").resolve()))
from distillation.data import PanoheadPostDataModule
from distillation.model import NHAStaticTrainer

In [40]:
from nha.util.general import dict_2_device, stack_dicts

In [42]:
args = [
    '--config', 'configs/optimize_avatar_mesh_guidance.ini', 
    '--default_root_dir', str(path / 'ds' / 'results'),
    '--data_path', str(path / 'ds'),
    '--pti_out_path', str(path / 'pti_out'),
    '--load_threeddfa', str(path / 'dataset.json'),
    '--gpus', '1',
]

parser = ArgumentParser()
parser = NHAStaticTrainer.add_argparse_args(parser)
parser = PanoheadPostDataModule.add_argparse_args(parser)
parser = pl.Trainer.add_argparse_args(parser)

parser = ConfigArgumentParser(parents=[parser], add_help=False)
parser.add_argument('--config', required=True, is_config_file=True)
parser.add_argument("--checkpoint_file", type=str, required=False, default="",
                    help="checkpoint to load model from")

args = parser.parse_args() if args is None else parser.parse_args(args)
args.replace_sampler_ddp = False
args.load_flame = False
args.load_camera = False
# args.flame_lr = list(map(lambda x: x * 5, args.flame_lr))

args_dict = vars(args)

In [43]:
nha_static = NHAStaticTrainer(**args_dict)
data = PanoheadPostDataModule(**args_dict)
data.setup()



In [47]:
from pytorch_lightning.loggers import TensorBoardLogger
experiment_logger = TensorBoardLogger(args_dict["default_root_dir"],
                                          name="lightning_logs")
trainer = pl.Trainer.from_argparse_args(args, callbacks=nha_static.callbacks, max_epochs=1000, logger=experiment_logger)

trainer.fit(
    nha_static,
    train_dataloader=data.train_dataloader(batch_size=1),
    val_dataloaders=data.val_dataloader(batch_size=3)
)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores

  | Name            | Type               | Params
-------------------------------------------------------
0 | _flame          | FlameHead          | 0     
1 | _offset_mlp     | OffsetMLP          | 243 K 
2 | _normal_encoder | SirenNormalEncoder | 542 K 
3 | _texture        | TextureMLP         | 362 K 
4 | _explFeatures   | MultiTexture       | 4.5 M 
5 | _leaky_hinge    | LeakyHingeLoss     | 0     
6 | _masked_L1      | MaskedCriterion    | 0     
-------------------------------------------------------
6.1 M     Trainable params
0         Non-trainable params
6.1 M     Total params
24.360    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [46]:
batch = next(iter(data.train_dataloader(3)))
batch = dict_2_device(batch, nha_static.device)

In [47]:
nha_static.train()
nha_static.to('cuda')
batch = nha_static.prepare_batch(batch)
loss, loss_dict = nha_static.flame_step(batch)

In [48]:
opt = nha_static.configure_optimizers()

In [51]:
loss.backward()

In [53]:
opt.step()

In [56]:
nha_static._expr

Parameter containing:
tensor([[-3.3870e-04, -2.0333e-04, -1.5875e-04,  4.7671e-05, -1.4735e-04,
          4.1756e-06,  3.3654e-06, -1.4169e-04, -1.6423e-04, -8.9856e-05,
         -1.0247e-05,  1.5844e-05,  5.2633e-05,  5.8535e-10,  2.1190e-05,
         -2.0835e-05,  2.1543e-05, -3.4704e-05,  1.3494e-05, -7.3948e-06,
          2.2493e-06, -9.7095e-06, -2.1290e-05,  7.8369e-06, -3.9857e-06,
         -2.5353e-05,  8.2815e-06,  2.0309e-06,  1.0677e-06, -5.3065e-06,
          1.1174e-06,  4.5276e-06, -1.3084e-05,  2.6055e-06,  7.6864e-06,
         -1.4949e-05,  1.1948e-06,  6.7290e-06, -8.4470e-06,  1.0327e-05,
         -5.4414e-06,  1.0377e-05, -4.0657e-07,  4.0605e-06, -1.5834e-05,
         -2.5066e-06, -8.8227e-06,  2.6279e-06,  2.6711e-06,  1.9540e-06,
          6.3494e-06,  1.3083e-06, -2.5024e-06, -1.2698e-06, -3.2463e-06,
          6.1304e-06,  3.4501e-06, -8.2521e-06, -2.9414e-06, -4.3097e-06,
          2.8530e-06,  1.0158e-05,  3.7476e-06,  3.8932e-06, -1.0727e-05,
          8.9288

# Visualizations

In [100]:
def camera_frustum(c2w, K, depth=1):
    camera_origin = c2w[:3, 3]
    ndc_corners = np.array([
        [0, 1, 1],  # Bottom left
        [0, 0, 1],   # Bottom right
        [1, 0, 1],    # Top right
        [1, 1, 1]    # Top left
    ])
    screen_corners_camera_space = np.dot(np.linalg.inv(K), ndc_corners.T).T
    screen_corners_camera_space *= depth
    screen_corners_camera_space = np.concatenate([screen_corners_camera_space, np.ones((4, 1))], axis=1)
    screen_corners_world_space = np.dot(c2w, screen_corners_camera_space.T).T[:, :3]
    frustum_points = np.vstack([camera_origin, screen_corners_world_space])
    
    return frustum_points

In [101]:
import plotly.graph_objects as go

In [102]:
def add_camera_frustum_trace(fig, c2w, K, depth=1):
    frustum_points = camera_frustum(c2w, K, depth)
    
    # Add the frustum lines
    for i in range(1, 5):
        fig.add_trace(go.Scatter3d(x=[frustum_points[0, 0], frustum_points[i, 0]],
                                y=[frustum_points[0, 1], frustum_points[i, 1]],
                                z=[frustum_points[0, 2], frustum_points[i, 2]],
                                mode='lines+markers',
                                line=dict(color='blue', width=2),
                                marker=dict(size=4, color='red')))

    # Connect the corners of the frustum
    corner_indices = [1, 2, 3, 4, 1]
    fig.add_trace(go.Scatter3d(x=frustum_points[corner_indices, 0],
                            y=frustum_points[corner_indices, 1],
                            z=frustum_points[corner_indices, 2],
                            mode='lines',
                            line=dict(color='blue', width=2, dash='dash')))

In [105]:

# Create a figure
fig = go.Figure()

point_cloud = lmk3d_gt

# Extract x, y, z coordinates from the array
x, y, z = point_cloud[:,0], point_cloud[:,1], point_cloud[:,2]
fig.add_trace(go.Scatter3d(x=x, y=y, z=z, mode='markers',
                                   marker=dict(size=5, opacity=0.8)))

point_cloud = lmks3d.detach().cpu().numpy()

x, y, z = point_cloud[:,0], point_cloud[:,1], point_cloud[:,2]
fig.add_trace(go.Scatter3d(x=x, y=y, z=z, mode='markers',
                                   marker=dict(size=5, opacity=0.8)))


# for i in range(0, len(front_idx), 5):
#     add_camera_frustum_trace(fig, c2w[front_idx[i]], K[front_idx[i]], depth=0.6)
# # Update figure layout
# fig.update_layout(title='Camera Frustum Visualization',
#                   scene=dict(xaxis=dict(title='X'),
#                              yaxis=dict(title='Y'),
#                              zaxis=dict(title='Z')),
#                   margin=dict(l=0, r=0, b=0, t=0))

# Show the figure
fig.show()
