In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.autograd
from torch.autograd import Variable
import numpy as np
from scipy.io import loadmat
from scipy.special import sph_harm
import math
import matplotlib.pyplot as plt
from shapely.geometry import Point
from shapely.geometry.polygon import Polygon
from time import time

In [5]:
# torch.device("cuda")
torch.set_default_tensor_type('torch.cuda.FloatTensor')

In [6]:
no_of_ver = 53215

In [7]:
fileName='Dataset/Coarse_Dataset/Exp_Pca.bin'
with open(fileName, mode='rb') as file: # b is important -> binary
#     fileContent = file.read()
    dim_exp = np.fromfile(file, dtype=np.int32, count=1)
    mu_exp = np.zeros(no_of_ver*3)
    base_exp = np.zeros((no_of_ver*3*dim_exp[0]), dtype=float)
    mu_exp = np.fromfile(file, dtype=np.float32, count=3*no_of_ver)
    base_exp = np.fromfile(file, dtype=np.float32, count=3*no_of_ver*dim_exp[0])

In [8]:
A_exp = torch.tensor(np.array(np.resize(base_exp, (no_of_ver*3, dim_exp[0]))), dtype=torch.float32, requires_grad = True)

In [9]:
data = np.loadtxt('Dataset/Coarse_Dataset/std_exp.txt', delimiter=' ')
data = torch.tensor(data[:,np.newaxis], dtype = torch.float32,requires_grad =True )

In [10]:
temp = loadmat('Dataset/3DDFA_Release/Matlab/ModelGeneration/model_info.mat')
temp['tri'].dtype = np.int16
trimIndex = np.array(temp['trimIndex'][:,0], dtype=np.int32)
trim_ind = np.reshape(np.array([3*trimIndex-2,3*trimIndex-1,3*trimIndex])-1,(no_of_ver*3,),'F')
tri_mesh_data = torch.tensor(temp['tri'].T, dtype=torch.long)

In [11]:
print("PyTorch version: ", torch.__version__ )
print("CUDA available: ", torch.cuda.is_available())
print("CUDA version: ", torch.version.cuda)

PyTorch version:  1.1.0
CUDA available:  True
CUDA version:  9.0.176


In [12]:
morph_model = loadmat('Dataset/PublicMM1/01_MorphableModel.mat')

In [13]:
shapePCA = morph_model['shapePC']
shapeMU = morph_model['shapeMU']
shapeSTD = morph_model['shapeEV']

texPCA = morph_model['texPC']
texMU = morph_model['texMU']
texSTD = morph_model['texEV']

In [14]:
p_mu = torch.tensor(shapeMU[trim_ind], requires_grad=True)
b_mu = torch.tensor(texMU[trim_ind], requires_grad = True)
A_alb = torch.tensor(texPCA[trim_ind,:100], requires_grad =True)
A_id = torch.tensor(shapePCA[trim_ind,:100], requires_grad =True)
std_id = torch.tensor(shapeSTD[:100], requires_grad = True)
std_alb = torch.tensor(texSTD[:100], requires_grad = True)
std_exp = data

In [15]:
def P(l, m, x):
    pmm = 1.0
    if m>0:
        somx2 = torch.sqrt((1.0-x)*(1.0+x))
        fact = 1.0
        for i in torch.arange(m):
            pmm = -fact*pmm*somx2
            fact = fact+2.0
    if l==m :
        return pmm
    pmmp1 = x*(2.0*m+1.0)*pmm
    if (l==m+1):
        return pmmp1
    pll = 0.0
    for ll in torch.arange(m+2, l+1):
        pll = ((2.0*ll-1.0)*x*pmmp1 - (ll+m-1.0)*pmm)/(ll-m)
        pmm = pmmp1
        pmmp1 = pll
    return pll

def factorial(n):
    return torch.prod(torch.arange(1,n+1), dtype=torch.float32)

def K(l, m):
    norm_const = ((2.0*l+1.0)*factorial(l-m))/((4.0*np.pi)*factorial(l+m))
    return torch.sqrt(norm_const)

