In [120]:
%load_ext autoreload
%autoreload 2
import os
import torch as tc
import numpy as np
import torch as tc
tc.set_default_tensor_type(tc.FloatTensor)
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt
import matplotlib 
matplotlib.rcParams['pdf.fonttype'] = 'truetype'
fontProperties = {'family': 'serif', 'serif': ['Helvetica'], 'weight': 'normal', 'size': 12}
plt.rc('font', **fontProperties)
from matplotlib import gridspec
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.ticker as mtick
import dxchange

import array_ops_test_gpu_rev_rot


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [121]:
def get_cooridnates_stack_for_rotation(array_size, axis=0):
    image_center = [(x - 1) / 2 for x in array_size]
    coords_ls = []
    for this_axis, s in enumerate(array_size):
        if this_axis != axis:
            coord = np.arange(s)
            for i in range(len(array_size)):
                if i != axis and i != this_axis:
                    other_axis = i
                    break
            if other_axis < this_axis:
                coord = np.tile(coord, array_size[other_axis])
            else:
                coord = np.repeat(coord, array_size[other_axis])
            coords_ls.append(coord - image_center[i])
    coord_new = np.stack(coords_ls)
    return coord_new

In [122]:
array_size = [8,8,8]
coord_new = get_cooridnates_stack_for_rotation(array_size)
print(coord_new)

[[-3.5 -3.5 -3.5 -3.5 -3.5 -3.5 -3.5 -3.5 -2.5 -2.5 -2.5 -2.5 -2.5 -2.5
  -2.5 -2.5 -1.5 -1.5 -1.5 -1.5 -1.5 -1.5 -1.5 -1.5 -0.5 -0.5 -0.5 -0.5
  -0.5 -0.5 -0.5 -0.5  0.5  0.5  0.5  0.5  0.5  0.5  0.5  0.5  1.5  1.5
   1.5  1.5  1.5  1.5  1.5  1.5  2.5  2.5  2.5  2.5  2.5  2.5  2.5  2.5
   3.5  3.5  3.5  3.5  3.5  3.5  3.5  3.5]
 [-3.5 -2.5 -1.5 -0.5  0.5  1.5  2.5  3.5 -3.5 -2.5 -1.5 -0.5  0.5  1.5
   2.5  3.5 -3.5 -2.5 -1.5 -0.5  0.5  1.5  2.5  3.5 -3.5 -2.5 -1.5 -0.5
   0.5  1.5  2.5  3.5 -3.5 -2.5 -1.5 -0.5  0.5  1.5  2.5  3.5 -3.5 -2.5
  -1.5 -0.5  0.5  1.5  2.5  3.5 -3.5 -2.5 -1.5 -0.5  0.5  1.5  2.5  3.5
  -3.5 -2.5 -1.5 -0.5  0.5  1.5  2.5  3.5]]


In [123]:
def calculate_original_coordinates_for_rotation(array_size, coord_new, theta, dev=None):
    image_center = [(x - 1) / 2 for x in array_size]
    m0 = tc.tensor([tc.cos(theta), -tc.sin(theta)], device=dev)
    m1 = tc.tensor([tc.sin(theta), tc.cos(theta)], device=dev)
    m_rot = tc.stack([m0, m1])

    coord_old = tc.matmul(m_rot, coord_new)
    coord1_old = coord_old[0, :] + image_center[1]
    coord2_old = coord_old[1, :] + image_center[2]
    coord_old = np.stack([coord1_old, coord2_old], axis=1)
    return coord_old

In [124]:
theta = tc.tensor(1.0)
array_size = [8,8,8]
coord_new = tc.from_numpy(get_cooridnates_stack_for_rotation(array_size, axis=0).astype(np.float32))
print(coord_new.shape)
coord_old = calculate_original_coordinates_for_rotation(array_size, coord_new, theta)
print(coord_old.shape)

torch.Size([2, 64])
(64, 2)


In [125]:
def save_rotation_lookup(array_size, theta_ls, dest_folder=None):

    # create matrix of coordinates
    coord_new = tc.from_numpy(get_cooridnates_stack_for_rotation(array_size, axis=0).astype(np.float32))

    n_theta = len(theta_ls)
    if dest_folder is None:
        dest_folder = 'arrsize_{}_{}_{}_ntheta_{}'.format(array_size[0], array_size[1], array_size[2], n_theta)
    if not os.path.exists(dest_folder):
        os.mkdir(dest_folder)
    for i, theta in enumerate(theta_ls):  #changed from theta_ls[rank:n_theta:n_ranks]
        coord_old = calculate_original_coordinates_for_rotation(array_size, coord_new, theta)
        coord_inv = calculate_original_coordinates_for_rotation(array_size, coord_new, -theta)
        # coord_old_ls are the coordinates in original (0-deg) object frame at each angle, corresponding to each
        # voxel in the object at that angle.
        np.save(os.path.join(dest_folder, '{:.5f}'.format(theta)), coord_old.astype('float16'))
        np.save(os.path.join(dest_folder, '_{:.5f}'.format(theta)), coord_inv.astype('float16'))
    return None

In [126]:
dev= "cpu"
rank = 1
n_ranks = 4
array_size = [8,8,8]
theta_ls = tc.tensor([0.,1.,2.])
save_rotation_lookup(array_size, theta_ls, "rotation_look_up")

