In [1]:
import os
import numpy as np
import torch
import torch.autograd.profiler as profiler
from torch_batch_svd import svd
import MDAnalysis as md
from numba import jit
from shapeGMM import _traj_tools as traj_tools

In [2]:
delta = 100
# read trajectory
data_path = '../../DESRES-Trajectory_pnas2012-2f4k-360K-protein/pnas2012-2f4k-360K-protein/'
selection = "name CA and not resid 42 76"
#selection = "bynum 5:204"
#selection = "all"
# LOAD DATA
prmtopFileName =  data_path + 'pnas2012-2f4k-360K-protein.pdb'
trajFiles = [data_path + files for files in sorted(os.listdir(data_path)) if files.endswith('.dcd')]
coord = md.Universe(prmtopFileName,trajFiles)
sel = coord.select_atoms(selection)
print("Number of atoms in trajectory:", coord.atoms.n_atoms)
print("Number of frames in trajectory:",coord.trajectory.n_frames)
print("Number of atoms being analyzed:",sel.n_atoms)
print("Number of frames being analyzed:",coord.trajectory.n_frames//delta+1)
traj = np.empty((coord.trajectory.n_frames//delta+1,sel.n_atoms,3),dtype=float)
count = 0
for ts in coord.trajectory[::delta]:
    traj[count,:,:] = sel.positions#-sel.center_of_geometry()
    count += 1

Number of atoms in trajectory: 577
Number of frames in trajectory: 1526041
Number of atoms being analyzed: 33
Number of frames being analyzed: 15261


In [3]:
import torch_align
dtype=torch.float32
device=torch.device("cuda:0")
traj_tensor = torch.tensor(traj,dtype=dtype,device=device)
torch_align.torch_remove_center_of_geometry(traj_tensor)
n_frames = traj.shape[0]
weights = np.ones(n_frames)
weights /= np.sum(weights)
weight_tensor = torch.tensor(weights,dtype=dtype,device=device)
avg, precision, lpdet  = torch_align.torch_iterative_align_kronecker_weighted(traj_tensor, weight_tensor, dtype=torch.float32, device=torch.device("cuda:0"), thresh=1e-3, verbose=True)

tensor(-52.7160, device='cuda:0', dtype=torch.float64)
tensor(-51.5468, device='cuda:0', dtype=torch.float64)
tensor(-50.8493, device='cuda:0', dtype=torch.float64)
tensor(-50.0606, device='cuda:0', dtype=torch.float64)
tensor(-49.1118, device='cuda:0', dtype=torch.float64)
tensor(-48.1023, device='cuda:0', dtype=torch.float64)
tensor(-47.2618, device='cuda:0', dtype=torch.float64)
tensor(-46.6332, device='cuda:0', dtype=torch.float64)
tensor(-46.1532, device='cuda:0', dtype=torch.float64)
tensor(-45.6956, device='cuda:0', dtype=torch.float64)
tensor(-45.2119, device='cuda:0', dtype=torch.float64)
tensor(-44.7064, device='cuda:0', dtype=torch.float64)
tensor(-44.3232, device='cuda:0', dtype=torch.float64)
tensor(-44.1066, device='cuda:0', dtype=torch.float64)
tensor(-43.9780, device='cuda:0', dtype=torch.float64)
tensor(-43.8858, device='cuda:0', dtype=torch.float64)
tensor(-43.8063, device='cuda:0', dtype=torch.float64)
tensor(-43.7220, device='cuda:0', dtype=torch.float64)
tensor(-43

In [17]:
disp = traj_tensor - avg
print(disp.shape)

torch.Size([15261, 33, 3])


In [19]:
print(disp[:,:,0].view(-1,1,33).shape)

torch.Size([15261, 1, 33])


In [20]:
torch.matmul(disp[:,:,0].view(-1,1,n_atoms),disp[:,:,0].view(-1,n_atoms,1)).shape

torch.Size([15261, 1, 1])

In [5]:
print(var)

tensor(4.7810, device='cuda:0', dtype=torch.float64)


In [24]:
def torch_remove_center_of_geometry(traj_tensor, dtype=torch.float32, device=torch.device("cuda:0")):
    # meta data
    n_frames = traj.shape[0]
    n_atoms = traj.shape[1]

    cog = torch.mean(traj_tensor.to(torch.float64),1,False)
    for i in range(n_frames):
        traj_tensor[i] -= cog[i]
    del cog
    torch.cuda.empty_cache()

In [30]:
dtype=torch.float32
device=torch.device("cuda:0")
traj_tensor = torch.tensor(traj,dtype=dtype,device=device)
print(torch.mean(traj_tensor.to(torch.float64),1,False))
torch_remove_center_of_geometry(traj_tensor)
print(torch.mean(traj_tensor.to(torch.float64),1,False))

tensor([[  1.9979,   5.8565,  -3.6840],
        [ -4.8365, -25.3204,  -8.6482],
        [  7.7014,  -2.8374,  12.1676],
        ...,
        [ 17.2672,  -4.2508, -13.3152],
        [-19.9200,  21.4330,  -0.5505],
        [  2.5963, -22.8508,  13.9161]], device='cuda:0', dtype=torch.float64)
tensor([[ 3.5672e-08, -3.6124e-08,  1.1560e-07],
        [ 2.8673e-08, -3.3866e-09,  9.0310e-09],
        [ 4.4252e-08,  6.5023e-08, -3.5672e-08],
        ...,
        [ 5.8702e-09, -1.4675e-08,  8.6923e-08],
        [-8.1279e-08,  6.1411e-08, -1.5578e-08],
        [-1.2643e-08,  6.5023e-08,  1.2282e-07]], device='cuda:0',
       dtype=torch.float64)


In [8]:
# test weighted alignment
import torch_align
n_frames = traj.shape[0]
n_atoms = traj.shape[1]
weights = np.random.random(n_frames)
weights = np.ones(n_frames)
weights /= np.sum(weights)
traj = torch_align.torch_remove_center_of_geometry(traj)
avg_weighted, traj_aligned_weighted, var_weighted = torch_align.torch_align_uniform_weighted(traj, weights, dtype=torch.float32, device=torch.device("cuda:0"), thresh=1e-3, verbose=True)

-140.69652950129063
-123.27849629441
-123.10511106747325
-123.10311220817795
-123.10308607006304


In [6]:
avg, traj_aligned, var = torch_align.torch_align_uniform(traj, dtype=torch.float32, device=torch.device("cuda:0"), thresh=1e-3, verbose=True)

-140.69653134979797
-123.27849755290518
-123.1051121858449
-123.10311314101853
-123.10308647354991


In [39]:
avg, traj_aligned, var = torch_align_uniform(traj, dtype=torch.float32, device=torch.device("cuda:0"), thresh=1e-3, verbose=True)

-140.69653134979797
-123.27849755290518
-123.1051121858449
-123.10311314101853
-123.10308647354991
Total elapsed time: 124.589
log_lik time 0.738 0.593


In [9]:
print(var, var_weighted)

4.780989984875232 4.780989944686343


In [10]:
print(avg-avg_weighted)

[[ 9.5367432e-07  0.0000000e+00  0.0000000e+00]
 [ 0.0000000e+00 -2.3841858e-07  0.0000000e+00]
 [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
 [ 0.0000000e+00  1.1920929e-07  0.0000000e+00]
 [ 4.7683716e-07  0.0000000e+00 -1.1920929e-07]
 [ 0.0000000e+00  0.0000000e+00 -2.3841858e-07]
 [ 0.0000000e+00  2.3841858e-07 -1.4901161e-07]
 [ 4.7683716e-07  1.7881393e-07 -2.3841858e-07]
 [ 0.0000000e+00  2.3841858e-07 -1.1920929e-07]
 [ 0.0000000e+00  0.0000000e+00 -5.9604645e-08]
 [ 0.0000000e+00  2.3841858e-07  0.0000000e+00]
 [ 0.0000000e+00  0.0000000e+00  0.0000000e+00]
 [ 0.0000000e+00  2.9802322e-08 -4.7683716e-07]
 [-2.9802322e-08 -5.9604645e-08  9.5367432e-07]
 [-5.9604645e-08 -2.3841858e-07 -4.7683716e-07]
 [ 0.0000000e+00 -1.1920929e-07  0.0000000e+00]
 [ 0.0000000e+00 -1.4901161e-07  4.7683716e-07]
 [ 0.0000000e+00  0.0000000e+00  4.7683716e-07]
 [-4.7683716e-07 -2.3841858e-07  0.0000000e+00]
 [-1.9073486e-06 -1.1920929e-07  2.3841858e-07]
 [ 9.5367432e-07 -2.3841858e-07  1.19209

In [38]:

def _torch_uniform_log_likelihood(disp):
    # meta data
    n_frames = disp.shape[0]
    n_atoms = disp.shape[1]
    # reshape displacement 
    disp = torch.reshape(disp,(n_frames*n_atoms*3,1))
    # compute variance
    var = torch.sum(disp*disp).cpu().numpy()
    var /= n_frames*3*(n_atoms-1)
    # compute log likelihood (per frame)
    log_lik = -1.5*(n_atoms-1)*(np.log(var) + 1)
    return log_lik, var

def torch_align_uniform(traj, dtype=torch.float32, device=torch.device("cuda:0"), thresh=1e-3, verbose=False):
    # meta data
    n_frames = traj.shape[0]
    n_atoms = traj.shape[1]
    # timing data
    total_start = torch.cuda.Event(enable_timing=True)
    total_stop = torch.cuda.Event(enable_timing=True)
    log_lik_start = torch.cuda.Event(enable_timing=True)
    log_lik_stop = torch.cuda.Event(enable_timing=True)
    log_lik_elapsed = 0.0
    total_start.record()
    
    # pass trajectory to device
    traj_tensor = torch.tensor(traj,dtype=dtype,device=device)

    # initialize with average as the first frame (arbitrary choice)
    avg = traj_tensor[0]

    delta_log_lik = thresh+10
    old_log_lik = 0
    while (delta_log_lik > thresh):
        # compute correlation matrices using batched matmul
        c_mats = torch.matmul(avg.T,traj_tensor)
        # perfrom SVD of c_mats using batched SVD
        #u, s, v = torch.linalg.svd(c_mats)
        u, s, v = svd(c_mats)
        # ensure true rotation by correcting sign of determinant
        prod_dets = torch.linalg.det(u)*torch.linalg.det(v)
        u[:,0,-1] *= prod_dets
        u[:,1,-1] *= prod_dets
        u[:,2,-1] *= prod_dets
        #rot_mat = torch.transpose(torch.matmul(u,v),1,2)
        rot_mat = torch.transpose(torch.matmul(u,torch.transpose(v,1,2)),1,2)
        # do rotation
        traj_tensor = torch.matmul(traj_tensor,rot_mat)
        # compute new average
        avg = torch.mean(traj_tensor,0,False)
        disp = (traj_tensor - avg).to(torch.float64)
        # compute log likelihood and variance
        log_lik_start.record()
        log_lik, var = _torch_uniform_log_likelihood(disp)
        log_lik_stop.record()
        torch.cuda.synchronize()
        log_lik_elapsed += log_lik_start.elapsed_time(log_lik_stop)
        delta_log_lik = abs(log_lik - old_log_lik)
        if verbose==True:
            print(log_lik)
        old_log_lik = log_lik
    total_stop.record()
    torch.cuda.synchronize()
    total_time = total_start.elapsed_time(total_stop)
    print("Total elapsed time:",np.round(total_time,3))
    print("log_lik time", np.round(log_lik_elapsed,3), np.round(log_lik_elapsed/total_time*100,3))
    
    return avg.cpu().numpy(), traj_tensor.cpu().numpy(), var


# Kronecker alignment with weights

In [14]:
import torch_align
avg_k, traj_aligned_k, covar_k, lpdet_k = torch_align.torch_align_kronecker(traj, dtype=torch.float32, device=torch.device("cuda:0"), thresh=1e-2, verbose=True)

-52.715989010402126
-51.54684564122217
-50.849343027704606
-50.06061663526944
-49.11179997354897
-48.10234065492433
-47.261789951434324
-46.63319579144418
-46.15324130695882
-45.69556600354173
-45.2118659987544
-44.7064055980933
-44.323208356653254
-44.1065659604726
-43.978021076940394
-43.88585042749277
-43.8063181398916
-43.72201515484187
-43.61273587022087
-43.447898755076956
-43.185651755806475
-42.80579315509976
-42.38850053136583
-42.06908340540164
-41.86311241482156
-41.72032555904979
-41.6101700456448
-41.51912685730258
-41.44043674303579
-41.37025851289659
-41.3062527890569
-41.24693817800421
-41.191352177481356
-41.13885893490008
-41.08903054238571
-41.04158124037112
-40.99631598591567
-40.95311002158887
-40.91188663727155
-40.87259935109681
-40.8352352238429
-40.799792965338675
-40.766289549389874
-40.734744194862955
-40.70518110203772
-40.67761412009876
-40.65205468778431
-40.62849311903964
-40.6069021107015
-40.587237203044864
-40.569437408896945
-40.55341981763551
-40.539

In [16]:
print(covar_k)

[[ 3.31032081e-01 -4.38182215e-01  1.65037451e-01 ... -9.27287032e-03
   7.37842507e-03 -3.86456582e-03]
 [-4.38182215e-01  1.18122577e+00 -9.45447593e-01 ...  1.62949663e-03
   1.87891832e-03 -1.38323177e-03]
 [ 1.65037451e-01 -9.45447593e-01  1.56786380e+00 ... -2.58739090e-02
   2.69759649e-02 -4.33100926e-03]
 ...
 [-9.27287032e-03  1.62949663e-03 -2.58739090e-02 ...  1.46806588e+00
  -6.37120145e-01  6.37218212e-03]
 [ 7.37842507e-03  1.87891832e-03  2.69759649e-02 ... -6.37120145e-01
   1.19078585e+00 -6.12696263e-01]
 [-3.86456582e-03 -1.38323177e-03 -4.33100926e-03 ...  6.37218212e-03
  -6.12696263e-01  5.08183947e-01]]


In [8]:
print(np.linalg.det(covar_k)**3)

2.1444792689533237e-20


In [7]:
print(np.linalg.det(np.kron(covar_k,np.identity(3))))

2.144479268954352e-20


In [11]:
def _torch_pseudo_inv(sigma, dtype=torch.float64, device=torch.device("cuda:0"),EigenValueThresh=1e-10):
    N = sigma.shape[0]
    e, v = torch.linalg.eigh(sigma)
    pinv = torch.tensor(np.zeros(sigma.shape),dtype=dtype,device=device)
    lpdet = torch.tensor(0.0,dtype=dtype,device=device)
    for i in range(N):
        if (e[i] > EigenValueThresh):
            lpdet += torch.log(e[i])
            pinv += 1.0/e[i]*torch.outer(v[:,i],v[:,i])
    return pinv, lpdet

def _torch_kronecker_weighted_log_lik(disp, weight_tensor, precision, lpdet):
    # meta data
    n_frames = disp.shape[0]
    n_atoms = disp.shape[1]
    # compute log Likelihood for all points
    #log_lik = torch.trace(torch.sum(torch.matmul(torch.transpose(disp,1,2),torch.matmul(precision,disp)),0))
    log_lik = torch.sum(torch.matmul(torch.transpose(torch.reshape(disp[:,:,0],(n_frames,n_atoms,1)),1,2),torch.matmul(precision,torch.reshape(disp[:,:,0],(n_frames,n_atoms,1))))*weight_tensor.view(-1,1,1),0)
    log_lik += torch.sum(torch.matmul(torch.transpose(torch.reshape(disp[:,:,1],(n_frames,n_atoms,1)),1,2),torch.matmul(precision,torch.reshape(disp[:,:,1],(n_frames,n_atoms,1))))*weight_tensor.view(-1,1,1),0)
    log_lik += torch.sum(torch.matmul(torch.transpose(torch.reshape(disp[:,:,2],(n_frames,n_atoms,1)),1,2),torch.matmul(precision,torch.reshape(disp[:,:,2],(n_frames,n_atoms,1))))*weight_tensor.view(-1,1,1),0)
    log_lik *= -0.5
    log_lik -= 1.5 * lpdet
    return log_lik

def torch_align_kronecker_weighted(traj, weights, stride=1000, dtype=torch.float32, device=torch.device("cuda:0"), thresh=1e-3, verbose=False):
    # meta data
    n_frames = traj.shape[0]
    n_atoms = traj.shape[1]
    
    # pass trajectory to device
    traj_tensor = torch.tensor(traj, dtype=dtype, device=device)
    weight_tensor = torch.tensor(weights,dtype=dtype,device=device)
    covar_norm = torch.tensor(1/3,dtype=torch.float64,device=device)
    
    # initialize with average as the first frame (arbitrary choice)
    weighted_avg = traj_tensor[0].T
    
    delta_log_lik = thresh+10
    old_log_lik = 0
    while (delta_log_lik > thresh):
        # compute correlation matrices using batched matmul
        c_mats = torch.matmul(weighted_avg,traj_tensor)
        # perfrom SVD of c_mats using batched SVD
        #u, s, v = torch.linalg.svd(c_mats)
        u, s, v = svd(c_mats)
        # ensure true rotation by correcting sign of determinant 
        prod_dets = torch.linalg.det(u)*torch.linalg.det(v)
        u[:,0,-1] *= prod_dets
        u[:,1,-1] *= prod_dets
        u[:,2,-1] *= prod_dets
        #rot_mat = torch.transpose(torch.matmul(u,v),1,2)
        rot_mat = torch.transpose(torch.matmul(u,torch.transpose(v,1,2)),1,2)
        # do rotation
        traj_tensor = torch.matmul(traj_tensor,rot_mat)
        # compute new average
        avg = torch.sum((traj_tensor*weight_tensor.view(-1,1,1)).to(torch.float64),0,False)
        disp = traj_tensor.to(torch.float64) - avg
        # compute covar using strided data
        covar = torch.tensor(np.zeros((n_atoms,n_atoms)),dtype=torch.float64,device=device)
        for frame in range(0,n_frames,stride):
            covar += torch.sum(torch.matmul(disp[frame:frame+stride],torch.transpose(disp[frame:frame+stride],1,2))*weight_tensor[frame:frame+stride].view(-1,1,1),0)
        covar *= covar_norm
        # log likelihood
        precision, lpdet = _torch_pseudo_inv(covar,dtype=torch.float64,device=device)
        log_lik = _torch_kronecker_weighted_log_lik(disp, weight_tensor, precision, lpdet).cpu().numpy()[0][0]
        delta_log_lik = abs(log_lik - old_log_lik)
        if verbose == True:
            print(log_lik)
        old_log_lik = log_lik
        weighted_avg = torch.matmul(avg.T,precision).to(dtype)

    return avg.cpu().numpy(), traj_tensor.cpu().numpy(), precision.cpu().numpy(), lpdet.cpu().numpy()


In [17]:
import torch_align
n_frames = traj.shape[0]
n_atoms = traj.shape[1]
weights = np.random.random(n_frames)
weights = np.ones(n_frames)
weights /= np.sum(weights)
traj = torch_align.torch_remove_center_of_geometry(traj)
avg_k_w, traj_aligned_k_w, covar_k_w, lpdet_k_w = torch_align_kronecker_weighted(traj, weights, dtype=torch.float32, device=torch.device("cuda:0"), thresh=1e-2, verbose=True)

-52.71598691994702
-51.546843140383714
-50.849340455006676
-50.06061458077904
-49.111798186861776
-48.1023388702704
-47.26178880839869
-46.633194327417264
-46.153241027203244
-45.69556600192617
-45.211866439701296
-44.70640657826238
-44.32320925776384
-44.106566709637725
-43.97802138869429
-43.88585045208407
-43.80631820040114
-43.722015108823506
-43.612735976826066
-43.447899483931046
-43.18565124910742
-42.80579266584211
-42.38850135434713
-42.06908417476599
-41.86311250307963
-41.72032509244113
-41.61017083476501
-41.51912757631702
-41.44043741266851
-41.37026041025876
-41.30625556594816
-41.24694079800153
-41.19135506027813
-41.138862194541126
-41.08903369229002
-41.041583041275395
-40.99631705364416
-40.953111338329606
-40.911887990932605
-40.87260203008973
-40.83523808764217
-40.79979666787079
-40.76629262027978
-40.73474630582153
-40.705181555217585
-40.67761548445738
-40.6520540911081
-40.62849104426222
-40.606901297707296
-40.5872383559127
-40.56943922172136
-40.55342499802516

In [18]:
print(lpdet_k, lpdet_k_w)

-4.994473916255258 -4.996567204277507


In [19]:
print(covar_k - covar_k_w)

[[-2.17265299e-05  2.85760171e-05 -1.06941514e-05 ...  5.64416062e-07
  -2.99066525e-07  1.47581722e-07]
 [ 2.85760171e-05 -7.75703521e-05  6.20422586e-05 ... -1.63928173e-07
  -8.21958679e-08 -7.52041227e-08]
 [-1.06941514e-05  6.20422586e-05 -1.02697522e-04 ...  1.60884777e-06
  -1.54430193e-06  3.31240664e-07]
 ...
 [ 5.64416062e-07 -1.63928173e-07  1.60884777e-06 ... -9.62735086e-05
   4.20674589e-05 -6.57640060e-07]
 [-2.99066525e-07 -8.21958679e-08 -1.54430193e-06 ...  4.20674589e-05
  -7.78202861e-05  4.01808260e-05]
 [ 1.47581722e-07 -7.52041227e-08  3.31240664e-07 ... -6.57640060e-07
   4.01808260e-05 -3.33809819e-05]]


In [22]:
print(covar_k_w)

[[ 3.31053808e-01 -4.38210791e-01  1.65048146e-01 ... -9.27343474e-03
   7.37872414e-03 -3.86471340e-03]
 [-4.38210791e-01  1.18130334e+00 -9.45509635e-01 ...  1.62966056e-03
   1.87900052e-03 -1.38315656e-03]
 [ 1.65048146e-01 -9.45509635e-01  1.56796650e+00 ... -2.58755178e-02
   2.69775092e-02 -4.33134050e-03]
 ...
 [-9.27343474e-03  1.62966056e-03 -2.58755178e-02 ...  1.46816215e+00
  -6.37162212e-01  6.37283976e-03]
 [ 7.37872414e-03  1.87900052e-03  2.69775092e-02 ... -6.37162212e-01
   1.19086367e+00 -6.12736444e-01]
 [-3.86471340e-03 -1.38315656e-03 -4.33134050e-03 ...  6.37283976e-03
  -6.12736444e-01  5.08217328e-01]]
