In [1]:
import torch
import torch.nn as nn
import numpy as np
from torch.autograd.functional import jacobian
from tqdm import notebook, tqdm
import random
import torch.multiprocessing
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader
import warnings

In [2]:
torch.get_num_threads()

40

In [3]:
warnings.filterwarnings("ignore")

In [4]:
torch.set_default_dtype(torch.float64)
torch.set_default_device('cpu')

In [5]:
@torch.compile(mode = "max-autotune")
def G(gs):
    '''
    :param gs: a list of tensor functions
    :return: a function sending a tensor to the stacked matrix of the functions of that tensor
    '''
    def G_gs(tensor):
        x = torch.squeeze(tensor)
        # print("Function input: ",tensor) # checking the input for debugging
        # print("Function output:" , torch.stack([g(tensor) for g in gs],0))
        return torch.stack([g(x) for g in gs], 0)

    return G_gs

@torch.compile(mode = "max-autotune")
def J(gs, x):
    '''Returns the Jacobian evaluated at x for a list gs of constraint functions'''
    return jacobian(G(gs), torch.squeeze(x))

In [6]:
@torch.compile(mode = "max-autotune")
def rattle_step(x, v1, h, M, gs, e):
    '''
    Defining a function to take a step in the position, velocity form.
    g should be a vector-valued function of constraints.
    :return: x_1, v_1
    '''

    M1 =  torch.inverse(M)

    G1 = G(gs)


    DV = torch.zeros_like(x)

    #DV[-1] = 10  # leaving this out for g-BAOAB
    DV_col = DV.reshape(-1, 1)

    x_col = x.reshape(-1, 1)
    v1_col = v1.reshape(-1, 1)

    # doing Newton-Raphson iterations
    iters = 0
    x2 = x_col + h * v1_col - 0.5*(h**2)* M1 @ DV_col
    Q_col = x2
    Q = torch.squeeze(Q_col)
    J1 = J(gs, torch.squeeze(x_col))

    #print("RATTLE")
    #while torch.any(torch.abs(G1(Q)) > e):
    for iters in range(3):
        J2 = J(gs, torch.squeeze(Q))
        R = J2 @ M1 @ J1.t()
        dL = torch.inverse(R) @ G1(Q)
        #print(f"Q = {Q}")
        Q=Q- M1 @ J1.t() @ dL
    #print("CONVERGED")
    #print(f"Updating v1_col, Jacobian {J(gs,torch.squeeze(x_col))}")
    #print(f"Updating v1_col, Jacobian^T {J(gs,torch.squeeze(x_col)).t()}")

    # half step for velocity
    Q_col = Q.reshape(-1,1)
    v1_half = (Q_col - x_col)/h
    x_col = Q_col
    J1 = J(gs, torch.squeeze(x_col))

    # getting the level
    J2 = J(gs, torch.squeeze(Q))
    P = J1 @ M1 @ J1.t()
    T = J1 @ (2/h * v1_half - M1 @ DV_col)

    #solving the linear system
    L = torch.linalg.solve(P,T)

    v1_col = v1_half - h/2 * DV_col - h/2 * J2.t()@L


    # print(f"Error = {G1(x_col + h*( v1_col + h/2 * torch.inverse(M) @ J1.reshape(-1,1) @ lam))}")
    # # updating v
    # print(f"lam = {lam}")
    # print(f"Updating v1_col, Jacobian^T {J(gs,torch.squeeze(x_col)).t}")

    return torch.squeeze(x_col), torch.squeeze(v1_col)

In [7]:
@torch.compile(mode = "max-autotune")
def gBAOAB_step_exact(q_init,p_init,F, gs, h,M, gamma, k, kr,e):
    # setting up variables
    M1 = torch.inverse(M)
    R = torch.randn(len(q_init))
    p = p_init
    q = q_init
    a2 = torch.exp(torch.tensor(-gamma*h))
    b2 = torch.sqrt(k*(1-a2**(2)))

    # doing the initial p-update
    J1 = J(gs,torch.squeeze(q))
    G = J1
    L1 = torch.eye(len(q_init)) - torch.transpose(G,0,1) @ torch.inverse(G@ M1@ torch.transpose(G,0,1)) @ G @ M1
    p =p-  h/2 * L1 @ F(q)


    # doing the first RATTLE step
    for i in range(kr):
      q, p = rattle_step(q, p, h/2*kr, M, gs, e)


    # the second p-update - (O-step in BAOAB)
    J2 = J(gs,torch.squeeze(q))
    G = J2
    L2 = torch.eye(len(q_init)) - torch.transpose(G,0,1) @ torch.inverse(G@ M1@ torch.transpose(G,0,1)) @ G @ M1
    p = a2* p + b2* M**(1/2) @L2 @ M**(1/2) @ R

    # doing the second RATTLE step
    for i in range(kr):
      q, p = rattle_step(q, p, h/2*kr, M, gs, e)


    # the final p update
    J3= J(gs,torch.squeeze(q))
    G = J3
    L3 = torch.eye(len(q_init)) - torch.transpose(G,0,1) @ torch.inverse(G@ M1@ torch.transpose(G,0,1)) @ G @ M1
    p = p-  h/2 * L3 @ F(q)

    return q,p

