In [1]:
import torch
import torch.nn as nn
import numpy as np
from torch.autograd.functional import jacobian
import tqdm
from tqdm import notebook
import random
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
import matplotlib.pyplot as plt

In [2]:
torch.set_default_dtype(torch.float64)
torch.set_default_device('cpu')

In [3]:
#@torch.compile(mode = "max-autotune")
def G(gs):
    '''
    :param gs: a list of tensor functions
    :return: a function sending a tensor to the stacked matrix of the functions of that tensor
    '''
    def G_gs(tensor):
        x = torch.squeeze(tensor)
        # print("Function input: ",tensor) # checking the input for debugging
        # print("Function output:" , torch.stack([g(tensor) for g in gs],0))
        return torch.stack([g(x) for g in gs], 0)

    return G_gs

#@torch.compile(mode = "max-autotune")
def J(gs, x):
    '''Returns the Jacobian evaluated at x for a list gs of constraint functions'''
    return jacobian(G(gs), torch.squeeze(x))

In [4]:
@torch.compile(mode = "max-autotune")
def rattle_step(x, v1, h, M, gs, e):
    '''
    Defining a function to take a step in the position, velocity form.
    g should be a vector-valued function of constraints.
    :return: x_1, v_1
    '''

    M1 =  torch.inverse(M)

    G1 = G(gs)


    DV = torch.zeros_like(x)

    #DV[-1] = 10  # leaving this out for g-BAOAB
    DV_col = DV.reshape(-1, 1)

    x_col = x.reshape(-1, 1)
    v1_col = v1.reshape(-1, 1)

    # doing Newton-Raphson iterations
    iters = 0
    x2 = x_col + h * v1_col - 0.5*(h**2)* M1 @ DV_col
    Q_col = x2
    Q = torch.squeeze(Q_col)
    J1 = J(gs, torch.squeeze(x_col))

    #print("RATTLE")
    #while torch.any(torch.abs(G1(Q)) > e):
    for iters in range(3):
        J2 = J(gs, torch.squeeze(Q))
        R = J2 @ M1 @ J1.t()
        dL = torch.inverse(R) @ G1(Q)
        #print(f"Q = {Q}")
        Q=Q- M1 @ J1.t() @ dL
    #print("CONVERGED")
    #print(f"Updating v1_col, Jacobian {J(gs,torch.squeeze(x_col))}")
    #print(f"Updating v1_col, Jacobian^T {J(gs,torch.squeeze(x_col)).t()}")

    # half step for velocity
    Q_col = Q.reshape(-1,1)
    v1_half = (Q_col - x_col)/h
    x_col = Q_col
    J1 = J(gs, torch.squeeze(x_col))

    # getting the level
    J2 = J(gs, torch.squeeze(Q))
    P = J1 @ M1 @ J1.t()
    T = J1 @ (2/h * v1_half - M1 @ DV_col)

    #solving the linear system
    L = torch.linalg.solve(P,T)

    v1_col = v1_half - h/2 * DV_col - h/2 * J2.t()@L


    # print(f"Error = {G1(x_col + h*( v1_col + h/2 * torch.inverse(M) @ J1.reshape(-1,1) @ lam))}")
    # # updating v
    # print(f"lam = {lam}")
    # print(f"Updating v1_col, Jacobian^T {J(gs,torch.squeeze(x_col)).t}")

    return torch.squeeze(x_col), torch.squeeze(v1_col)