def SH(m, l, phi, theta):
    sqrt2 = torch.sqrt(torch.tensor(2.0, dtype=torch.float32))
    if m==0:
        return K(l,0)*P(l,m,torch.cos(theta))
    elif m>0:
        return sqrt2*K(l,m)*torch.cos(m*phi)*P(l,m,torch.cos(theta))
    else:
        return sqrt2*K(l,-m)*torch.sin(-m*phi)*P(l,-m,torch.cos(theta))

In [16]:
def rot_mat(pitch, yaw, roll):
    Ry = torch.tensor([[torch.cos(pitch),0,torch.sin(pitch)],[0,1,0],[-torch.sin(pitch),0,torch.cos(pitch)]], requires_grad = True)
    Rx = torch.tensor([[1,0,0],[0,torch.cos(roll),torch.sin(roll)],[0,-torch.sin(roll),torch.cos(roll)]], requires_grad = True)
    Rz = torch.tensor([[torch.cos(yaw),torch.sin(yaw),0],[-torch.sin(yaw),torch.cos(yaw),0],[0,0,1]], requires_grad = True)
    R = Rz@Ry@Rx
    
    return R

def sh_basis(norm):
#     if torch.is_tensor(n):
#         norm = n.cpu().detach().numpy()
#     else:
#         norm = n
    theta = norm[1] #Polar angle
    phi = norm[0] #Azimuth angle
    sh = torch.zeros((9,), dtype=torch.float32, requires_grad=True)
    count = 0
    for l in torch.arange(3):
        for m in torch.arange(-l,l+1):
#             if m==0:
#                 sh[count]=np.real(sph_harm(m,l,phi,theta))
#             elif m>0:
#                 sh[count]=np.sqrt(2)*np.real(sph_harm(m,l,phi,theta))
#             else:
#                 sh[count]=np.sqrt(2)*np.imag(sph_harm(m,l,phi,theta))
#             count = count+1
                sh[count] = SH(m,l,phi,theta)
                count = count+1
            
    return sh

In [17]:
def barycentric_weights(p, tri_p):
    #http://blackpawn.com/texts/pointinpoly/
    
    a = tri_p[0,:]
    b = tri_p[1,:]
    c = tri_p[2,:]
    
    v0 = c - a
    v1 = b - a
    v2 = p - a
    
#     dot00 = torch.dot(v0,v0)
#     dot01 = torch.dot(v0,v1)
#     dot02 = torch.dot(v0,v2)
#     dot11 = torch.dot(v1,v1)
#     dot12 = torch.dot(v1,v2)
    
#     denom = dot00*dot11 - dot01*dot01
    
#     if denom == 0:
#         u = v = 0
#     else:
#         u = (dot11*dot02-dot01*dot12)/denom
#         v = (dot00*dot12-dot01*dot02)/denom
    
    A = torch.stack((v0,v1)).t()
    B = v2[:,None]
    
    X, LU = torch.solve(B,A)
    u = X[0,0]
    v = X[1,0]
        
    weights = torch.tensor([1-u-v, u, v], dtype=torch.float32, requires_grad = True)
        
    return weights

def world_to_image(q_world, h, w):
#     temp = np.array([w/2,h/2-h+1,0])
#     q_image = (q_world + temp)*[1,-1,1]
    q_image = q_world.clone()
    
    q_image[:,0] = q_image[:,0] + w/2
    q_image[:,1] = q_image[:,1] + h/2
    q_image[:,1] = h - q_image[:,1] - 1
    
    return q_image

def rasterize(q, q_depth, tri_mesh_data, h, w):
    depth_info = -math.inf*torch.ones((h,w))
#     depth_info = {}
    tri_ind_info = -torch.ones((h,w), dtype = torch.int32)
    bary_wts_info = torch.zeros((h,w,3), requires_grad = True)
    
