In [146]:
import torch
import torch.nn as nn
import numpy as np
from torch.autograd.functional import jacobian
import tqdm
from tqdm import notebook
import random
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')

In [147]:
torch.set_default_dtype(torch.float64)
torch.set_default_device('cuda')

In [148]:
#@torch.compile(mode = "max-autotune")
def G(gs):
    '''
    :param gs: a list of tensor functions
    :return: a function sending a tensor to the stacked matrix of the functions of that tensor
    '''
    def G_gs(tensor):
        x = torch.squeeze(tensor)
        # print("Function input: ",tensor) # checking the input for debugging
        # print("Function output:" , torch.stack([g(tensor) for g in gs],0))
        return torch.stack([g(x) for g in gs], 0)

    return G_gs

#@torch.compile(mode = "max-autotune")
def J(gs, x):
    '''Returns the Jacobian evaluated at x for a list gs of constraint functions'''
    return jacobian(G(gs), torch.squeeze(x))

In [149]:
bones2 = [
    (1, 2),
    (1, 3),
    (3, 4),
    (4, 5),
    (5, 6),
    (1, 7),
    (7, 8),
    (8, 9),
    (9, 10),
    (11, 12),
    (12, 13),
    (13, 14),
    (15, 16),
    (16, 17),
    (17, 18),
]

In [150]:
@torch.compile(mode = "max-autotune")
def length_constraint(i,j, xinit):
    init = torch.squeeze(xinit)
    def constraint_fn(y):
        x = torch.squeeze(y) # will need to change for batched data
        return (x[3*i]- x[3*j])**2 +(x[3*i+1]- x[3*j+1])**2 + (x[3*i+2]- x[3*j+2])**2 - ((init[3*i]- init[3*j])**2 +(init[3*i+1]- init[3*j+1])**2 + (init[3*i+2]- init[3*j+2])**2)
    return constraint_fn

In [151]:
@torch.compile(mode = "max-autotune")
def length_constraint_2(i, xinit):
    init = torch.squeeze(xinit)
    def constraint_fn(y):
        x = torch.squeeze(y)
        return (x[3*i]- 0)**2 +(x[3*i+1]- 0)**2 + (x[3*i+2]- 2)**2 - ((init[3*i]- 0)**2 +(init[3*i+1]- 0)**2 + (init[3*i+2]- 2)**2)
    return constraint_fn

In [152]:
@torch.compile(mode = "max-autotune")
def cotangent_projection(gs):
    def proj(x):
        G = J(gs,x)
        M = torch.eye(G.size()[1])
        L= torch.eye(G.size()[1]) - G.T @ torch.inverse(G @ M @ G.T) @ G @ torch.inverse(M)
        return L
    return proj

In [153]:
class ScoreNet(nn.Module):
  """A time-dependent score-based model."""


  def __init__(self, embed_dim):
    super().__init__()
    self.lin1 = nn.Linear(57,57)
    self.lin2 = nn.Linear(57, 57)
    self.lin3 = nn.Linear(57, 57)
    self.lin4 = nn.Linear(57, 57)
    self.lin5 = nn.Linear(57,57)
    self.act = torch.nn.Sigmoid()

  #@torch.compile(mode="default")
  def forward(self,x,L):
      # setting the fixed points of x
      l = torch.zeros_like(x)
      l[:,0] = x[:,0]
      l[:,1] = x[:,1]
      l[:,2] = -torch.ones_like(x[:,2])*2 + x[:,2]

      x = x - l
      h = self.lin1(x)
      h = self.act(self.lin2(h))
      h = self.act(self.lin3(h))
      h = self.lin4(h)
      #h = self.act(self.lin5(h))

      # projection
      p = torch.unsqueeze(L@ torch.squeeze(h),0)
      h = p
      # NOT normalizing the output
      #h = h/ t[:,None]

      # setting the force on the fixed point to zero
      l2 = torch.zeros_like(h)
      l2[:,0] = h[:,0]
      l2[:,1] = h[:,1]
      l2[:,2] = h[:,2]
      h = h - l2
      return torch.squeeze(h)

