In [481]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from time import time
from tqdm import tqdm


class PKBD(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(PKBD, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)
        self.output_dim = output_dim
        
    def forward(self, x):
        mu = self.fc(x)
        normm = mu.norm(dim=-1, keepdim=True)
        rho = normm / (1 + normm)
        mu = mu / normm
        return mu, rho

def pkbd_log_likelihood(mu, rho, x, W):
    d = x.shape[-1]
    term1 = (1 - rho ** 2).log()  # log(1 - rho^2)
    term2 = 1 + rho ** 2 - 2 * rho * ((mu.unsqueeze(1) @ Y.unsqueeze(2)).squeeze(-1))  # (1 + rho^2 - 2*rho*mu^T x)^(d/2)
    log_likelihood = (d/2)*term2.log() - term1
    #print(log_likelihood.shape, W.shape)
    return (log_likelihood * W).sum()

# Example usage
input_dim = 1
output_dim = 3
EPOCHS = 100
LR = 0.5
batch_size = 32

In [345]:
Y = np.load('Y.npy')
Y = torch.tensor(Y, dtype=torch.float32)
Y = Y[:1000, :]
Y.shape

torch.Size([1000, 3])

In [346]:
X = torch.ones(Y.shape[0], input_dim, dtype=torch.float32)
X = X[:1000, :]
X

tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
      

In [347]:
W = torch.ones(Y.shape[0], input_dim, dtype=torch.float32)/1000
W.shape
W

tensor([[0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0.0010],
        [0

In [517]:
model = PKBD(input_dim, output_dim)
optimizer = optim.Adam(model.parameters(), lr=LR)

In [518]:
%%time
model.train()
for epoch in range(EPOCHS):
    optimizer.zero_grad()
    mu, rho = model(X)
    print(rho[1].item())
    loss = pkbd_log_likelihood(mu, rho, Y, W)
    print(loss.item())
    loss.backward()
    optimizer.step()

0.5910823941230774
0.17256538569927216
0.6548534631729126
-1.1215004920959473
0.7294043302536011
-2.143233299255371
0.7844153046607971
-2.386915683746338
0.8187183737754822
-2.677828550338745
0.8420624732971191
-3.0683343410491943
0.85946124792099
-3.3039000034332275
0.8728229403495789
-3.3805766105651855
0.8832686543464661
-3.409651279449463
0.8915784358978271
-3.472324848175049
0.8983380198478699
-3.5657448768615723
0.9039567112922668
-3.6553268432617188
0.908701479434967
-3.7211878299713135
0.9127556681632996
-3.7571096420288086
0.9162515997886658
-3.7687950134277344
0.9192866683006287
-3.7738137245178223
0.9219375252723694
-3.7879037857055664
0.924267590045929
-3.8151562213897705
0.9263291358947754
-3.849391460418701
0.9281648397445679
-3.879033327102661
0.9298087954521179
-3.8945746421813965
0.9312869906425476
-3.89546537399292
0.9326194524765015
-3.890187978744507
0.9338226318359375
-3.889383316040039
0.9349113702774048
-3.8988020420074463
0.9358997344970703
-3.916536331176758
0.

In [483]:
model.fc.weight

Parameter containing:
tensor([[10.0375],
        [ 0.1044],
        [ 0.3620]], requires_grad=True)

In [529]:
model = PKBD(input_dim, output_dim)
optimizer = optim.LBFGS(model.parameters(), lr=LR, max_iter=20)

In [530]:
%%time
model.train()
for epoch in range(1):
    def closure():
        optimizer.zero_grad()
        mu, rho = model(X)
        print(rho[1].item())
        loss = pkbd_log_likelihood(mu, rho, Y, W)
        print(f'Loss {loss.item()}')
        loss.backward()
        return loss
    optimizer.step(closure)

0.6090788841247559
Loss 0.4048592448234558
0.5963284373283386
Loss 0.018465563654899597
0.6345394253730774
Loss -2.0414209365844727
0.8175244331359863
Loss 0.16586364805698395
0.811237633228302
Loss -0.36425501108169556
0.8042029142379761
Loss -1.1842683553695679
0.7979008555412292
Loss -2.532008409500122
0.7996063828468323
Loss -2.8257181644439697
0.8064548969268799
Loss -3.0726265907287598
0.816735029220581
Loss -3.187577247619629
0.829311728477478
Loss -3.2708189487457275
0.8505144119262695
Loss -3.400203227996826
0.8735232353210449
Loss -3.546751022338867
0.8942664861679077
Loss -3.664421796798706
0.9127225279808044
Loss -3.808732032775879
0.9231880903244019
Loss -3.8751323223114014
0.93205326795578
Loss -3.9232492446899414
0.9385175108909607
Loss -3.95131254196167
0.9432503581047058
Loss -3.966712236404419
0.9465882778167725
Loss -3.974252700805664
CPU times: total: 0 ns
Wall time: 12 ms


In [512]:
torch.utils.data.get_worker_info()

In [None]:

loss_list = []
def train(W):
    max_accuracy = 0
    pbar = tqdm(total=EPOCHS)
    begin = time()
    for index in range(EPOCHS):
        model.train()
        for i, (x, y) in enumerate(train_dataloader):
            optimizer.zero_grad()
            mu, rho = model(x)
            loss = pkbd_log_likelihood(mu, rho, y, W)

            loss_list.append(loss.item())

            loss.backward()

            optimizer.step()
        
        pbar.update()

    end = time()
    time_cost = round((end - begin) / 60, 2)

In [55]:
d = 3
mu, rho = model(X)
mu, rho

term1 = (1 - rho ** 2).log()
term2 =  1 + rho**2 - 2 * rho* ((mu.unsqueeze(1) @ Y.unsqueeze(2)).squeeze(-1))
res = (d/2)*term2.log() - term1
print(res.shape)
res.mean()

torch.Size([2000, 1])


tensor(0.0896, grad_fn=<MeanBackward0>)