In [263]:
%matplotlib

import numpy as np
from scipy import ndimage
import scipy
import matplotlib.pyplot as plt
import torch
import skimage
import imageio
import glob


from multiPoseExtraction import MultiPoseExtraction
from PoseExtraction import PoseNormalization
from deepHOG import DeepHOG
import utils
import datasets as ds
from torch.autograd import Variable
from torchvision import datasets, transforms

def tensor_to_coord_tensor(X):
    """
    Takes in a 4D image and returns the 3*W*H coordinate representation
    """
    device = X.device
    batch_num, channel_num, height, width = X.shape
    flattened_img = X.permute((0, 1, 3, 2)).contiguous().view(batch_num, -1)
    x_coord = torch.arange(0, width).view(width, 1)
    x_coord = x_coord.expand(width, height).contiguous()
    x_coord = x_coord.view(width * height).float().expand(batch_num, -1).to(device)

    y_coord = torch.arange(height, 0, -1) - 1
    y_coord = y_coord.expand(width, height).contiguous()
    y_coord = y_coord.view(width * height).float().expand(batch_num, -1).to(device)
        
    coord_matrix = torch.stack([x_coord, y_coord, flattened_img])

    return coord_matrix.permute((1, 0, 2))


class Noise(torch.nn.Module):
    def __init__(self, scale=0.01):
        super(Noise, self).__init__()
        self.scale = scale

    def forward(self, x):
        device = x.device
        noise = (torch.randn_like(x)*self.scale).to(device)
        return x + noise
    

def plot_arrow_img(ax, means, orientations, img_shape, arrow_scale=2, color=(1, 0, 0), alpha=0.8):
    mean_x = means[0].cpu().data.numpy()
    mean_y = means[1].cpu().data.numpy()

    rot = orientations.cpu().data.numpy()
    arrow_start = (mean_x, img_shape[1] - mean_y)
    arrow_end = (rot[0]*arrow_scale, -1*rot[1]*arrow_scale)

    #         ax[index % row_length].arrow(arrow_start[0], arrow_start[1], arrow_end[0], arrow_end[1], 
    #                     head_width=0.5, head_length=1, fc='red', ec='r', linewidth=4, alpha=1)

    ax.arrow(arrow_start[0], arrow_start[1], arrow_end[0], arrow_end[1], 
                head_width=arrow_scale/8, head_length=arrow_scale/8, fc='red', ec=color, linewidth=4, alpha=alpha)




img_files = glob.glob('./data/HeadPoseImageDatabase/Person02/*.jpg')

# im = imageio.imread(img_files[np.random.randint(len(img_files))], as_gray=True)
# im = im[20:, 20:-20]

im = np.zeros([800, 800])
im[10:-700, 480:-20] = 1

im[650:-20, 20:-600] = 1

# im = ndimage.rotate(im, 34, mode='constant')
im = ndimage.gaussian_filter(im, 1)

sx = ndimage.sobel(im, axis=0, mode='constant')
sy = ndimage.sobel(im, axis=1, mode='constant')
sob = np.hypot(sx, sy)
# sob = im

sob += 0.00*np.random.random(im.shape)

def left_svd_tensor(T):
    eps = 0.001
    device = T.device
    # X = torch.bmm(T, torch.bmm(self.weights.expand(T.shape[0], *self.weights.shape).to(device), T.permute(0, 2, 1)))
    X = torch.bmm(T, T.permute(0, 2, 1))

    D = X[:, 0, 0]* X[:, 1, 1] - X[:, 0, 1]*X[:, 1, 0]
    D = D.unsqueeze(1).unsqueeze(2)
    X = torch.mul((1 / (D + 0.001)), X)

    T = X[:, 0, 0] + X[:, 1, 1]
    D = X[:, 0, 0] * X[:, 1, 1] - X[:, 0, 1]*X[:, 1, 0]

    L1 = (T + torch.sqrt(torch.nn.functional.relu((T**2) - 4*D))) / 2
    L2 = (T - torch.sqrt(torch.nn.functional.relu((T**2) - 4*D))) / 2

    v1 = torch.stack([L1 - X[:, 1, 1], X[:, 1, 0]])
    v2 = torch.stack([L2 - X[:, 1, 1], X[:, 1, 0]])
    U = torch.stack([v1, v2]).permute(2, 0, 1)
    U = torch.nn.functional.normalize(U,  dim=2)