In [154]:
score_model = torch.nn.DataParallel(ScoreNet(58))

In [155]:
from torch.utils.data import TensorDataset, DataLoader

In [156]:
from torch.optim import Adam

In [157]:
import warnings
warnings.filterwarnings("ignore")


In [158]:
#torch._dynamo.config.verbose=True

In [159]:
#torch._dynamo.config.suppress_errors = True

# Sampling from the reverse SDE

In [160]:
data_numpy = np.load('train_dataset.npy',allow_pickle=True)

In [161]:
data_tensor = torch.tensor(data_numpy, device = 'cpu')

In [162]:
data_tensor[0].size()

torch.Size([1, 19, 3])

In [163]:
data_tensor = torch.stack([torch.flatten(tens) for tens in data_tensor])

In [164]:
data_tensor[0]

tensor([ 0.0000,  0.0000,  2.0000, -0.0415,  0.4913,  1.9698, -0.0302,  0.6276,
         1.9810,  0.0772,  0.3837,  1.7854,  0.1868,  0.0496,  1.7402,  0.3057,
        -0.2179,  1.7439,  0.3301, -0.3405,  1.7480, -0.1814,  0.4888,  2.0441,
        -0.1794,  0.6404,  2.3646, -0.1866,  0.8687,  2.1815, -0.2104,  0.9399,
         2.0815,  0.1104, -0.2030,  1.9271,  0.1536, -0.7239,  1.8919,  0.1890,
        -1.1902,  1.8484,  0.2527, -1.2213,  2.0356, -0.0817, -0.2263,  2.0288,
        -0.0881, -0.7402,  1.9276, -0.0922, -1.2053,  1.8624, -0.1571, -1.2467,
         2.0472])

In [165]:
my_dataset = TensorDataset(data_tensor)

In [166]:
bones = [
    (0, 1),
    (1, 2),
    (1, 3),
    (3, 4),
    (4, 5),
    (5, 6),
    (1, 7),
    (7, 8),
    (8, 9),
    (9, 10),
    (0, 11),
    (11, 12),
    (12, 13),
    (13, 14),
    (0, 15),
    (15, 16),
    (16, 17),
    (17, 18),
]

In [167]:
def angles_to_joints(angles, lengths,bones):
    pose = torch.zeros((19,3))
    pose[0] = torch.tensor([0,0,2])
    for i in range(len(bones)):
        bone = bones[i]
        r = lengths[i]
        phi = torch.tensor(angles[i][0])
        theta = torch.tensor(angles[i][1])
        pose[bone[1]] = pose[bone[0]] + torch.tensor([r*torch.sin(theta)*torch.cos(phi),r*torch.sin(theta)*torch.sin(phi),r*torch.cos(theta)])
    return pose

In [168]:
def lengths(bones,poses):
    lengths = []
    poses = torch.squeeze(poses).reshape((19,3))
    for bone in bones:
        l = poses[bone[0]] -poses[bone[1]]
        l = torch.sqrt(torch.dot(l,l))
        lengths.append(l)
    return lengths

In [169]:
def uniform_generator(x_init,bones):
    random_angles = torch.vstack([torch.rand(18)*torch.pi*2,torch.acos(2*torch.rand(18) -1)]).T
    lengths1 = lengths(bones, x_init)
    return angles_to_joints(random_angles,lengths1,bones)