In [5]:
@torch.compile(mode = "max-autotune")
def gBAOAB_step_exact(q_init,p_init,F, gs, h,M, gamma, k, kr,e):
    # setting up variables
    M1 = torch.inverse(M)
    R = torch.randn(len(q_init))
    p = p_init
    q = q_init
    a2 = torch.exp(torch.tensor(-gamma*h))
    b2 = torch.sqrt(k*(1-a2**(2)))

    # doing the initial p-update
    J1 = J(gs,torch.squeeze(q))
    G = J1
    L1 = torch.eye(len(q_init)) - torch.transpose(G,0,1) @ torch.inverse(G@ M1@ torch.transpose(G,0,1)) @ G @ M1
    p =p-  h/2 * L1 @ F(q)


    # doing the first RATTLE step
    for i in range(kr):
      q, p = rattle_step(q, p, h/2*kr, M, gs, e)


    # the second p-update - (O-step in BAOAB)
    J2 = J(gs,torch.squeeze(q))
    G = J2
    L2 = torch.eye(len(q_init)) - torch.transpose(G,0,1) @ torch.inverse(G@ M1@ torch.transpose(G,0,1)) @ G @ M1
    p = a2* p + b2* M**(1/2) @L2 @ M**(1/2) @ R

    # doing the second RATTLE step
    for i in range(kr):
      q, p = rattle_step(q, p, h/2*kr, M, gs, e)


    # the final p update
    J3= J(gs,torch.squeeze(q))
    G = J3
    L3 = torch.eye(len(q_init)) - torch.transpose(G,0,1) @ torch.inverse(G@ M1@ torch.transpose(G,0,1)) @ G @ M1
    p = p-  h/2 * L3 @ F(q)

    return q,p

In [6]:
@torch.compile(backend="eager")
def gBAOAB_integrator(q_init,p_init,F, gs, h,M, gamma, k, steps,kr,e):
    positions = []
    velocities = []
    q = q_init
    p = p_init
    for i in range(steps):
        q, p = gBAOAB_step_exact(q,p, F,gs, h,M, gamma, k,kr,e)
        positions.append(q)
        velocities.append(p)

    return positions, velocities

In [7]:
@torch.compile(backend="eager")
def multi_gBAOAB_integrator(q_inits,p_inits,F, gs, h,M, gamma, k,ts,kr,e):
    positions = []
    velocities = []
    for ind in len(q_inits):
      q = q_inits[i]
      p = p_inits[i]
      steps = int(ts[i]//h)
      for i in range(steps):
        q, p = gBAOAB_step_exact(q,p, F,gs, h,M, gamma, k,kr,e)
      positions.append(q)
      velocities.append(p)

    return torch.stack(positions), torch.stack(velocities)

In [8]:
bones2 = [
    (1, 2),
    (1, 3),
    (3, 4),
    (4, 5),
    (5, 6),
    (1, 7),
    (7, 8),
    (8, 9),
    (9, 10),
    (11, 12),
    (12, 13),
    (13, 14),
    (15, 16),
    (16, 17),
    (17, 18),
]

In [9]:
@torch.compile(mode = "max-autotune")
def length_constraint(i,j, xinit):
    init = torch.squeeze(xinit)
    def constraint_fn(y):
        x = torch.squeeze(y) # will need to change for batched data
        return (x[3*i]- x[3*j])**2 +(x[3*i+1]- x[3*j+1])**2 + (x[3*i+2]- x[3*j+2])**2 - ((init[3*i]- init[3*j])**2 +(init[3*i+1]- init[3*j+1])**2 + (init[3*i+2]- init[3*j+2])**2)
    return constraint_fn

In [10]:
@torch.compile(mode = "max-autotune")
def length_constraint_2(i, xinit):
    init = torch.squeeze(xinit)
    def constraint_fn(y):
        x = torch.squeeze(y)
        return (x[3*i]- 0)**2 +(x[3*i+1]- 0)**2 + (x[3*i+2]- 2)**2 - ((init[3*i]- 0)**2 +(init[3*i+1]- 0)**2 + (init[3*i+2]- 2)**2)
    return constraint_fn

In [11]:
@torch.compile(mode = "max-autotune")
def cotangent_projection(gs):
    def proj(x):
        G = J(gs,x)
        M = torch.eye(G.size()[1])
        L= torch.eye(G.size()[1]) - G.T @ torch.inverse(G @ M @ G.T) @ G @ torch.inverse(M)
        return L, G
    return proj

In [12]:
class GaussianFourierProjection(nn.Module):
  """Gaussian random features for encoding time steps."""
  def __init__(self, embed_dim, scale=30.):
    super().__init__()
    # Randomly sample weights during initialization. These weights are fixed
    # during optimization and are not trainable.
    self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)

  def forward(self, x):
    x_proj = x[:, None] * self.W[None, :] * 2 * torch.pi
    return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)