#     for i in range(h):
#         for j in range(w):
#             depth_info[(i,j)] = -math.inf
#             bary_wts_info[(i,j)] = 0
    
    for i in range(len(tri_mesh_data)):
        print('Rasterizing triangle: ', i+1)
        tri_ver_ind = tri_mesh_data[i,:]
        
        umin = max(int(torch.ceil(torch.min(q[tri_ver_ind, 0]))), 0) #torch.min(lmks_2d[:,0])
        umax = min(int(torch.floor(torch.max(q[tri_ver_ind, 0]))), w-1) #torch.max(lmks_2d[:,0])
        
        vmin = max(int(torch.ceil(torch.min(q[tri_ver_ind, 1]))), 0) #torch.min(lmks_2d[:,0])
        vmax = min(int(torch.floor(torch.max(q[tri_ver_ind, 1]))), h-1)
        
        if umax<umin or vmax<vmin:
            continue
        else:
            for u in range(umin, umax+1):
                for v in range(vmin, vmax+1):
                    weights = barycentric_weights(torch.tensor([u, v], dtype = torch.float32, requires_grad =True), q[tri_ver_ind, :])
                    if (weights<0).all():
                        continue
                    else:
                        depth = torch.dot(weights, q_depth[tri_ver_ind])
                        if depth > depth_info[(u, v)]:
                            depth_info[u, v] = depth
                            tri_ind_info[u, v] = i
                            bary_wts_info[u, v,:] = weights
                            
    return tri_ind_info, bary_wts_info

def cart2sph(n):
#     phi = np.arctan2(n[1],n[0]) #arctan(y/x)
#     theta = np.arccos(n[2]) #arccos(z)
    if torch.is_tensor(n):
        norm = n.cpu().detach().numpy()
    else:
        norm = n
    phi = np.arctan2(norm[1],norm[0])
    theta = np.arccos(norm[2])
    return torch.tensor([phi, theta], dtype=torch.float32, requires_grad = True)

def calculate_normal(tri_ind_info, tri_mesh_data, centroid, q, h, w):
    normal_xyz = torch.zeros((h, w, 3), requires_grad = True)
    normal_sph = torch.zeros((h, w, 2), requires_grad = True)
    
    for i in range(h):
        for j in range(w):
            tri_ver = q[tri_mesh_data[tri_ind_info[i, j]-1, :],:]
            a = tri_ver[0,:]
            b = tri_ver[1,:]
            c = tri_ver[2,:]
            normal_xyz[i,j,:] = torch.cross(a-b, b-c)/torch.norm(torch.cross(a-b, b-c))
            if torch.dot(torch.mean(tri_ver, 0)-centroid, normal_xyz[i,j,:])<0:
                normal_xyz[i,j,:] *= -1
            normal_sph[i,j,:] = cart2sph(normal_xyz[i,j,:])
    return normal_sph

def render_color_image(q, albedo, tri_mesh_data, gamma, h, w):
    image = torch.zeros((h,w,3), dtype=torch.float32, requires_grad = True)
    alb = torch.zeros((h,w,3), dtype=torch.float32, requires_grad = True)
    centroid = torch.mean(q,0)
    
    st = time()
#     tri_ind_info, bary_wts_info = rasterize(q[:,:2], q[:,2], tri_mesh_data, h, w)
    tri_ind_info, bary_wts_info = rasterize_triangles(q, tri_mesh_data, h, w)
    print('Rasterization Done!- %f seconds' %(time()-st))
    
    n_sph = calculate_normal(tri_ind_info, tri_mesh_data, centroid, q, h, w)
    
#     for i in range(h):
#         for j in range(w):
#             sh_func = sh_basis(n_sph[i,j,:])
#             alb[i,j,:] = albedo[tri_mesh_data[tri_ind_info[i, j]-1, :],:].t()@bary_wts_info[(i,j)]
#             image[i,j,:] = alb[i,j,:]*(gamma.t()@sh_func.squeeze())
            
    return albedo, bary_wts_info