#     confidence = self.strength_cofficient(L1, L2)
    confidence = 1
    return U, confidence, L1, L2


def get_orientation_vectors(img, mean):
    epsilon = 0.0001
    coord_tensor = tensor_to_coord_tensor(torch.from_numpy(img).unsqueeze(0).unsqueeze(1).float())
    W = torch.abs(coord_tensor[:,2:3,:])
    X = coord_tensor[:, :2, :]

    WX = torch.mul(W, X)
    
    mu_avg = (torch.sum(WX, dim=2) / torch.sum(W, dim=2) + epsilon).unsqueeze(2)
    mu_W = torch.from_numpy(mean).unsqueeze(0).unsqueeze(2).float()
    print(mu_W.shape)
    print(X.shape)

    XC = X - mu_W
    # print(XC.shape)
    XC = XC / ((torch.norm(XC, p=1, dim=1)**2))
#     XC = XC / (0.1*torch.norm(XC, dim=1))
#     XC = XC

    WXC = torch.mul(torch.sqrt(W), XC)
    WXC = WXC / WXC.sum()

    orientations, confidence, L1, L2 = left_svd_tensor(WXC)
    
    return orientations, (L1, L2), mu_avg

# get_orientation_vectors(sob, np.array((50, 50)))

img_shape = sob.shape

fig ,ax = plt.subplots()
ax.imshow(sob)
ax.axis('off')
ax.set_title('Sobel filter', fontsize=20)

def onclick(event):
    mean = np.asarray([event.xdata, img_shape[1] - event.ydata])
    orientation, (L1, L2), mu_avg = get_orientation_vectors(sob, mean.copy())
    print(mu_avg.shape)
    mean_position = mu_avg[0, :, 0].numpy()
    arrow_scale = (L1 / L2).item()*20
    orientation1 = orientation[0, 0, :]
    ax.plot(event.xdata, event.ydata, 'r*')
    ax.scatter(mean_position[0], img_shape[1] - mean_position[1], s=200, c='g')
    plot_arrow_img(ax, torch.Tensor(mean), orientation1, img_shape, color='y', arrow_scale=arrow_scale, alpha=0.8)
    plt.draw() #redraw

cid = fig.canvas.mpl_connect('button_press_event', onclick)



Using matplotlib backend: TkAgg
torch.Size([1, 2, 1])
torch.Size([1, 2, 640000])
torch.Size([1, 2, 1])
torch.Size([1, 2, 1])
torch.Size([1, 2, 640000])
torch.Size([1, 2, 1])
torch.Size([1, 2, 1])
torch.Size([1, 2, 640000])
torch.Size([1, 2, 1])
torch.Size([1, 2, 1])
torch.Size([1, 2, 640000])
torch.Size([1, 2, 1])
torch.Size([1, 2, 1])
torch.Size([1, 2, 640000])
torch.Size([1, 2, 1])
torch.Size([1, 2, 1])
torch.Size([1, 2, 640000])
torch.Size([1, 2, 1])
torch.Size([1, 2, 1])
torch.Size([1, 2, 640000])
torch.Size([1, 2, 1])
torch.Size([1, 2, 1])
torch.Size([1, 2, 640000])
torch.Size([1, 2, 1])
torch.Size([1, 2, 1])
torch.Size([1, 2, 640000])
torch.Size([1, 2, 1])
torch.Size([1, 2, 1])
torch.Size([1, 2, 640000])
torch.Size([1, 2, 1])
torch.Size([1, 2, 1])
torch.Size([1, 2, 640000])
torch.Size([1, 2, 1])
torch.Size([1, 2, 1])
torch.Size([1, 2, 640000])
torch.Size([1, 2, 1])
torch.Size([1, 2, 1])
torch.Size([1, 2, 640000])
torch.Size([1, 2, 1])
torch.Size([1, 2, 1])
torch.Size([1, 2, 64000