In [127]:
def read_origin_coords(src_folder, theta, reverse=False):

    if not reverse:
        coords = np.load(os.path.join(src_folder, '{:.5f}.npy'.format(theta)), allow_pickle=True)
    else:
        coords = np.load(os.path.join(src_folder, '_{:.5f}.npy'.format(theta)), allow_pickle=True)
    return coords

In [128]:
coords = read_origin_coords("rotation_look_up", 1.)
print(coords.shape)

(64, 2)


In [133]:
def apply_rotation_transpose(obj, coord_old, interpolation='bilinear', axis=0, device=None):
    """
    Find the result of applying the transpose of the rotation-interpolation matrix defined by coord_old. Used to
    calculate the VJP of rotation operation.
    :param obj: Tensor.
    :param coord_old: The same variable as is passed to apply_rotation.
    """
    obj = obj.permute(1,2,3,0)
    s = obj.shape
    axes_rot = []
    for i in range(len(obj.shape)):
        if i != axis and i <= 2:
            axes_rot.append(i)

    coord_old_1 = coord_old[:, 0]
    coord_old_2 = coord_old[:, 1]

    # Clip coords, so that edge values are used for out-of-array indices
    coord_old_1 = tc.clamp(coord_old_1, 0, s[axes_rot[0]] - 1)
    coord_old_2 = tc.clamp(coord_old_2, 0, s[axes_rot[1]] - 1)

    coord_old_floor_1 = tc.floor(coord_old_1).type(tc.int64)
    coord_old_ceil_1 = coord_old_floor_1 + 1
    coord_old_floor_2 = tc.floor(coord_old_2).type(tc.int64)
    coord_old_ceil_2 = coord_old_floor_2 + 1

    obj_rot = tc.zeros_like(obj, requires_grad=False)
    fac_ff = (coord_old_ceil_1 - coord_old_1) * (coord_old_ceil_2 - coord_old_2)
    fac_fc = (coord_old_ceil_1 - coord_old_1) * (coord_old_2 - coord_old_floor_2)
    fac_cf = (coord_old_1 - coord_old_floor_1) * (coord_old_ceil_2 - coord_old_2)
    fac_cc = (coord_old_1 - coord_old_floor_1) * (coord_old_2 - coord_old_floor_2)
    print(fac_cc.shape)
    
    fac_ff = tc.stack([fac_ff] * s[-1], axis=1)
    fac_fc = tc.stack([fac_fc] * s[-1], axis=1)
    fac_cf = tc.stack([fac_cf] * s[-1], axis=1)
    fac_cc = tc.stack([fac_cc] * s[-1], axis=1)
    print(fac_cc.shape)
    
    for i_slice in range(s[axis]):
        slicer_ff = [i_slice, i_slice, i_slice]
        slicer_ff[axes_rot[0]] = coord_old_floor_1
        slicer_ff[axes_rot[1]] = coord_old_floor_2
        
        slicer_fc = [i_slice, i_slice, i_slice]
        slicer_fc[axes_rot[0]] = coord_old_floor_1
        slicer_fc[axes_rot[1]] = tc.clamp(coord_old_ceil_2, 0, s[axes_rot[1]] - 1)
        
        slicer_cf = [i_slice, i_slice, i_slice]
        slicer_cf[axes_rot[0]] = tc.clamp(coord_old_ceil_1, 0, s[axes_rot[0]] - 1)
        slicer_cf[axes_rot[1]] = coord_old_floor_2
        
        slicer_cc = [i_slice, i_slice, i_slice]
        slicer_cc[axes_rot[0]] = tc.clamp(coord_old_ceil_1, 0, s[axes_rot[0]] - 1)
        slicer_cc[axes_rot[1]] = tc.clamp(coord_old_ceil_2, 0, s[axes_rot[1]] - 1)

        slicer_obj = [slice(None), slice(None), slice(None)]
        slicer_obj[axis] = i_slice
        print(slicer_obj)
        obj_slice = tc.reshape(obj[slicer_obj], [-1, s[-1]])  # originally: [-1,2]
        obj_rot[tuple(slicer_ff)] += obj_slice * fac_ff
        obj_rot[tuple(slicer_fc)] += obj_slice * fac_fc
        obj_rot[tuple(slicer_cf)] += obj_slice * fac_cf
        obj_rot[tuple(slicer_cc)] += obj_slice * fac_cc
        
    obj = obj.permute(3,0,1,2)
    return obj_rot

In [134]:
obj = tc.ones(3,8,8,8)
coord_old = tc.from_numpy(read_origin_coords("rotation_look_up", 1.0, reverse=True)).type(tc.float)
device = "cpu"

In [135]:
obj_rot = apply_rotation_transpose(obj, coord_old, interpolation='bilinear', axis=0, device=None)

torch.Size([64])
torch.Size([64, 3])
[0, slice(None, None, None), slice(None, None, None)]
[1, slice(None, None, None), slice(None, None, None)]
[2, slice(None, None, None), slice(None, None, None)]
[3, slice(None, None, None), slice(None, None, None)]
[4, slice(None, None, None), slice(None, None, None)]
[5, slice(None, None, None), slice(None, None, None)]
[6, slice(None, None, None), slice(None, None, None)]
[7, slice(None, None, None), slice(None, None, None)]