In [172]:
#@title Define the loss function (double click to expand or collapse)
def loss2(model, xs,eps=torch.tensor([1e-9]), lam =1):
    """The loss function for moser flows
    """
    # xs is a batch (batch_size, 57) of position
    loss = torch.tensor([0.])
    loss.requires_grad_()

    # first the loss for the first part
    for pose in xs:
        # x is the list of simulated qs
        q = torch.squeeze(x[0]).detach() # x is unsqueezed I can't remember why
        gs = [length_constraint(i,j,torch.squeeze(q)) for (i,j) in bones2]
        gs.append(length_constraint_2(1,torch.squeeze(q)))
        gs.append(length_constraint_2(11,torch.squeeze(q)))
        gs.append(length_constraint_2(15,torch.squeeze(q)))
        L_fn = cotangent_projection(gs)
        sim_x = torch.unsqueeze(pose,0)
        L = L_fn(sim_x)
        divergence = torch.trace(torch.squeeze(torch.autograd.functional.jacobian(model, inputs=(sim_x,L), create_graph=True,strict=True)[0]))
        exp = max(eps,(torch.pi*4)**(-18) - divergence)
        loss = loss + torch.log(exp)/len(xs)


    # and now the second part??
    for i in range(30):
        random_pose = torch.flatten(uniform_generator(q,bones))
        sim_y = torch.unsqueeze(random_pose, 0)
        L = L_fn(sim_y)
        divergence = torch.trace(torch.squeeze(torch.autograd.functional.jacobian(model, inputs=(sim_y,L), create_graph=True,strict=True)[0]))
        l = eps - min(eps, (torch.pi*4)**(-18) - divergence)
        loss = loss + lam/30 * l
    return loss