In [57]:
def isPointInTri(point, tri_points):
    ''' Judge whether the point is in the triangle
    Method:
        http://blackpawn.com/texts/pointinpoly/
    Args:
        point: (2,). [u, v] or [x, y] 
        tri_points: (3 vertices, 2 coords). three vertices(2d points) of a triangle. 
    Returns:
        bool: true for in triangle
    '''
    tp = tri_points

    # vectors
    v0 = tp[2,:] - tp[0,:]
    v1 = tp[1,:] - tp[0,:]
    v2 = point - tp[0,:]
    A = torch.stack((v0,v1)).t()
    B = v2[:,None]

    X = torch.inverse(A)@B

    u = X[0,0]
    v = X[1,0]
    # dot products
    dot00 = torch.dot(v0.t(), v0)
    dot01 = torch.dot(v0.t(), v1)
    dot02 = torch.dot(v0.t(), v2)
    dot11 = torch.dot(v1.t(), v1)
    dot12 = torch.dot(v1.t(), v2)

    # barycentric coordinates
    if dot00*dot11 - dot01*dot01 == 0:
        inverDeno = 0
    else:
        inverDeno = 1/(dot00*dot11 - dot01*dot01)

    u = (dot11*dot02 - dot01*dot12)*inverDeno
    v = (dot00*dot12 - dot01*dot02)*inverDeno

    # check if point in triangle
    return (u >= 0) & (v >= 0) & (u + v < 1)

def get_point_weight(point, tri_points):
    ''' Get the weights of the position
    Methods: https://gamedev.stackexchange.com/questions/23743/whats-the-most-efficient-way-to-find-barycentric-coordinates
     -m1.compute the area of the triangles formed by embedding the point P inside the triangle
     -m2.Christer Ericson's book "Real-Time Collision Detection". faster.(used)
    Args:
        point: (2,). [u, v] or [x, y] 
        tri_points: (3 vertices, 2 coords). three vertices(2d points) of a triangle. 
    Returns:
        w0: weight of v0
        w1: weight of v1
        w2: weight of v3
     '''
    tp = tri_points
    # vectors
    v0 = tp[2,:] - tp[0,:]
    v1 = tp[1,:] - tp[0,:]
    v2 = point - tp[0,:]
    
    A = torch.stack((v0,v1)).t()
    B = v2[:,None]
    
    X = torch.inverse(A)@B
    u = X[0,0]
    v = X[1,0]
    # dot products
#     dot00 = torch.dot(v0.t(), v0)
#     dot01 = torch.dot(v0.t(), v1)
#     dot02 = torch.dot(v0.t(), v2)
#     dot11 = torch.dot(v1.t(), v1)
#     dot12 = torch.dot(v1.t(), v2)

#     # barycentric coordinates
#     if dot00*dot11 - dot01*dot01 == 0:
#         inverDeno = 0
#     else:
#         inverDeno = 1/(dot00*dot11 - dot01*dot01)

#     u = (dot11*dot02 - dot01*dot12)*inverDeno
#     v = (dot00*dot12 - dot01*dot02)*inverDeno

    u = X[0,0]
    v = X[1,0]
    w0 = 1 - u - v
    w1 = v
    w2 = u
    return w0,w1,w2

