In [1]:
import torch
import numpy as np

In [2]:
%cd ..

/nfs/homedirs/fuchsgru/MastersThesis


In [3]:
from model.orthogonal import *

In [54]:
def fasthpp(V, X, stop_recursion=3): 
  """
    V: matrix that represent weights of householder matrices (d, d)
    X: rectangular matrix (d, bs) to compute H(V) @ X
    stop_recursion: integer that controls how many merge iterations before recursion stops. 
    		    if None recursion continues until base case. 
  """
  d = V.shape[0]

  Y_ = V.clone().T
  W_ = -2*Y_.clone()

  # Only works for powers of two. 
  assert (d & (d-1)) == 0 and d != 0, "d should be power of two. You can just pad the matrix. " 

  # Step 1: compute (Y, W)s by merging! 
  k = 1
  for i, c in enumerate(range(int(np.log2(d)))):  
    k_2 = k 
    k  *= 2

    m1_ = Y_.view(d//k_2, k_2, d)[0::2] @ torch.transpose(W_.view(d//k_2, k_2, d)[1::2], 1, 2)
    m2_ = torch.transpose(W_.view(d//k_2, k_2, d)[0::2], 1, 2) @ m1_

    W_ = W_.view(d//k_2, k_2, d).clone()
    W_[1::2] += torch.transpose(m2_, 1, 2)
    W_ = W_.view(d, d)

    if stop_recursion is not None and c == stop_recursion: break

  # Step 2: 
  if stop_recursion is None:   return X + W_.T @ (Y_ @ X) 
  else: 
    # For each (W,Y) pair multiply with 
    for i in range(d // k-1, -1, -1 ):
      X = X + W_[i*k: (i+1)*k].T @ (Y_[i*k: (i+1)*k]  @ X )
    return X 

In [55]:
x = torch.randn(64, 72) # [D, batch]
x.requires_grad = False
v = 0.1 * torch.randn(64, 64)
v.requires_grad = True

In [56]:
y_pred = fasthpp(v, x, stop_recursion=3)

In [57]:
y_pred.sum().backward()

In [59]:
v.grad

tensor([[ 1.2937e+01, -2.9528e+00, -1.6258e+01,  ...,  1.3506e+01,
          2.0690e-05, -1.1856e+00],
        [ 1.3862e+01, -2.3585e+00, -2.4804e+01,  ...,  7.7521e+00,
          1.2090e+00,  6.0216e+00],
        [ 5.4788e+00, -5.0078e+00, -1.1050e+01,  ..., -1.6087e+01,
          1.1258e+00,  8.1878e+00],
        ...,
        [ 9.3238e+00, -4.7141e+00, -2.0566e+01,  ...,  3.9409e+00,
         -3.6661e+00, -2.0735e+01],
        [-2.8815e+00, -1.0664e+01, -5.8704e+00,  ...,  9.8769e+00,
         -3.5049e+00, -2.0599e+01],
        [ 6.0672e+00, -7.3961e+00, -1.6536e+01,  ...,  8.4623e+00,
         -3.2601e+00, -1.9084e+01]])