In [None]:
## size of a mini-batch
## learning rate
lr=1e-5 #@param {'type':'number'}
batch_size =  20 #@param {'type':'integer'}
## learning rate
dataloader = DataLoader(my_dataset,batch_size=batch_size, shuffle=True,generator=torch.Generator(device='cuda'))
n_epochs = 20
tqdm_epoch = tqdm.notebook.trange(n_epochs)
optimizer = Adam(score_model.parameters(), lr=lr)
i = 0
epoch_losses =[]
for epoch in tqdm_epoch:
    t_dl =tqdm.tqdm(dataloader)
    avg_loss = 0.
    num_items = 0
    for pw in t_dl:
        x = pw[0].to('cuda')
        i += 1
        loss = loss2(score_model, x)
        t_dl.set_description(f"Loss = {loss.item()}")
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        avg_loss += loss.item() * x.shape[0]
        num_items += x.shape[0]
        epoch_losses.append(avg_loss / num_items)
    tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
    torch.save(score_model.state_dict(), 'ckpt_moser_flow.pth')

  0%|          | 0/20 [00:00<?, ?it/s]


  0%|          | 0/2000 [00:00<?, ?it/s][A
Loss = -7.225596184098042:   0%|          | 0/2000 [00:10<?, ?it/s][A
Loss = -7.225596184098042:   0%|          | 1/2000 [00:11<6:35:22, 11.87s/it][A
Loss = -8.208980653145924:   0%|          | 1/2000 [00:22<6:35:22, 11.87s/it][A
Loss = -8.208980653145924:   0%|          | 2/2000 [00:23<6:32:19, 11.78s/it][A
Loss = -7.426888056030357:   0%|          | 2/2000 [00:33<6:32:19, 11.78s/it][A
Loss = -7.426888056030357:   0%|          | 3/2000 [00:35<6:31:27, 11.76s/it][A
Loss = -11.190118251153708:   0%|          | 3/2000 [00:45<6:31:27, 11.76s/it][A
Loss = -11.190118251153708:   0%|          | 4/2000 [00:46<6:28:47, 11.69s/it][A
Loss = -10.301072675413861:   0%|          | 4/2000 [00:57<6:28:47, 11.69s/it][A
Loss = -10.301072675413861:   0%|          | 5/2000 [00:58<6:28:38, 11.69s/it][A
Loss = -8.199716565234308:   0%|          | 5/2000 [01:09<6:28:38, 11.69s/it] [A
Loss = -8.199716565234308:   0%|          | 6/2000 [01:10<6:33:53, 11

In [None]:
torch.save(score_model.state_dict(), 'ckpt_moser_flow.pth')

In [None]:
import matplotlib.pyplot as plt

In [None]:
q=data_tensor[0]

In [None]:
for i in range(100):
    random_pose = torch.flatten(uniform_generator(q,bones))
    q = torch.squeeze(random_pose) # x is unsqueezed I can't remember why
    gs = [length_constraint(i,j,torch.squeeze(q)) for (i,j) in bones2]
    gs.append(length_constraint_2(1,torch.squeeze(q)))
    gs.append(length_constraint_2(11,torch.squeeze(q)))
    gs.append(length_constraint_2(15,torch.squeeze(q)))
    L_fn = cotangent_projection(gs)
    sim_x = torch.unsqueeze(random_pose,0)
    L = L_fn(sim_x)
    print(density(sim_x))

In [None]:
### now finding the density
def density(x,L):
    divergence = torch.trace(torch.squeeze(torch.autograd.functional.jacobian(score_model, inputs=(x,L), create_graph=True,strict=True)[0]))
    return 1/(4*torch.pi)**(18) - divergence

In [None]:
plt.plot(epoch_losses)

In [None]:
## size of a mini-batch
## learning rate
lr=1e-3 #@param {'type':'number'}
batch_size =  20 #@param {'type':'integer'}
## learning rate
dataloader = DataLoader(my_dataset,batch_size=batch_size, shuffle=True)
n_epochs = 20
tqdm_epoch = tqdm.notebook.trange(n_epochs)
optimizer = Adam(score_model.parameters(), lr=lr)
i = 0
for epoch in tqdm_epoch:
    t_dl =tqdm.tqdm(dataloader)
    avg_loss = 0.
    num_items = 0
    for pw in t_dl:
        x = pw[0].cpu()
        i += 1
        loss = loss2(score_model, x)
        t_dl.set_description(f"Loss = {loss.item()}")
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        avg_loss += loss.item() * x.shape[0]
        num_items += x.shape[0]
        epoch_losses.append(avg_loss / num_items)
    tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
    torch.save(score_model.state_dict(), 'ckpt_improved.pth')

In [None]:
model = ScoreNet(58)

In [None]:
x = torch.rand(57)
x.requires_grad_()

# generating a uniform position

In [None]:
data = np.load('train_dataset.npy')

since the original SDE was just Brownian motion, the reverse is simply:

In [None]:
data = np.load('train_dataset.npy')
x_init = torch.squeeze(torch.tensor(data[0]))

In [None]:
ckpt = torch.load('ckpt.pth')
score_model.load_state_dict(ckpt)

In [None]:
q_init = uniform_generator(torch.squeeze(x_init),bones)

In [None]:
torch._dynamo.config.suppress_errors = True

In [None]:
@torch.compile(mode = "max-autotune")
def run_simulation(q_init,steps):
  h = 1/steps
  positions = []
  q = torch.flatten(torch.squeeze(q_init))
  gs = [length_constraint(i,j,torch.squeeze(q)) for (i,j) in bones2]
  gs.append(length_constraint_2(1,torch.squeeze(q)))
  gs.append(length_constraint_2(11,torch.squeeze(q)))
  gs.append(length_constraint_2(15,torch.squeeze(q)))
  p = torch.zeros_like(q)
  M = torch.eye(q.size()[0])
  L_fn = cotangent_projection(gs)
  @torch.compile()
  def force(y):
    with torch.no_grad():
      x = torch.unsqueeze(y,0)
      L = L_fn(x)
      score = -score_model(x,time,L)
    return torch.squeeze(score)
  for step in range(1,steps):
      print(step)
      time = torch.tensor([1-(step)*h])
      q1 = q
      q, p = gBAOAB_step_exact(q,p,force, gs, h,M, 1,1,1,10**(-11))
      print(f"Movement of q: {torch.linalg.norm(q-q1)}")
      positions.append(q)
  return positions

In [None]:
positions = run_simulation(q_init, 10000)

In [None]:
pos = np.array([p.detach() for p in positions])
np.save('reverse_pos_new.npy',pos)