In [13]:
class ScoreNet(nn.Module):
  """A time-dependent score-based model."""


  def __init__(self, embed_dim):
    super().__init__()
    self.embed = nn.Sequential(GaussianFourierProjection(embed_dim=embed_dim),nn.Linear(embed_dim, embed_dim))
    self.lin_embed = nn.Linear(embed_dim,57)
    self.lin_embed2 = nn.Linear(embed_dim,57)
    self.lin1 = nn.Linear(57,57)
    self.lin2 = nn.Linear(57, 57)
    self.lin3 = nn.Linear(57, 57)
    self.lin4 = nn.Linear(57, 57)
    self.lin5 = nn.Linear(57,57)
    self.act = torch.nn.Sigmoid()
    
  #@torch.compile(mode="default")
  def forward(self,x,t,L,G):
      # setting the fixed points of x
      l = torch.zeros_like(x)
      l[:,0] = x[:,0]
      l[:,1] = x[:,1]
      l[:,2] = -torch.ones_like(x[:,2])*2 + x[:,2]

      x = x - l
      embed = self.act(self.embed(t))
      h = self.lin1(x)
      h = h+ self.lin_embed(embed)
      h = self.act(self.lin2(h))
      h = self.act(self.lin3(h))+ self.lin_embed2(embed)
      h = self.act(self.lin4(h))
      h = self.lin5(h)

      # projection
      p = torch.unsqueeze(L@ torch.squeeze(h),0)
      h = p
    
      # testing that the cotangent projection works?????
      
      # NOT normalizing the output
      #h = h/ t[:,None]

      # setting the force on the fixed point to zero
      l2 = torch.zeros_like(h)
      l2[:,0] = h[:,0]
      l2[:,1] = h[:,1]
      l2[:,2] = h[:,2]
      h = h - l2
      #print(f"Projection error: {G@ torch.squeeze(h)}")
      return torch.squeeze(h)

In [14]:
score_model = torch.nn.DataParallel(ScoreNet(58))

In [15]:
from torch.utils.data import TensorDataset, DataLoader

In [16]:
from torch.optim import Adam

In [17]:
import warnings
warnings.filterwarnings("ignore")


In [18]:
#torch._dynamo.config.verbose=True

In [19]:
#torch._dynamo.config.suppress_errors = True

# Sampling from the reverse SDE

In [20]:
# data1 = np.load('0-100.npy',allow_pickle=True)
# data2 = np.load('100-200.npy',allow_pickle=True)
# data3 = np.load('200-250.npy',allow_pickle=True)
# data4 = np.load('250-400.npy',allow_pickle=True)
# data5 = np.load('400-500.npy',allow_pickle=True)
# data6 = np.load('500-600.npy',allow_pickle=True)
# data7= np.load('600-650.npy',allow_pickle=True)
# data8= np.load('650-700.npy',allow_pickle=True)
# data9= np.load('700-750.npy',allow_pickle=True)
# data10= np.load('750-800.npy',allow_pickle=True)

data10 =np.load('1000-2000.npy',allow_pickle=True)
data11= np.load('2000-2500.npy',allow_pickle=True)
data12= np.load('2500-3000.npy',allow_pickle=True)
data13= np.load('3000-3500.npy',allow_pickle=True)
data14= np.load('3500-4000.npy',allow_pickle=True)
data15= np.load('4000-4500.npy',allow_pickle=True)

In [21]:
#data16= np.load('4500-5000.npy',allow_pickle=True)
data17= np.load('5000-5500.npy',allow_pickle=True)
data18= np.load('5500-6000.npy',allow_pickle=True)
data19= np.load('6000-6500.npy',allow_pickle=True)