In [8]:
@torch.compile(mode = "max-autotune")
def gBAOAB_integrator(q_init,p_init,F, gs, h,M, gamma, k, steps,kr,e):
    positions = []
    velocities = []
    q = q_init
    p = p_init
    for i in range(steps):
        q, p = gBAOAB_step_exact(q,p, F,gs, h,M, gamma, k,kr,e)
        positions.append(q)
        velocities.append(p)

    return positions, velocities

In [9]:
@torch.compile(mode = "max-autotune")
def multi_gBAOAB_integrator(q_inits,p_inits,F, gs, h,M, gamma, k,ts,kr,e):
    positions = []
    velocities = []
    for ind in len(q_inits):
      q = q_inits[i]
      p = p_inits[i]
      steps = int(ts[i]//h)
      for i in range(steps):
        q, p = gBAOAB_step_exact(q,p, F,gs, h,M, gamma, k,kr,e)
      positions.append(q)
      velocities.append(p)

    return torch.stack(positions), torch.stack(velocities)

In [10]:
bones2 = [
    (1, 2),
    (1, 3),
    (3, 4),
    (4, 5),
    (5, 6),
    (1, 7),
    (7, 8),
    (8, 9),
    (9, 10),
    (11, 12),
    (12, 13),
    (13, 14),
    (15, 16),
    (16, 17),
    (17, 18),
]

In [11]:
@torch.compile(mode = "max-autotune")
def length_constraint(i,j, xinit):
    init = torch.squeeze(xinit)
    def constraint_fn(y):
        x = torch.squeeze(y) # will need to change for batched data
        return (x[3*i]- x[3*j])**2 +(x[3*i+1]- x[3*j+1])**2 + (x[3*i+2]- x[3*j+2])**2 - ((init[3*i]- init[3*j])**2 +(init[3*i+1]- init[3*j+1])**2 + (init[3*i+2]- init[3*j+2])**2)
    return constraint_fn

In [12]:
@torch.compile(mode = "max-autotune")
def length_constraint_2(i, xinit):
    init = torch.squeeze(xinit)
    def constraint_fn(y):
        x = torch.squeeze(y)
        return (x[3*i]- 0)**2 +(x[3*i+1]- 0)**2 + (x[3*i+2]- 2)**2 - ((init[3*i]- 0)**2 +(init[3*i+1]- 0)**2 + (init[3*i+2]- 2)**2)
    return constraint_fn

In [13]:
@torch.compile(mode = "max-autotune")
def cotangent_projection(gs):
    def proj(x):
        G = J(gs,x)
        M = torch.eye(G.size()[1])
        L= torch.eye(G.size()[1]) - G.T @ torch.inverse(G @ M @ G.T) @ G @ torch.inverse(M)
        return L
    return proj

In [14]:
class GaussianFourierProjection(nn.Module):
  """Gaussian random features for encoding time steps."""
  def __init__(self, embed_dim, scale=30.):
    super().__init__()
    # Randomly sample weights during initialization. These weights are fixed
    # during optimization and are not trainable.
    self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)

  def forward(self, x):
    x_proj = x[:, None] * self.W[None, :] * 2 * torch.pi
    return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)