def rasterize_triangles(vertices, triangles, h, w):
    ''' 
    Args:
        vertices: [nver, 3]
        triangles: [ntri, 3]
        h: height
        w: width
    Returns:
        depth_buffer: [h, w] saves the depth, here, the bigger the z, the fronter the point.
        triangle_buffer: [h, w] saves the tri id(-1 for no triangle). 
        barycentric_weight: [h, w, 3] saves corresponding barycentric weight.

    # Each triangle has 3 vertices & Each vertex has 3 coordinates x, y, z.
    # h, w is the size of rendering
    '''
    # initial 
    depth_buffer = {}#torch.zeros([h, w]) - 999999. #+ torch.min(vertices[2,:]) - 999999. # set the initial z to the farest position
    triangle_buffer = torch.zeros([h, w], dtype = torch.int32) - 1  # if tri id = -1, the pixel has no triangle correspondance
    barycentric_weight = {}#torch.zeros([h, w, 3], dtype = torch.float32)  # 
    
    for i in range(h):
        for j in range(w):
            depth_buffer[(i,j)] = -math.inf
            barycentric_weight[(i,j)] = 0
    
    for i in range(triangles.shape[0]):
        print('Rasterzing: ',i+1)
        tri = triangles[i, :] # 3 vertex indices

        # the inner bounding box
        umin = max(int(torch.ceil(torch.min(vertices[tri, 0]))), 0)
        umax = min(int(torch.floor(torch.max(vertices[tri, 0]))), w-1)

        vmin = max(int(torch.ceil(torch.min(vertices[tri, 1]))), 0)
        vmax = min(int(torch.floor(torch.max(vertices[tri, 1]))), h-1)

        if umax<umin or vmax<vmin:
            continue

        for u in range(umin, umax+1):
            for v in range(vmin, vmax+1):
                if not isPointInTri(torch.tensor([u, v],dtype = torch.float32, requires_grad=True), vertices[tri, :2]): 
                    continue
                w0, w1, w2 = get_point_weight(torch.tensor([u, v], dtype = torch.float32, requires_grad=True), vertices[tri, :2]) # barycentric weight
                point_depth = w0*vertices[tri[0], 2] + w1*vertices[tri[1], 2] + w2*vertices[tri[2], 2]
                if point_depth > depth_buffer[v, u]:
                    depth_buffer[(v, u)] = point_depth
                    triangle_buffer[v, u] = i
                    barycentric_weight[(v, u)] = torch.tensor([w0, w1, w2], dtype=torch.float32, requires_grad=True)

    return triangle_buffer, barycentric_weight

In [58]:
I_in = plt.imread('Dataset/300W-Convert/300W-Original/afw/134212_1.jpg')
I_in=torch.tensor(I_in[160:384,704:928,:],dtype = torch.float32, requires_grad = True)

In [59]:

h=w=224
chi = torch.rand(312,1,requires_grad = True, dtype=torch.float32)
t = torch.zeros(3,1, requires_grad = True, dtype=torch.float32)
count = 1
while True:
    print("Iteration No: ", count)
    chi_prev = chi
    
    al_id = chi_prev[0:100]
    al_exp = chi_prev[100:179]
    al_alb = chi_prev[179:279]
    [s, pitch, yaw, roll] = chi_prev[279:283,0]
    t[:2] = chi_prev[283:285]
    r = chi_prev[285:]
    gamma_r = r[:9]
    gamma_g = r[9:18]
    gamma_b = r[18:]
    gamma = torch.reshape(r,(3,9)).t()
    
    
    p = p_mu + torch.matmul(A_id,al_id) + torch.matmul(A_exp,al_exp)
    b = b_mu + torch.matmul(A_alb,al_alb)
    vertex = torch.reshape(p, (no_of_ver, 3))
    albedo = torch.reshape(b, (no_of_ver, 3))
    
    if count == 1:
        s = 150/(torch.max(vertex) - torch.min(vertex))
    
    R = rot_mat(pitch, yaw, roll)
    q_world = s*R@vertex.t() + t
    q_image = world_to_image(q_world.t(), h, w)
    print('Rendering Image...')
    st = time()
    albe, bary_wts_info = render_color_image(q_image, albedo, tri_mesh_data, gamma, h, w)
#     I_rend = render_color_image(q_image, albedo, tri_mesh_data, gamma, h, w)
#     print('Rendering Done! - %f seconds' %(time()-st) )
#     w_l = 100
#     w_r = 5e-5
#     #E=torch.norm(I_rend - I_in)**2 + torch.norm(al_id/std_id)**2 + w_r*torch.norm(al_alb/std_alb)**2 + torch.norm(al_exp/std_exp)**2
#     E_con_r = torch.tensor([torch.sqrt(1/28241)*torch.norm(I_rend - I_in)], dtype=torch.float32)
#     E_lan_r = torch.sqrt(w_l/self.no_of_lmks)*torch.norm(lmks_2d - q_image[lmks_3d_ind[0,:],:2], axis=1)
    