In [22]:
# numpy1 = np.array([np.array([data.numpy() for data in data_1]) for data_1 in data1])
# numpy2 = np.array([np.array([data.numpy() for data in data_1]) for data_1 in data2])
# numpy3 = np.array([np.array([data.numpy() for data in data_1]) for data_1 in data3])
# numpy4 = np.array([np.array([data.numpy() for data in data_1]) for data_1 in data4])
# numpy5 = np.array([np.array([data.numpy() for data in data_1]) for data_1 in data5])
# numpy6 = np.array([np.array([data.numpy() for data in data_1]) for data_1 in data6])
# numpy7 = np.array([np.array([data.numpy() for data in data_1]) for data_1 in data7])
# numpy8 = np.array([np.array([data.numpy() for data in data_1]) for data_1 in data8])
# numpy9 = np.array([np.array([data.numpy() for data in data_1]) for data_1 in data9])
numpy10 = np.array([np.array([data.numpy() for data in data_1]) for data_1 in data10])
numpy11 = np.array([np.array([data.numpy() for data in data_1]) for data_1 in data11])
numpy12 = np.array([np.array([data.numpy() for data in data_1]) for data_1 in data12])
numpy13 = np.array([np.array([data.numpy() for data in data_1]) for data_1 in data13])
numpy14 = np.array([np.array([data.numpy() for data in data_1]) for data_1 in data14])
numpy15 = np.array([np.array([data.numpy() for data in data_1]) for data_1 in data15])
#numpy16 = np.array([np.array([data.numpy() for data in data_1]) for data_1 in data16])
numpy17 = np.array([np.array([data.numpy() for data in data_1]) for data_1 in data17])
numpy18 = np.array([np.array([data.numpy() for data in data_1]) for data_1 in data18])
numpy19 = np.array([np.array([data.numpy() for data in data_1]) for data_1 in data19])
#numpy20 = np.array([np.array([data.numpy() for data in data_1]) for data_1 in data20])

In [23]:
data_numpy = np.concatenate([numpy10,numpy11,numpy12,numpy13,numpy14,numpy15,numpy17,numpy18,numpy19])

In [24]:
data_tensor = torch.tensor(data_numpy, device = 'cpu')

In [25]:
my_dataset = TensorDataset(data_tensor)

In [26]:
#@title Define the loss function (double click to expand or collapse)
def loss2(model, xs,eps=1e-5):
  """The loss function for training score-based generative models.

  Args:
    model: A PyTorch model instance that represents a
      time-dependent score-based model.
    x: A mini-batch of training data.
    marginal_prob_std: A function that gives the standard deviation of
      the perturbation kernel.
    eps: A tolerance value for numerical stability.
  """
  # uniformly selecting a time
  loss =torch.tensor([0.])
  loss.requires_grad_()
  tr =0
  s_tr = 0
  for x in xs:
    # x is the list of simulated qs
    q = torch.squeeze(x[0]).detach()
    gs = [length_constraint(i,j,torch.squeeze(q)) for (i,j) in bones2]
    gs.append(length_constraint_2(1,torch.squeeze(q)))
    gs.append(length_constraint_2(11,torch.squeeze(q)))
    gs.append(length_constraint_2(15,torch.squeeze(q)))

    # projection matrix
    L_fn = cotangent_projection(gs)
    random_t = torch.round(torch.rand(1, device=x.device)*(len(x))-1)
    q = x[random_t.numpy()]

    sim_x = q
    L , G= L_fn(sim_x) # defining the projection matrix
    score = model(sim_x, random_t/100,L,G).cpu()
    tr = tr + torch.abs( torch.trace(torch.squeeze(torch.autograd.functional.jacobian(model, inputs=(sim_x,random_t/100,L,G))[0])) )
    s_tr = s_tr + torch.trace(torch.squeeze(torch.autograd.functional.jacobian(model, inputs=(sim_x,random_t/100,L,G))[0]))
    loss = loss + 1/2 * torch.linalg.norm(score)**2 + torch.trace(torch.squeeze(torch.autograd.functional.jacobian(model, inputs=(sim_x,random_t/100,L,G), create_graph=True,strict=False)[0]))
  if tr < 0.1:
    print(f"sum absolute trace = {tr}, sum trace = {s_tr}")
  return (loss/len(xs)) 

