# 1. Import 

In [1]:
import os, sys
sys.path.append(os.path.abspath(''))
import torch
import matplotlib.pyplot as plt
from StylizedModel.IO.Load import load
# Util function for loading meshes

# Data structures and functions for rendering
from pytorch3d.structures import Meshes
from pytorch3d.vis.plotly_vis import AxisArgs, plot_batch_individually, plot_scene
from pytorch3d.vis.texture_vis import texturesuv_image_matplotlib
from pytorch3d.renderer import (
    look_at_view_transform,
    FoVPerspectiveCameras, 
    PointLights, 
    DirectionalLights, 
    Materials, 
    RasterizationSettings, 
    MeshRenderer, 
    MeshRasterizer,  
    SoftPhongShader,
    TexturesUV,
    TexturesVertex
)
from Generator.GenMesh import GenMesh
from tqdm.notebook import tqdm
from torch import tensor

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# 2. Create detector

In [2]:
from face_detector.detect_face import create_detector
import matplotlib.pyplot as plt
import cv2

detector = create_detector('yolov3')
image = cv2.imread('data/images/0.png')
b, g, r = cv2.split(image)
image = cv2.merge([r, g, b])
image = cv2.resize(image, (2048, 2048))
trg_image = tensor(image).to(device)
plt.figure(figsize=(10, 10))
implot = plt.imshow(image)
plt.axis("off")
preds = detector(image)
keypoints = preds[0]['keypoints']

plt.scatter(keypoints[..., 0], keypoints[..., 1], color='r', marker='o')
plt.show()

/home/yaosy/.cache/torch/hub/checkpoints/mmdet_anime-face_yolov3.pth
/home/yaosy/.cache/torch/hub/checkpoints/mmpose_anime-face_hrnetv2.pth
load checkpoint from local path: /home/yaosy/.cache/torch/hub/checkpoints/mmpose_anime-face_hrnetv2.pth


RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

# 3. Do IT

In [None]:

# Initialize a camera.
# With world coordinates +Y up, +X left and +Z in, the front of the cow is facing the -Z direction. 
# So we move the camera by 180 in the azimuth direction so it is facing the front of the cow. 
R, T = look_at_view_transform(3.0, 2.0, 0.0, device = device) 
R = R.to(device)
T = T.to(device)
fov = tensor(60.0, requires_grad=True, device=device)

cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov = fov)
# Define the settings for rasterization and shading. Here we set the output image to be of size
# 512x512. As we are rendering images for visualization purposes only we will set faces_per_pixel=1
# and blur_radius=0.0. We also set bin_size and max_faces_per_bin to None which ensure that 
# the faster coarse-to-fine rasterization method is used. Refer to rasterize_meshes.py for 
# explanations of these parameters. Refer to docs/notes/renderer.md for an explanation of 
# the difference between naive and coarse-to-fine rasterization. 
raster_settings = RasterizationSettings(
    image_size=2048, 
    blur_radius=0.0, 
    faces_per_pixel=1, 
)

# Place a point light in front of the object. As mentioned above, the front of the cow is facing the 
# -z direction. 
lights = PointLights(device=device, location=[[2.0, 3.0, 3.0]])

# Create a Phong renderer by composing a rasterizer and a shader. The textured Phong shader will 
# interpolate the texture uv coordinates for each vertex, sample from a texture image and 
# apply the Phong lighting model

from Shader.MySoftShader import SoftShader

renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=cameras, 
        raster_settings=raster_settings
    ),
    shader=SoftShader(
        device=device, 
        cameras=cameras,
    )
)
mesh = load("data/models/template/template.fbx", device)
mesh_gen: GenMesh = torch.load("data/generators/mesh_gen_rk20.pt")

identity = torch.zeros(20).to(device)
expression = torch.zeros(6).to(device)


In [None]:

R.requires_grad = True
T.requires_grad = True
identity.requires_grad = True
expression.requires_grad = True

optimizer = torch.optim.Adam([R, T, identity, expression, fov], lr = 1e-3)

for i in tqdm(range(20000)):
    
    optimizer.zero_grad()
    mesh.verts = mesh_gen.get(identity, expression)

    mesh_landmarks = mesh_gen.get_landmarks(identity, expression)
    mesh_landmarks = cameras.transform_points_screen(mesh_landmarks, image_size=(2048, 2048), R = R, T = T, fov = fov)

    loss = (tensor(keypoints[..., :2]).to(device) - mesh_landmarks[..., :2]).pow(2.0).sum()
    loss += 5000.0 * identity.pow(2.0).sum()
    loss += 1000.0 * expression.pow(2.0).sum()
    
    loss.backward()
    optimizer.step()
    with torch.no_grad():
        expression.clamp_(0.0)
        fov.clamp_(10.0, 110.0)
    
    if i % 1000 == 0:
        print(identity)
        print(expression)
        print(fov)
        images = renderer(mesh.to_textured_p3d(), R = R, T = T, fov = fov)
        plt.figure(figsize=(10, 10))
        implot = plt.imshow(images[0, ..., :3].detach().cpu().numpy())
        plt.scatter(mesh_landmarks[..., 0].detach().cpu().numpy(), mesh_landmarks[..., 1].detach().cpu().numpy(), color='r', marker='o')
        plt.scatter(keypoints[..., 0], keypoints[..., 1], color='b', marker='o')
        plt.axis("off")
        plt.show()

from StylizedModel.IO.Save import(save)
save(mesh, 'tmp/wtf.fbx')

In [None]:

optimizer.zero_grad()
text_gen = torch.load("data/generators/text_gen_rk20.pt")
text_gen.to(device)

R = R.detach()
T = T.detach()
identity = identity.detach()
expression = expression.detach()
mesh.verts = mesh.verts.detach()

R.requires_grad = False
T.requires_grad = False
identity.requires_grad = False
expression.requires_grad = False

ptexture = torch.zeros(8, requires_grad=True, device = device)
add_texture = torch.zeros((1024, 1024, 3), requires_grad=True, device = device)
optimizer = torch.optim.AdamW([ptexture, add_texture], lr = 1e-2)

for i in tqdm(range(200)):
    
    optimizer.zero_grad()
    
    texture = text_gen.get(ptexture) + add_texture.clamp(-0.01, 0.01)
    image = renderer(mesh.to_textured_p3d(texture), R = R, T = T, fov = fov)[0]
    loss = ((image[..., :3] - trg_image / 256) * image[..., 3:]).pow(2.0).sum()
    loss += 100 * ptexture.pow(2.0).sum()
    loss.backward()
    optimizer.step()
    
    if i % 40 == 0:
        print(ptexture)
        plt.figure(figsize=(10, 10))
        implot = plt.imshow(image[..., :3].clone().detach().cpu().numpy() )
        #implot = plt.imshow((image[..., :3].clone().detach().cpu().numpy() + trg_image.detach().cpu().numpy() / 256) / 2.0)
        # plt.scatter(mesh_landmarks[..., 0].detach().cpu().numpy(), mesh_landmarks[..., 1].detach().cpu().numpy(), color='r', marker='o')
        # plt.scatter(keypoints[..., 0], keypoints[..., 1], color='b', marker='o')
        plt.axis("off")
        plt.show()

cv2.imwrite("./tmp/save.jpg", texture.detach().cpu().numpy())