#     E = torch.cat((E_con_r,E_lan_r))
#     E.backward()
#     J = chi_prev.grad.t()
    break
#     chi_next = chi_prev - torch.pinverse(J.t()@J)@J.t()*E
#     err = torch.norm(chi_next - chi_prev)
#     chi_prev = chi_next
#     count=count+1
#     chi_prev.grad.zero_
#     print('Error: ', err)
#     if err<1:
#         break

Iteration No:  1
Rendering Image...
Rasterzing:  1
Rasterzing:  2
Rasterzing:  3
Rasterzing:  4
Rasterzing:  5
Rasterzing:  6
Rasterzing:  7
Rasterzing:  8
Rasterzing:  9
Rasterzing:  10
Rasterzing:  11
Rasterzing:  12
Rasterzing:  13
Rasterzing:  14
Rasterzing:  15
Rasterzing:  16
Rasterzing:  17
Rasterzing:  18
Rasterzing:  19
Rasterzing:  20
Rasterzing:  21
Rasterzing:  22
Rasterzing:  23
Rasterzing:  24
Rasterzing:  25
Rasterzing:  26
Rasterzing:  27
Rasterzing:  28
Rasterzing:  29
Rasterzing:  30
Rasterzing:  31
Rasterzing:  32
Rasterzing:  33
Rasterzing:  34
Rasterzing:  35
Rasterzing:  36
Rasterzing:  37
Rasterzing:  38
Rasterzing:  39
Rasterzing:  40
Rasterzing:  41
Rasterzing:  42
Rasterzing:  43
Rasterzing:  44
Rasterzing:  45
Rasterzing:  46
Rasterzing:  47
Rasterzing:  48
Rasterzing:  49
Rasterzing:  50
Rasterzing:  51
Rasterzing:  52
Rasterzing:  53
Rasterzing:  54
Rasterzing:  55
Rasterzing:  56
Rasterzing:  57
Rasterzing:  58
Rasterzing:  59
Rasterzing:  60
Rasterzing:  

Rasterzing:  512
Rasterzing:  513
Rasterzing:  514
Rasterzing:  515
Rasterzing:  516
Rasterzing:  517
Rasterzing:  518
Rasterzing:  519
Rasterzing:  520
Rasterzing:  521
Rasterzing:  522
Rasterzing:  523
Rasterzing:  524
Rasterzing:  525
Rasterzing:  526
Rasterzing:  527
Rasterzing:  528
Rasterzing:  529
Rasterzing:  530
Rasterzing:  531
Rasterzing:  532
Rasterzing:  533
Rasterzing:  534
Rasterzing:  535
Rasterzing:  536
Rasterzing:  537
Rasterzing:  538
Rasterzing:  539
Rasterzing:  540
Rasterzing:  541
Rasterzing:  542
Rasterzing:  543
Rasterzing:  544
Rasterzing:  545
Rasterzing:  546
Rasterzing:  547
Rasterzing:  548
Rasterzing:  549
Rasterzing:  550
Rasterzing:  551
Rasterzing:  552
Rasterzing:  553
Rasterzing:  554
Rasterzing:  555
Rasterzing:  556
Rasterzing:  557
Rasterzing:  558
Rasterzing:  559
Rasterzing:  560
Rasterzing:  561
Rasterzing:  562
Rasterzing:  563
Rasterzing:  564
Rasterzing:  565
Rasterzing:  566
Rasterzing:  567
Rasterzing:  568
Rasterzing:  569
Rasterzing:  5

KeyboardInterrupt: 

In [51]:
A = torch.tensor([[-0.1231,  0.3298],
        [ 0.2749,  0.1079]], dtype= torch.float32, requires_grad=True)

In [52]:
A.type(torch.cuda.FloatTensor)

tensor([[-0.1231,  0.3298],
        [ 0.2749,  0.1079]], requires_grad=True)

In [53]:
torch.pinverse(A)

RuntimeError: cuda runtime error (11) : invalid argument at /pytorch/aten/src/THC/generic/THCTensorMath.cu:35