In [None]:
## size of a mini-batch
## learning rate
lr=1e-3 #@param {'type':'number'}
batch_size =  20 #@param {'type':'integer'}
## learning rate
dataloader = DataLoader(my_dataset,batch_size=batch_size, shuffle=True)
n_epochs = 20
tqdm_epoch = tqdm.notebook.trange(n_epochs)
optimizer = Adam(score_model.parameters(), lr=lr)
i = 0
epoch_losses =[]
for epoch in tqdm_epoch:
    t_dl =tqdm.tqdm(dataloader)
    avg_loss = 0.
    num_items = 0
    for pw in t_dl:
        x = pw[0].cpu()
        i += 1
        loss = loss2(score_model, x)
        t_dl.set_description(f"Loss = {loss.item()}")
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        avg_loss += loss.item() * x.shape[0]
        num_items += x.shape[0]
        epoch_losses.append(avg_loss / num_items)
    tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
    torch.save(score_model.state_dict(), f"ckpt_improved_sig_3_epoch{epoch}.pth")

  0%|          | 0/20 [00:00<?, ?it/s]


  0%|                                                               | 0/250 [00:00<?, ?it/s][ANo CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'

Loss = 1.6969762248829485:   0%|                                    | 0/250 [00:06<?, ?it/s][A

sum absolute trace = 0.009552652679137325, sum trace = 0.009552652679137325



Loss = 1.6969762248829485:   0%|                            | 1/250 [00:06<27:55,  6.73s/it][A
Loss = 1.394326041584087:   0%|                             | 1/250 [00:10<27:55,  6.73s/it][A

sum absolute trace = 0.006118226521412266, sum trace = 0.0015278494348590493



Loss = 1.394326041584087:   1%|▏                            | 2/250 [00:11<22:08,  5.36s/it][A
Loss = 1.1292785262872609:   1%|▏                           | 2/250 [00:14<22:08,  5.36s/it][A

sum absolute trace = 0.00647975985596769, sum trace = -0.005668370447452744



Loss = 1.1292785262872609:   1%|▎                           | 3/250 [00:15<19:50,  4.82s/it][A
Loss = 0.9065022098771479:   1%|▎                           | 3/250 [00:19<19:50,  4.82s/it][A

sum absolute trace = 0.015248653657364135, sum trace = -0.015248653657364135



Loss = 0.9065022098771479:   2%|▍                           | 4/250 [00:19<18:56,  4.62s/it][A
Loss = 0.699523239462926:   2%|▍                            | 4/250 [00:23<18:56,  4.62s/it][A

sum absolute trace = 0.02220457199542365, sum trace = -0.02220457199542365



Loss = 0.699523239462926:   2%|▌                            | 5/250 [00:24<18:33,  4.55s/it][A
Loss = 0.5481671064933235:   2%|▌                           | 5/250 [00:27<18:33,  4.55s/it][A

sum absolute trace = 0.03216534878181921, sum trace = -0.03216534878181921



Loss = 0.5481671064933235:   2%|▋                           | 6/250 [00:28<18:13,  4.48s/it][A
Loss = 0.41695359937454973:   2%|▋                          | 6/250 [00:32<18:13,  4.48s/it][A

sum absolute trace = 0.04170637920381852, sum trace = -0.04170637920381852



Loss = 0.41695359937454973:   3%|▊                          | 7/250 [00:32<18:12,  4.49s/it][A
Loss = 0.30027962225451643:   3%|▊                          | 7/250 [00:37<18:12,  4.49s/it][A

sum absolute trace = 0.05167847241761557, sum trace = -0.05167847241761557



Loss = 0.30027962225451643:   3%|▊                          | 8/250 [00:37<18:27,  4.58s/it][A
Loss = 0.2546021662456566:   3%|▉                           | 8/250 [00:42<18:27,  4.58s/it][A

sum absolute trace = 0.06178752237999361, sum trace = -0.06178752237999361



Loss = 0.2546021662456566:   4%|█                           | 9/250 [00:42<19:13,  4.79s/it][A
Loss = 0.17997792657218242:   4%|▉                          | 9/250 [00:46<19:13,  4.79s/it][A

sum absolute trace = 0.07170972806512321, sum trace = -0.07170972806512321



Loss = 0.17997792657218242:   4%|█                         | 10/250 [00:47<18:47,  4.70s/it][A
Loss = 0.14115820664117287:   4%|█                         | 10/250 [00:51<18:47,  4.70s/it][A

sum absolute trace = 0.07987548657225661, sum trace = -0.07987548657225661



Loss = 0.14115820664117287:   4%|█▏                        | 11/250 [00:52<18:59,  4.77s/it][A
Loss = 0.10457517201865855:   4%|█▏                        | 11/250 [00:56<18:59,  4.77s/it][A

sum absolute trace = 0.08798737943103781, sum trace = -0.08798737943103781



Loss = 0.10457517201865855:   5%|█▏                        | 12/250 [00:56<18:38,  4.70s/it][A
Loss = 0.0789442387024926:   5%|█▎                         | 12/250 [01:00<18:38,  4.70s/it][A

sum absolute trace = 0.09675604606442412, sum trace = -0.09675604606442412



Loss = 0.0789442387024926:   5%|█▍                         | 13/250 [01:01<18:18,  4.64s/it][A
Loss = 0.06390580791956343:   5%|█▎                        | 13/250 [01:05<18:18,  4.64s/it][A
Loss = 0.06390580791956343:   6%|█▍                        | 14/250 [01:05<18:12,  4.63s/it][A
Loss = 0.05218523466273979:   6%|█▍                        | 14/250 [01:09<18:12,  4.63s/it][A
Loss = 0.05218523466273979:   6%|█▌                        | 15/250 [01:10<17:57,  4.58s/it][A
Loss = 0.04199963743322947:   6%|█▌                        | 15/250 [01:14<17:57,  4.58s/it][A
Loss = 0.04199963743322947:   6%|█▋                        | 16/250 [01:14<17:33,  4.50s/it][A
Loss = 0.03537252289177329:   6%|█▋                        | 16/250 [01:18<17:33,  4.50s/it][A
Loss = 0.03537252289177329:   7%|█▊                        | 17/250 [01:19<17:11,  4.43s/it][A
Loss = 0.031882958305119974:   7%|█▋                       | 17/250 [01:22<17:11,  4.43s/it][A
Loss = 0.031882958305119974:   7%|█▊   

Loss = -0.0607418988584346:  22%|█████▋                    | 55/250 [04:20<15:31,  4.78s/it][A
Loss = -0.0607418988584346:  22%|█████▊                    | 56/250 [04:21<15:18,  4.74s/it][A
Loss = -0.06537108723575481:  22%|█████▌                   | 56/250 [04:25<15:18,  4.74s/it][A
Loss = -0.06537108723575481:  23%|█████▋                   | 57/250 [04:25<14:55,  4.64s/it][A
Loss = -0.0695162840898216:  23%|█████▉                    | 57/250 [04:30<14:55,  4.64s/it][A
Loss = -0.0695162840898216:  23%|██████                    | 58/250 [04:30<15:11,  4.75s/it][A
Loss = -0.07444151367535515:  23%|█████▊                   | 58/250 [04:34<15:11,  4.75s/it][A
Loss = -0.07444151367535515:  24%|█████▉                   | 59/250 [04:35<15:02,  4.73s/it][A
Loss = -0.07944195691124455:  24%|█████▉                   | 59/250 [04:39<15:02,  4.73s/it][A
Loss = -0.07944195691124455:  24%|██████                   | 60/250 [04:40<14:59,  4.73s/it][A
Loss = -0.08595791008048563:  24%|██████

In [None]:
plt.plot(epoch_losses)

In [None]:
## size of a mini-batch
## learning rate
lr=1e-45 #@param {'type':'number'}
batch_size =  20 #@param {'type':'integer'}
## learning rate
dataloader = DataLoader(my_dataset,batch_size=batch_size, shuffle=True)
n_epochs = 20
tqdm_epoch = tqdm.notebook.trange(n_epochs)
optimizer = Adam(score_model.parameters(), lr=lr)
for epoch in tqdm_epoch:
    t_dl =tqdm.tqdm(dataloader)
    avg_loss = 0.
    num_items = 0
    for pw in t_dl:
        x = pw[0].cpu()
        i += 1
        loss = loss2(score_model, x)
        t_dl.set_description(f"Loss = {loss.item()}")
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        avg_loss += loss.item() * x.shape[0]
        num_items += x.shape[0]
        epoch_losses.append(avg_loss / num_items)
    tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
    torch.save(score_model.state_dict(), 'ckpt_improved_sig_3.pth')

In [None]:
model = ScoreNet(58)

In [None]:
x = torch.rand(57)
x.requires_grad_()

# generating a uniform position

In [None]:
data = np.load('train_dataset.npy')

In [None]:
bones = [
    (0, 1),
    (1, 2),
    (1, 3),
    (3, 4),
    (4, 5),
    (5, 6),
    (1, 7),
    (7, 8),
    (8, 9),
    (9, 10),
    (0, 11),
    (11, 12),
    (12, 13),
    (13, 14),
    (0, 15),
    (15, 16),
    (16, 17),
    (17, 18),
]

In [None]:
def angles_to_joints(angles, lengths,bones):
    pose = torch.zeros((19,3))
    pose[0] = torch.tensor([0,0,2])
    for i in range(len(bones)):
        bone = bones[i]
        r = lengths[i]
        phi = torch.tensor(angles[i][0])
        theta = torch.tensor(angles[i][1])
        pose[bone[1]] = pose[bone[0]] + torch.tensor([r*torch.sin(theta)*torch.cos(phi),r*torch.sin(theta)*torch.sin(phi),r*torch.cos(theta)])
    return pose

In [None]:
def lengths(bones,poses):
    lengths = []
    poses = torch.squeeze(poses).reshape((19,3))
    for bone in bones:
        l = poses[bone[0]] -poses[bone[1]]
        l = torch.sqrt(torch.dot(l,l))
        lengths.append(l)
    return lengths

In [None]:
def uniform_generator(x_init,bones):
    random_angles = torch.vstack([torch.rand(18)*torch.pi*2,torch.acos(2*torch.rand(18) -1)]).T
    lengths1 = lengths(bones, x_init)
    return angles_to_joints(random_angles,lengths1,bones)

since the original SDE was just Brownian motion, the reverse is simply:

In [None]:
data = np.load('train_dataset.npy')
x_init = torch.squeeze(torch.tensor(data[0]))

In [None]:
#ckpt = torch.load('ckpt.pth')
score_model.load_state_dict(ckpt)

In [None]:
q_init = uniform_generator(torch.squeeze(x_init),bones)

In [None]:
torch._dynamo.config.suppress_errors = True

In [None]:
@torch.compile(mode = "max-autotune")
def run_simulation(q_init,steps):
  h = 1/steps
  positions = []
  q = torch.flatten(torch.squeeze(q_init))
  gs = [length_constraint(i,j,torch.squeeze(q)) for (i,j) in bones2]
  gs.append(length_constraint_2(1,torch.squeeze(q)))
  gs.append(length_constraint_2(11,torch.squeeze(q)))
  gs.append(length_constraint_2(15,torch.squeeze(q)))
  p = torch.zeros_like(q)
  M = torch.eye(q.size()[0])
  L_fn = cotangent_projection(gs)
  @torch.compile()
  def force(y):
    with torch.no_grad():
      x = torch.unsqueeze(y,0)
      L = L_fn(x)
      score = -score_model(x,time,L)
    return torch.squeeze(score)
  for step in range(1,steps):
      print(step)
      time = torch.tensor([1-(step)*h])
      q1 = q
      q, p = gBAOAB_step_exact(q,p,force, gs, h,M, 1,1,1,10**(-11))
      print(f"Movement of q: {torch.linalg.norm(q-q1)}")
      positions.append(q)
  return positions

In [None]:
positions = run_simulation(q_init, 10000)

In [None]:
pos = np.array([p.detach() for p in positions])
np.save('reverse_pos_new.npy',pos)