In [15]:
class ScoreNet(nn.Module):
  """A time-dependent score-based model."""


  def __init__(self, embed_dim):
    super().__init__()
    self.embed = nn.Sequential(GaussianFourierProjection(embed_dim=embed_dim),nn.Linear(embed_dim, embed_dim))
    self.lin_embed = nn.Linear(embed_dim,57)
    self.lin1 = nn.Linear(57,57)
    self.lin2 = nn.Linear(57, 57)
    self.lin3 = nn.Linear(57, 57)
    self.lin4 = nn.Linear(57, 57)
    self.lin5 = nn.Linear(57,57)
    self.act = lambda x : torch.sigmoid(x)
    nn.init.normal_(self.lin_embed.weight, mean=0, std=0.2)
    nn.init.normal_(self.lin_embed.bias, mean=0, std=0.1)
    nn.init.normal_(self.lin1.weight, mean=0, std=0.2)
    nn.init.normal_(self.lin1.bias, mean=0, std=0.1)
    nn.init.normal_(self.lin2.weight, mean=0, std=0.2)
    nn.init.normal_(self.lin2.bias, mean=0, std=0.1)
    nn.init.normal_(self.lin3.weight, mean=0, std=0.2)
    nn.init.normal_(self.lin3.bias, mean=0, std=0.1)
    nn.init.normal_(self.lin4.weight, mean=0, std=0.2)
    nn.init.normal_(self.lin4.bias, mean=0, std=0.1)
    nn.init.normal_(self.lin5.weight, mean=0, std=0.2)
    nn.init.normal_(self.lin5.bias, mean=0, std=0.1)
  def forward(self,x,t,L):
      # setting the fixed points of x
      l = torch.zeros_like(x)
      l[:,0] = x[:,0]
      l[:,1] = x[:,1]
      l[:,2] = -torch.ones_like(x[:,2])*2 + x[:,2]

      x = x - l
      embed = self.act(self.embed(t))
      h = self.lin1(x)
      h = h+ self.lin_embed(embed)
      h = self.act(self.lin2(h))
      #h = self.act(self.lin3(h))
      #h = self.act(self.lin4(h))
      #h = self.act(self.lin5(h))

      # projection
      p = torch.unsqueeze(L@ torch.squeeze(h),0)
      h = p
      # NOT normalizing the output
      #h = h/ t[:,None]

      # setting the force on the fixed point to zero
      l2 = torch.zeros_like(h)
      l2[:,0] = h[:,0]
      l2[:,1] = h[:,1]
      l2[:,2] = -torch.ones_like(h[:,2])*2 + h[:,2]
      h = h - l2
      return torch.squeeze(h)

In [16]:
score_model = torch.nn.DataParallel(ScoreNet(58))

## Sampling the reverse SDE

In [17]:
data1 = np.load('train_dataset.npy',allow_pickle=True)
numpy_data = np.squeeze(data1)
torch_data = torch.tensor(numpy_data)

In [18]:
torch._dynamo.config.suppress_errors = True

In [19]:
@torch.compile(mode = "max-autotune")
def force(x):
  return(torch.zeros_like(x))

In [20]:
@torch.compile(mode = "max-autotune")
def get_data(data, i, j, save_at, already_in_there=np.array([])):
    if len(already_in_there) ==0:
        positions =[]
    else:
        positions = np.ndarray.tolist(already_in_there)
    h = 0.01
    for x in tqdm(data[i+ len(already_in_there):j]):
      q = torch.squeeze(torch.flatten(x))
      p = torch.zeros_like(q)
      M = torch.eye(len(q))
      gs = [length_constraint(i,j,torch.squeeze(q)) for (i,j) in bones2]
      gs.append(length_constraint_2(1,torch.squeeze(q)))
      gs.append(length_constraint_2(11,torch.squeeze(q)))
      gs.append(length_constraint_2(15,torch.squeeze(q)))
      qs,_ = gBAOAB_integrator(q,p,force, gs, h,M, 1, 1, 1000,1,10**(-13))
      ts = []
      for qold in qs:
        qnew = qold
        qnew[0] = 0
        qnew[1] = 0
        qnew[2] = 2
        ts.append(qnew)
      positions.append(ts)
      pos_array= np.array(positions)
      np.save(save_at, pos_array)
    return ts

In [21]:
already = np.load('t10_1000-2000.npy', allow_pickle = True)

In [None]:
qs = get_data(torch_data,1000,2000,'t10_1000-2000.npy', already)

  0%|          | 0/745 [00:00<?, ?it/s][2023-08-24 20:47:43,588] torch._dynamo.convert_frame: [ERROR] WON'T CONVERT <graph break in get_data> /tmp/ipykernel_1680046/3124810454.py line 8 
due to: 
Traceback (most recent call last):
  File "/home/tkelly/anaconda3/lib/python3.9/site-packages/tqdm/utils.py", line 75, in __eq__
    return self._comparable == other._comparable
AttributeError: 'function' object has no attribute '_comparable'

Set torch._dynamo.config.verbose=True for more information


   function: 'G_gs' (/tmp/ipykernel_1680046/2395838894.py:7)
   reasons:  ___check_obj_id(gs[0], 140377215852112)
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.
   function: 'gBAOAB_step_exact' (/tmp/ipykernel_1680046/3710533717.py:1)
   reasons:  ___check_obj_id(gs[0], 140377215852112)
to diagnose recompilation issues, see https://pytorch.org/docs/master/dynamo/troubleshooting.html.
   function: 'J' (/tmp/ipykernel_1680046/2395838894.py:15)
 

In [None]:
n = np.load('t10_0-1000.npy', allow_pickle= True)