In [5]:
import numpy as np
import torch
import matplotlib.pyplot as plt

In [6]:
def hardmax(W,H):
  d_dist = W @ H.T
  wh = torch.diag(d_dist)
  matrix = d_dist - wh.unsqueeze(1).repeat((1,W.shape[0]))
  for i in range(W.shape[0]):
      matrix[i, i] = -np.inf
  max = torch.max(matrix)
  return max

In [7]:
def minimize(W, H, alpha=0.1, tol=1e-6, max_iter=500000):
    """
    Use gradient descent to minimize the objective function.
    """
    lr_sched = np.linspace(0, alpha, num=max_iter)
    lr_sched = lr_sched[::-1]
    W = torch.autograd.Variable(W, requires_grad=True)
    H = torch.autograd.Variable(H, requires_grad=True)
    for i in range(max_iter):
        f = hardmax(W, H)
        # f = torch.nn.functional.cross_entropy(W@H.T*1, torch.arange(0, W.shape[0]).type(torch.LongTensor).to(W.device))
        f.backward()
        if torch.norm(W.grad) < tol and torch.norm(H.grad):
            break
        with torch.no_grad():
            W -= lr_sched[i] * W.grad
            W /= torch.norm(W, dim=1, keepdim=True)
            W.grad.zero_()
            H -= lr_sched[i] * H.grad
            H /= torch.norm(H, dim=1, keepdim=True)
            H.grad.zero_()
        if i%5000 == 0:
          print("iteration " + str(i).zfill(7) +" lr: %.3f"%lr_sched[i]+" f_value: %.8f" %f.item() + " max difference: %.5f"%torch.max(torch.norm(W-H, dim=1)).item())
    return f, W, H

In [8]:
#d_list = [3]#[3,4,8,7,6,5]
#K_list = [12]#[12,120,240,56,27,16]
#d_K_pair = [(7,56), (6,27), (5,16), (4, 120), (8,240)]
d_K_pair = [(21,162)]
cos_list = []
lr = 0.1
device = "cuda:3"
for (d,K) in d_K_pair:
    print(f"d: {d}, K: {K}")
    W = torch.randn((K, d)).to(device)
    #W_np = np.load("./WWT_matrix/d21_K162.npy")
    #W = torch.tensor(W_np).to(device)
    W /= torch.norm(W, dim=1, keepdim=True)
    H = W

    init_f= hardmax(W, H)
    print('init_f: ', init_f)
    minimizer, W, H = minimize(W, H, alpha=lr)


    WWT = (W @ W.T).detach().cpu().numpy()
    with open(f'./WWT_matrix/d{d}_K{K}.npy', 'wb') as f:
      np.save(f, WWT)

    for i in range(WWT.shape[0]):
        WWT[i,i] = -np.inf
    print("max cosine value:", np.max(WWT))
    cos_list.append(np.max(WWT))

print(cos_list)

d: 21, K: 162
init_f:  tensor(-0.7563, device='cuda:3')
iteration 0000000 lr: 0.100 f_value: -0.75628614 max difference: 0.00000
iteration 0005000 lr: 0.097 f_value: -0.70484447 max difference: 0.00000
iteration 0010000 lr: 0.095 f_value: -0.71426630 max difference: 0.00000
iteration 0015000 lr: 0.092 f_value: -0.71601117 max difference: 0.00000
iteration 0020000 lr: 0.090 f_value: -0.70593172 max difference: 0.00000
iteration 0025000 lr: 0.087 f_value: -0.72524649 max difference: 0.00000
iteration 0030000 lr: 0.085 f_value: -0.74361879 max difference: 0.00000
iteration 0035000 lr: 0.082 f_value: -0.77657616 max difference: 0.00000
iteration 0040000 lr: 0.080 f_value: -0.79640341 max difference: 0.00000
iteration 0045000 lr: 0.077 f_value: -0.82122707 max difference: 0.00000
iteration 0050000 lr: 0.075 f_value: -0.85032749 max difference: 0.00000
iteration 0055000 lr: 0.072 f_value: -0.87039530 max difference: 0.00000
iteration 0060000 lr: 0.070 f_value: -0.87776345 max difference: 0.0