In [None]:
import os
import numpy as np
import torch
import torch.autograd.profiler as profiler
from torch_batch_svd import svd
import MDAnalysis as md
from numba import jit
from shapeGMM import _traj_tools as traj_tools

In [None]:
EigenValueThresh = 1E-10

@jit(nopython=True)
def pseudo_lpdet_inv(sigma):
    N = sigma.shape[0]
    e, v = np.linalg.eigh(sigma)
    precision = np.zeros(sigma.shape,dtype=np.float64)
    lpdet = 0.0
    rank = 0
    for i in range(N):
        if (e[i] > EigenValueThresh):
            lpdet += np.log(e[i])
            precision += 1.0/e[i]*np.outer(v[:,i],v[:,i])
            rank += 1
    return lpdet, precision, rank

@jit(nopython=True)
def weight_kabsch_log_lik(x, mu, precision, lpdet):
    # meta data
    n_frames = x.shape[0]
    # compute log Likelihood for all points
    log_lik = 0.0
    for i in range(n_frames):
        #disp = x[i] - mu
        for j in range(3):
            disp = x[i,:,j] - mu[:,j]
            log_lik += np.dot(disp,np.dot(precision,disp))
    log_lik += 3 * n_frames * lpdet
    log_lik /= -2*n_frames
    return log_lik

@jit(nopython=True)
def weight_kabsch_rotate(mobile, target, weights):
    correlation_matrix = np.dot(np.transpose(mobile), np.dot(weights, target))
    V, S, W_tr = np.linalg.svd(correlation_matrix)
    if np.linalg.det(V) * np.linalg.det(W_tr) < 0.0:
        V[:, -1] = -V[:, -1]
    rotation = np.dot(V, W_tr)
    mobile_prime = np.dot(mobile,rotation)
    return mobile_prime

@jit(nopython=True)
def fast_weight_kabsch_rotate(mobile, weights_target):
    correlation_matrix = np.dot(np.transpose(mobile), weights_target)
    V, S, W_tr = np.linalg.svd(correlation_matrix)
    if np.linalg.det(V) * np.linalg.det(W_tr) < 0.0:
        V[:, -1] = -V[:, -1]
    rotation = np.dot(V, W_tr)
    mobile_prime = np.dot(mobile,rotation)
    return mobile_prime

@jit(nopython=True)
def covar_NxN_from_traj(disp):
    # trajectory metadata
    n_frames = disp.shape[0]
    n_atoms = disp.shape[1]
    # declare covar
    covar = np.zeros((n_atoms,n_atoms),np.float64)
    # loop and compute
    for ts in range(n_frames):
        covar += np.dot(disp[ts],disp[ts].T)
    # symmetrize and average covar
    covar /= 3*(n_frames-1)
    # done, return
    return covar

@jit(nopython=True)
def traj_iterative_average_precision_weighted_kabsch(traj_data,thresh=1E-3,max_steps=300):
    # trajectory metadata
    n_frames = traj_data.shape[0]
    n_atoms = traj_data.shape[1]
    nDim = traj_data.shape[2]
    # Initialize with uniform weighted Kabsch
    avg, aligned_pos = traj_iterative_average(traj_data,thresh)
    # compute NxN covar
    covar = covar_NxN_from_traj(aligned_pos-avg)
    # determine precision and pseudo determinant 
    lpdet, precision, rank = pseudo_lpdet_inv(covar)
    # compute log likelihood
    log_lik = weight_kabsch_log_lik(aligned_pos, avg, precision, lpdet)
    # perform iterative alignment and average to converge average
    log_lik_diff = 10+thresh
    step = 0
    while log_lik_diff > thresh and step < max_steps:
        # rezero new average
        new_avg = np.zeros((n_atoms,nDim),dtype=np.float64)
        # align trajectory to average and accumulate new average
        weights_target = np.dot(precision,avg)
        for ts in range(n_frames):
            aligned_pos[ts] = fast_weight_kabsch_rotate(aligned_pos[ts], weights_target)
            new_avg += aligned_pos[ts]
        # finish average
        new_avg /= n_frames
        # compute new Kabsch Weights
        covar = covar_NxN_from_traj(aligned_pos-new_avg)
        # determine precision and pseudo determinant 
        lpdet, precision, rank = pseudo_lpdet_inv(covar)
        # compute log likelihood
        new_log_lik = weight_kabsch_log_lik(aligned_pos, new_avg, precision, lpdet)
        log_lik_diff = np.abs(new_log_lik-log_lik)
        log_lik = new_log_lik
        avg = np.copy(new_avg)
        step += 1
        #print(step, log_lik)
    return avg, aligned_pos, precision, lpdet

###################################
####. Uniform
####################################

@jit(nopython=True)
def kabsch_rotate(mobile, target):
    correlation_matrix = np.dot(np.transpose(mobile), target)
    V, S, W_tr = np.linalg.svd(correlation_matrix)
    if np.linalg.det(V) * np.linalg.det(W_tr) < 0.0:
        V[:, -1] = -V[:, -1]
    rotation = np.dot(V, W_tr)
    mobile_prime = np.dot(mobile,rotation)
    return mobile_prime

@jit(nopython=True)
def uniform_kabsch_log_lik(x, mu):
    # meta data
    n_frames = x.shape[0]
    n_atoms = x.shape[1]
    # compute log Likelihood for all points
    log_lik = 0.0
    sampleVar = 0.0
    for i in range(n_frames):
        for j in range(3):
            disp = x[i,:,j] - mu[:,j]
            temp = np.dot(disp,disp)
            sampleVar += temp
            log_lik += temp
    # finish variance
    sampleVar /= (n_frames-1)*3*(n_atoms-1)
    log_lik /= sampleVar
    log_lik +=  n_frames * 3 * (n_atoms-1) * np.log(sampleVar)
    log_lik /= -2*n_frames
    return log_lik

# compute the average structure from trajectory data
@jit(nopython=True)
def traj_iterative_average(traj_data,thresh=1E-3):
    # trajectory metadata
    n_frames = traj_data.shape[0]
    n_atoms = traj_data.shape[1]
    nDim = traj_data.shape[2]
    # create numpy array of aligned positions
    aligned_pos = np.copy(traj_data).astype(np.float64)
    # start be removing COG translation
    for ts in range(n_frames):
        mu = np.zeros(nDim)
        for atom in range(n_atoms):
            mu += aligned_pos[ts,atom]
        mu /= n_atoms
        aligned_pos[ts] -= mu
    # Initialize average as first frame
    avg = np.copy(aligned_pos[0]).astype(np.float64)
    log_lik = uniform_kabsch_log_lik(aligned_pos,avg)
    # perform iterative alignment and average to converge log likelihood
    log_lik_diff = 10
    count = 1
    while log_lik_diff > thresh:
        # rezero new average
        new_avg = np.zeros((n_atoms,nDim),dtype=np.float64)
        # align trajectory to average and accumulate new average
        for ts in range(n_frames):
            aligned_pos[ts] = kabsch_rotate(aligned_pos[ts], avg)
            new_avg += aligned_pos[ts]
        # finish average
        new_avg /= n_frames
        # compute log likelihood
        new_log_lik = uniform_kabsch_log_lik(aligned_pos,avg)
        #print(new_log_lik)
        log_lik_diff = np.abs(new_log_lik-log_lik)
        log_lik = new_log_lik
        # copy new average
        avg = np.copy(new_avg)
        count += 1
    return avg, aligned_pos

In [None]:
def torch_align_uniform(traj, dtype=torch.float64, device=torch.device("cuda:0"), thresh=1e-3, verbose=False):
    # meta data
    n_frames = traj.shape[0]
    n_atoms = traj.shape[1]

    # pass trajectory to device
    traj_tensor = torch.tensor(traj,dtype=dtype,device=device)

    # initialize with average as the first frame (arbitrary choice)
    avg = traj_tensor[0]

    delta_log_lik = thresh+10
    old_log_lik = 0
    while (delta_log_lik > thresh):
        # compute correlation matrices using batched matmul
        c_mats = torch.matmul(avg.T,traj_tensor)
        # perfrom SVD of c_mats using batched SVD
        u, s, v = torch.linalg.svd(c_mats)
        # ensure true rotation by correcting sign of determinant
        prod_dets = torch.linalg.det(u)*torch.linalg.det(v)
        u[:,0,-1] *= prod_dets
        u[:,1,-1] *= prod_dets
        u[:,2,-1] *= prod_dets
        rot_mat = torch.transpose(torch.matmul(u,v),1,2)
        # do rotation
        traj_tensor = torch.matmul(traj_tensor,rot_mat)
        # compute new average
        new_avg = torch.mean(traj_tensor,0,False)
        # compute log likelihood
        disp = traj_tensor - new_avg
        covar = torch.matmul(torch.transpose(disp,1,2),disp)
        var = covar.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)
        var = torch.sum(var).cpu().numpy()
        log_lik = np.copy(var)
        # finish variance
        var /= (n_frames-1)*3*(n_atoms-1)
        log_lik /= var
        log_lik +=  n_frames * 3 * (n_atoms-1) * np.log(var)
        log_lik /= -2*n_frames
        delta_log_lik = abs(log_lik - old_log_lik)
        if verbose==True:
            print(log_lik)
        old_log_lik = log_lik
        avg = new_avg
    return avg, traj_tensor

def torch_pseudo_inv(sigma, dtype=torch.float32, device=torch.device("cuda:0"),EigenValueThresh=1e-10):
    N = sigma.shape[0]
    e, v = torch.linalg.eigh(sigma)
    pinv = torch.tensor(np.zeros(sigma.shape),dtype=dtype,device=device)
    lpdet = 0.0
    for i in range(N):
        if (e[i] > EigenValueThresh):
            lpdet += torch.log(e[i])
            pinv += 1.0/e[i]*torch.outer(v[:,i],v[:,i])
    return pinv, lpdet

def torch_weighted_log_lik(disp, precision, lpdet):
    # meta data
    n_frames = disp.shape[0]
    n_atoms = disp.shape[1]
    # compute log Likelihood for all points
    #log_lik = torch.trace(torch.sum(torch.matmul(torch.transpose(disp,1,2),torch.matmul(precision,disp)),0))
    log_lik = torch.sum(torch.matmul(torch.transpose(torch.reshape(disp[:,:,0],(n_frames,n_atoms,1)),1,2),torch.matmul(precision,torch.reshape(disp[:,:,0],(n_frames,n_atoms,1)))),0)
    log_lik += torch.sum(torch.matmul(torch.transpose(torch.reshape(disp[:,:,1],(n_frames,n_atoms,1)),1,2),torch.matmul(precision,torch.reshape(disp[:,:,1],(n_frames,n_atoms,1)))),0)
    log_lik += torch.sum(torch.matmul(torch.transpose(torch.reshape(disp[:,:,2],(n_frames,n_atoms,1)),1,2),torch.matmul(precision,torch.reshape(disp[:,:,2],(n_frames,n_atoms,1)))),0)
    log_lik += 3 * n_frames * lpdet
    log_lik /= -2*n_frames
    return log_lik

def torch_align_weighted(traj, stride=1000, dtype=torch.float64, device=torch.device("cuda:0"), thresh=1e-3, verbose=False):
    # timing data
    total_start = torch.cuda.Event(enable_timing=True)
    total_stop = torch.cuda.Event(enable_timing=True)
    pass_data_start = torch.cuda.Event(enable_timing=True)
    pass_data_stop = torch.cuda.Event(enable_timing=True)
    cmats_start = torch.cuda.Event(enable_timing=True)
    cmats_stop = torch.cuda.Event(enable_timing=True)
    cmats_elapsed = 0.0
    svd_start = torch.cuda.Event(enable_timing=True)
    svd_stop = torch.cuda.Event(enable_timing=True)
    svd_elapsed = 0.0
    det_start = torch.cuda.Event(enable_timing=True)
    det_stop = torch.cuda.Event(enable_timing=True)
    det_elapsed = 0.0
    finish_rot_start = torch.cuda.Event(enable_timing=True)
    finish_rot_stop = torch.cuda.Event(enable_timing=True)
    finish_rot_elapsed = 0.0
    rot_start = torch.cuda.Event(enable_timing=True)
    rot_stop = torch.cuda.Event(enable_timing=True)
    rot_elapsed = 0.0
    covar_start = torch.cuda.Event(enable_timing=True)
    covar_stop = torch.cuda.Event(enable_timing=True)
    covar_elapsed = 0.0
    pinv_start = torch.cuda.Event(enable_timing=True)
    pinv_stop = torch.cuda.Event(enable_timing=True)
    pinv_elapsed = 0.0
    log_lik_start = torch.cuda.Event(enable_timing=True)
    log_lik_stop = torch.cuda.Event(enable_timing=True)
    log_lik_elapsed = 0.0
    total_start.record()

    
    # meta data
    n_frames = traj.shape[0]
    n_atoms = traj.shape[1]
    
    pass_data_start.record()
    # pass trajectory to device
    traj_tensor = torch.tensor(traj,dtype=dtype,device=device)
    covar_norm = torch.tensor(1/(3*(n_frames-1)),dtype=dtype,device=device)
    pass_data_stop.record()
    
    # initialize with average as the first frame (arbitrary choice)
    weighted_avg = traj_tensor[0].T
    
    delta_log_lik = thresh+10
    old_log_lik = 0
    while (delta_log_lik > thresh):

        # compute correlation matrices using batched matmul
        cmats_start.record()
        c_mats = torch.matmul(weighted_avg,traj_tensor)
        cmats_stop.record()
        torch.cuda.synchronize()
        cmats_elapsed += cmats_start.elapsed_time(cmats_stop)
        # perfrom SVD of c_mats using batched SVD
        svd_start.record()
        #u, s, v = torch.linalg.svd(c_mats)
        u, s, v = svd(c_mats)
        svd_stop.record()
        torch.cuda.synchronize()
        svd_elapsed += svd_start.elapsed_time(svd_stop)
        # ensure true rotation by correcting sign of determinant 
        det_start.record()
        prod_dets = torch.linalg.det(u)*torch.linalg.det(v)
        det_stop.record()
        torch.cuda.synchronize()
        det_elapsed += det_start.elapsed_time(det_stop)
        #
        finish_rot_start.record()
        u[:,0,-1] *= prod_dets
        u[:,1,-1] *= prod_dets
        u[:,2,-1] *= prod_dets
        #rot_mat = torch.transpose(torch.matmul(u,v),1,2)
        rot_mat = torch.transpose(torch.matmul(u,torch.transpose(v,1,2)),1,2)
        finish_rot_stop.record()
        torch.cuda.synchronize()
        finish_rot_elapsed += finish_rot_start.elapsed_time(finish_rot_stop)
        # do rotation
        rot_start.record()
        traj_tensor = torch.matmul(traj_tensor,rot_mat)
        rot_stop.record()
        torch.cuda.synchronize()
        rot_elapsed += rot_start.elapsed_time(rot_stop)
        # compute new average
        avg = torch.mean(traj_tensor.to(torch.float64),0,False)
        disp = traj_tensor.to(torch.float64) - avg
        # compute covar using strided data
        # zero strided arrays
        covar_start.record()
        covar = torch.tensor(np.zeros((n_atoms,n_atoms)),dtype=torch.float64,device=device)
        for frame in range(0,n_frames,stride):
            covar += torch.sum(torch.matmul(disp[frame:frame+stride],torch.transpose(disp[frame:frame+stride],1,2)),0)
        covar *= covar_norm
        covar_stop.record()
        torch.cuda.synchronize()
        covar_elapsed += covar_start.elapsed_time(covar_stop)
        # log likelihood
        pinv_start.record()
        precision, lpdet = torch_pseudo_inv(covar,dtype=torch.float64)
        pinv_stop.record()
        torch.cuda.synchronize()
        pinv_elapsed += pinv_start.elapsed_time(pinv_stop)
        log_lik_start.record()
        log_lik = torch_weighted_log_lik(disp, precision, lpdet).cpu().numpy()[0][0]
        log_lik_stop.record()
        torch.cuda.synchronize()
        log_lik_elapsed += log_lik_start.elapsed_time(log_lik_stop)
        delta_log_lik = abs(log_lik - old_log_lik)
        if verbose == True:
            print(log_lik)
        old_log_lik = log_lik
        weighted_avg = torch.matmul(avg.T,precision).to(torch.float32)
    # finish timing
    total_stop.record()
    torch.cuda.synchronize()
    total_time = total_start.elapsed_time(total_stop)
    print("Total elapsed time:",np.round(total_time,3))
    print("Time to send data:", np.round(pass_data_start.elapsed_time(pass_data_stop),3), np.round(pass_data_start.elapsed_time(pass_data_stop)/total_time*100,3))
    print("cmat time:", np.round(cmats_elapsed,3), np.round(cmats_elapsed/total_time*100,3))
    print("svd time:", np.round(svd_elapsed,3), np.round(svd_elapsed/total_time*100,3))
    print("det time", np.round(det_elapsed,3), np.round(det_elapsed/total_time*100,3))
    print("finish rot time:", np.round(finish_rot_elapsed,3), np.round(finish_rot_elapsed/total_time*100,3))
    print("rot time:", np.round(rot_elapsed,3), np.round(rot_elapsed/total_time*100,3))
    print("covar time:", np.round(covar_elapsed,3), np.round(covar_elapsed/total_time*100,3))
    print("pinv time", np.round(pinv_elapsed,3), np.round(pinv_elapsed/total_time*100,3)) 
    print("log_lik time", np.round(log_lik_elapsed,3), np.round(log_lik_elapsed/total_time*100,3))        
        
        
    return avg, traj_tensor, precision, lpdet

In [1]:
delta = 100
# read trajectory
data_path = '../../DESRES-Trajectory_pnas2012-2f4k-360K-protein/pnas2012-2f4k-360K-protein/'
selection = "name CA and not resid 42 76"
#selection = "bynum 5:204"
#selection = "all"
# LOAD DATA
prmtopFileName =  data_path + 'pnas2012-2f4k-360K-protein.pdb'
trajFiles = [data_path + files for files in sorted(os.listdir(data_path)) if files.endswith('.dcd')]
coord = md.Universe(prmtopFileName,trajFiles)
sel = coord.select_atoms(selection)
print("Number of atoms in trajectory:", coord.atoms.n_atoms)
print("Number of frames in trajectory:",coord.trajectory.n_frames)
print("Number of atoms being analyzed:",sel.n_atoms)
print("Number of frames being analyzed:",coord.trajectory.n_frames//delta+1)
traj = np.empty((coord.trajectory.n_frames//delta+1,sel.n_atoms,3),dtype=float)
count = 0
for ts in coord.trajectory[::delta]:
    traj[count,:,:] = sel.positions#-sel.center_of_geometry()
    count += 1

NameError: name 'os' is not defined

In [5]:
device = "cuda:0"
dtype = torch.float32
n_frames = traj.shape[0]
n_atoms = traj.shape[1]
traj_tensor = torch.tensor(traj,device=device,dtype=dtype)

In [6]:
weights = torch.rand(traj_tensor.shape[0],device=device,dtype=dtype)
weights /= torch.sum(weights)

In [7]:
weighted_traj_tensor = traj_tensor * weights.view(-1,1,1)

In [8]:
cog = torch.mean(traj_tensor.to(torch.float64),1,False)
for i in range(n_frames):
    traj_tensor[i] -= cog[i]

In [11]:
import torch_align
weights = np.random.random(n_frames)
weights = np.ones(n_frames)
weights /= np.sum(weights)
traj = torch_align.torch_remove_center_of_geometry(traj)
avg, traj_aligned, var = torch_align.torch_align_uniform_weighted(traj, weights, dtype=torch.float32, device=torch.device("cuda:0"), thresh=1e-3, verbose=True)

-16348837.959430242
-15531123.18149861
-15497119.07433867
-15492621.025134778
-15491950.004127992
-15491848.498388756
-15491833.744463844
-15491831.920243865
-15491832.072908204
-15491832.512128977
-15491833.017908301
-15491833.51090449
-15491833.97587127
-15491834.447599491
-15491834.836695466
-15491835.324206468
-15491835.910154177
-15491836.523675513
-15491837.047464041
-15491837.59126326
-15491838.101383595


KeyboardInterrupt: 

In [10]:
avg, traj_aligned, var = torch_align.torch_align_uniform(traj, dtype=torch.float32, device=torch.device("cuda:0"), thresh=1e-3, verbose=True)

-834.1488922383661
-807.3570153494501
-806.242894933703
-806.0955213000307
-806.0735394438618
-806.0702125201793
-806.0697255455059


In [13]:
print(np.mean(avg,axis=0))

[ 6.4134599e-07 -3.3378601e-08  7.7724457e-07]


In [5]:
# 500 atoms
avg, traj_tensor, precision, lpdet = torch_align_weighted(traj, stride=1000, dtype=torch.float32, thresh=1e-1,verbose=True)

1615.269047981963
1616.469061621868
1617.0964701249018
1617.6490082505707
1618.254407529615
1618.960147786537
1619.7835008427235
1620.714753004388
1621.7119772963201
1622.755573077188
1623.9491541220204
1625.5445074923439
1627.0878908232316
1628.0785293317797
1628.6879712761076
1629.1253120625206
1629.4897005570117
1629.8223896401505
1630.1414698472288
1630.4548632248109
1630.7656392472777
1631.0746687479846
1631.382013539995
1631.6881812555848
1631.9938572829417
1632.2996118800368
1632.6055486717657
1632.9110507884225
1633.214721852412
1633.5146075103469
1633.8084145206226
1634.0937829146653
1634.3685528901647
1634.6308439924503
1634.879164249419
1635.11244170654
1635.330032765024
1635.531503272703
1635.7169243629646
1635.8863803629667
1636.0405299717108
1636.180033947446
1636.3057394363816
1636.4186176780406
1636.5196821815227
1636.6100136994635
Total elapsed time: 113934.992
Time to send data: 1902.282 1.67
cmat time: 314.023 0.276
svd time: 4160.431 3.652
det time 65.977 0.058
fini

In [5]:
# 500 atoms
avg, traj_tensor, precision, lpdet = torch_align_weighted(traj, stride=1000, dtype=torch.float64, thresh=1e-1,verbose=True)

[1615.26906394]
[1616.46908269]
[1617.09649695]
[1617.64902408]
[1618.25443667]
[1618.96021042]
[1619.7835538]
[1620.71484196]
[1621.71209721]
[1622.75570667]
[1623.94930442]
[1625.54466987]
[1627.08812959]
[1628.07885124]
[1628.68839091]
[1629.12580391]
[1629.49025011]
[1629.82302365]
[1630.14213479]
[1630.45554422]
[1630.7663164]
[1631.07527688]
[1631.38265599]
[1631.68887914]
[1631.99461957]
[1632.30045507]
[1632.60647828]
[1632.9120592]
[1633.21581911]
[1633.51578853]
[1633.80967076]
[1634.09511871]
[1634.36995633]
[1634.63231864]
[1634.88071863]
[1635.1140647]
[1635.33165029]
[1635.53312969]
[1635.71848556]
[1635.88799002]
[1636.04215961]
[1636.18170571]
[1636.3074832]
[1636.42044101]
[1636.52157775]
[1636.61190464]
Total elapsed time: 130677.414
Time to send data: 1907.239 1.46
cmat time: 2730.575 2.09
svd time: 18103.582 13.854
det time 65.231 0.05
finish rot time: 43.592 0.033
rot time: 1264.62 0.968
covar time: 81896.704 62.671
pinv time 3282.186 2.512
log_lik time 20966.368 1

In [5]:
# 300 atoms
avg, traj_tensor, precision, lpdet = torch_align_weighted(traj, stride=1000, dtype=torch.float64, thresh=1e-1)

Total elapsed time: 56076.418
Time to send data: 1815.451 3.237
cmat time: 1403.715 2.503
svd time: 15490.049 27.623
det time 53.721 0.096
finish rot time: 36.798 0.066
rot time: 637.885 1.138
covar time: 26631.44 47.491
pinv time 1558.666 2.78
log_lik time 8239.968 14.694


In [28]:
avg, traj_tensor, precision, lpdet = torch_align_weighted(traj, stride=2000, dtype=torch.float64, thresh=1e-1)

Total elapsed time: 19038.641
Time to send data: 37.006 0.194
cmat time: 434.964 2.285
svd time: 4754.733 24.974
det time 34.481 0.181
finish rot time: 251.87 1.323
rot time: 1244.773 6.538
covar time: 8139.555 42.753
pinv time 1545.31 8.117
log_lik time 2524.889 13.262


In [25]:
avg, traj_tensor, precision, lpdet = torch_align_weighted(traj, stride=1000, dtype=torch.float64, thresh=1e-1,verbose=True)

809.5294473324997
810.6901828632483
811.2408730983344
811.7088432690385
812.2422930512154
812.936115159636
813.8240392197268
814.9028413027304
816.08243905417
817.3270442116196
818.8241387158067
820.4750231593073
821.6887257448295
822.4901105566704
823.0669065180793
823.5247873013549
823.9149732073458
824.2652522756376
824.5919397604284
824.9034085138934
825.201617757907
825.4845406743818
825.7496713545357
825.9966205748309
826.2272974312441
826.4445131784191
826.6506877438136
826.8473608382009
827.0352894854583
827.2147357909146
827.385718723962
827.5481697340275
827.7020096422756
827.8471811043071
827.9836633400139
828.1114835057755
828.2307291134119
828.3415598131147
828.4442145380855
828.5390106770772
Total elapsed time: 27120.996
Time to send data: 37.044 0.137
cmat time: 435.321 1.605
svd time: 4755.049 17.533
det time 34.573 0.127
finish rot time: 253.015 0.933
rot time: 1249.983 4.609
covar time: 8156.458 30.074
pinv time 1550.682 5.718
log_lik time 10568.907 38.969


In [27]:
%timeit avg, traj_tensor, precision, lpdet = torch_align_weighted(traj, stride=1000, dtype=torch.float64, thresh=1e-1)

Total elapsed time: 19055.691
Time to send data: 37.094 0.195
cmat time: 436.186 2.289
svd time: 4749.793 24.926
det time 34.553 0.181
finish rot time: 252.093 1.323
rot time: 1244.751 6.532
covar time: 8157.424 42.808
pinv time 1548.319 8.125
log_lik time 2524.414 13.248
Total elapsed time: 19028.754
Time to send data: 36.6 0.192
cmat time: 428.482 2.252
svd time: 4732.062 24.868
det time 33.896 0.178
finish rot time: 251.968 1.324
rot time: 1245.18 6.544
covar time: 8161.769 42.892
pinv time 1543.372 8.111
log_lik time 2524.688 13.268
Total elapsed time: 19028.189
Time to send data: 36.649 0.193
cmat time: 428.461 2.252
svd time: 4731.741 24.867
det time 33.783 0.178
finish rot time: 251.933 1.324
rot time: 1245.099 6.543
covar time: 8161.897 42.894
pinv time 1543.338 8.111
log_lik time 2524.633 13.268
Total elapsed time: 19029.225
Time to send data: 36.588 0.192
cmat time: 428.455 2.252
svd time: 4731.765 24.866
det time 33.888 0.178
finish rot time: 251.947 1.324
rot time: 1245.023

In [None]:
%timeit avg, aligned_pos = traj_iterative_average(traj)

In [31]:
%timeit avg, traj_tensor = torch_align_weighted(traj, dtype=torch.float64)

6.23 s ± 1.52 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [28]:
avg, traj_tensor = torch_align(traj)

-2427.3916
-2256.5024
-2253.7893
-2253.718
-2253.7158
-2253.7158


In [21]:
%timeit avg, traj_tensor = torch_align(traj)

1.74 s ± 1.86 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [28]:
c_cpu = np.dot(traj[1].T,traj[0])
print(c_cpu)
u_cpu, s_cpu, v_cpu = np.linalg.svd(c_cpu)
if np.linalg.det(u_cpu) * np.linalg.det(v_cpu) < 0.0:
        #u_cpu[:, -1] = -u_cpu[:, -1]
        u_cpu = np.dot(u_cpu,np.diag([1.0,1.0,-1.0]))
rotation = np.dot(u_cpu, v_cpu) 
print(rotation)

[[  213.33044541   431.88569363 -1170.92205131]
 [  119.93508995    35.12541585   -63.43951665]
 [-1557.06528901  -969.14976134  2794.62486927]]
[[-0.6907233   0.11074025 -0.71458934]
 [ 0.15082826 -0.94440545 -0.29214582]
 [-0.70721436 -0.30957219  0.63562009]]


In [37]:
rot_mat_t = torch.transpose(rot_mat,1,2)

In [38]:
print(rot_mat_t.shape)

torch.Size([15261, 3, 3])


In [40]:
prod_dets = torch.linalg.det(u)*torch.linalg.det(v)
print(rot_mat_t[1],prod_dets[1])

tensor([[-0.6907,  0.1107, -0.7146],
        [ 0.1508, -0.9444, -0.2921],
        [-0.7072, -0.3096,  0.6356]], device='cuda:0', dtype=torch.float64) tensor(1.0000, device='cuda:0', dtype=torch.float64)


In [66]:
print(u_cpu)
print(v_cpu)

[[-0.34221098  0.92233097  0.17943588]
 [-0.03270435 -0.20254182  0.97872736]
 [ 0.93905382  0.32906291  0.09947627]]
[[-0.43267199 -0.29771531  0.85097623]
 [-0.90034254  0.19155203 -0.39075713]
 [ 0.04667185  0.93523976  0.35092495]]


In [18]:
print(u[1])

tensor([[-0.4327,  0.9003, -0.0467],
        [ 0.2977,  0.1916,  0.9352],
        [-0.8510, -0.3908,  0.3509]], device='cuda:0', dtype=torch.float64)


In [19]:
print(v[1])

tensor([[ 0.3422,  0.0327, -0.9391],
        [ 0.9223, -0.2025,  0.3291],
        [ 0.1794,  0.9787,  0.0995]], device='cuda:0', dtype=torch.float64)


In [20]:
print(torch.matmul(u[1],v[1]))

tensor([[ 0.6740, -0.2422,  0.6979],
        [ 0.4464,  0.8863, -0.1235],
        [-0.5887,  0.3948,  0.7054]], device='cuda:0', dtype=torch.float64)


# testing svd alternatives

In [38]:
dtype = torch.float32
device = "cuda:0"
thresh = 1e-1
# meta data
n_frames = traj.shape[0]
n_atoms = traj.shape[1]
    
# pass trajectory to device
traj_tensor = torch.tensor(traj,dtype=dtype,device=device)

# initialize with average as the first frame (arbitrary choice)
avg = traj_tensor[0]
    
delta_log_lik = thresh+10
old_log_lik = 0
while (delta_log_lik > thresh):
    # compute correlation matrices using batched matmul
    c_mats = torch.matmul(avg.T,traj_tensor)
    # perfrom SVD of c_mats using batched SVD
    #u, s, v = torch.linalg.svd(c_mats)
    u, s, v = svd(c_mats)
    # ensure true rotation by correcting sign of determinant 
    prod_dets = torch.linalg.det(u)*torch.linalg.det(v)
    u[:,0,-1] *= prod_dets
    u[:,1,-1] *= prod_dets
    u[:,2,-1] *= prod_dets
    #rot_mat = torch.transpose(torch.matmul(u,v),1,2)
    rot_mat = torch.transpose(torch.matmul(u,torch.transpose(v,1,2)),1,2)
    # do rotation
    traj_tensor = torch.matmul(traj_tensor,rot_mat)
    # compute new average
    avg = torch.mean(traj_tensor,0,False)
    # compute log likelihood
    disp = traj_tensor - avg
    covar = torch.matmul(torch.transpose(disp,1,2),disp)
    var = covar.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)
    var = torch.sum(var).cpu().numpy()
    log_lik = np.copy(var)
    # finish variance
    var /= (n_frames-1)*3*(n_atoms-1)
    log_lik /= var
    log_lik +=  n_frames * 3 * (n_atoms-1) * np.log(var)
    log_lik /= -2*n_frames
    delta_log_lik = abs(log_lik - old_log_lik)
    print(log_lik)
    old_log_lik = log_lik

-421.31265
-366.2488
-365.7177
-365.71213


In [35]:
u_v02, s_v02, v_v02 = svd(c_mats)

In [36]:
print(u[0])
print(u_v02[0])

tensor([[-0.4821,  0.8731, -0.0729],
        [-0.2263, -0.0438,  0.9731],
        [ 0.8464,  0.4856,  0.2187]], device='cuda:0')
tensor([[-0.4821,  0.8731, -0.0729],
        [-0.2263, -0.0438,  0.9731],
        [ 0.8464,  0.4856,  0.2187]], device='cuda:0')


In [37]:
print(v[0])
print(v_v02[0])

tensor([[-0.4833, -0.2262,  0.8457],
        [ 0.8722, -0.0410,  0.4874],
        [-0.0756,  0.9732,  0.2171]], device='cuda:0')
tensor([[-0.4833,  0.8722, -0.0756],
        [-0.2262, -0.0410,  0.9732],
        [ 0.8457,  0.4874,  0.2171]], device='cuda:0')


## pytorch

In [29]:
def torch_align(traj, dtype=torch.float32, device=torch.device("cuda:0"), thresh=1e-3):
    # meta data
    n_frames = traj.shape[0]
    n_atoms = traj.shape[1]
    
    # pass trajectory to device
    traj_tensor = torch.tensor(traj,dtype=dtype,device=device)

    # initialize with average as the first frame (arbitrary choice)
    avg = traj_tensor[0]
    
    delta_log_lik = thresh+10
    old_log_lik = 0
    while (delta_log_lik > thresh):
        # compute correlation matrices using batched matmul
        c_mats = torch.matmul(avg.T,traj_tensor)
        # perfrom SVD of c_mats using batched SVD
        u, s, v = torch.linalg.svd(c_mats)
        # ensure true rotation by correcting sign of determinant 
        prod_dets = torch.linalg.det(u)*torch.linalg.det(v)
        u[:,0,-1] *= prod_dets
        u[:,1,-1] *= prod_dets
        u[:,2,-1] *= prod_dets
        rot_mat = torch.transpose(torch.matmul(u,v),1,2)
        # do rotation
        traj_tensor = torch.matmul(traj_tensor,rot_mat)
        # compute new average
        new_avg = torch.mean(traj_tensor,0,False)
        # compute log likelihood
        disp = traj_tensor - new_avg
        covar = torch.matmul(torch.transpose(disp,1,2),disp)
        var = covar.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)
        var = torch.sum(var).cpu().numpy()
        log_lik = np.copy(var)
        # finish variance
        var /= (n_frames-1)*3*(n_atoms-1)
        log_lik /= var
        log_lik +=  n_frames * 3 * (n_atoms-1) * np.log(var)
        log_lik /= -2*n_frames
        delta_log_lik = abs(log_lik - old_log_lik)
        #print(log_lik)
        old_log_lik = log_lik
        avg = new_avg
    torch.cuda.empty_cache()
    return avg, traj_tensor

In [9]:
def torch_pseudo_inv(sigma, dtype=torch.float32, device=torch.device("cuda:0"),EigenValueThresh=1e-10):
    N = sigma.shape[0]
    e, v = torch.linalg.eigh(sigma)
    pinv = torch.tensor(np.zeros(sigma.shape),dtype=dtype,device=device)
    lpdet = 0.0
    for i in range(N):
        if (e[i] > EigenValueThresh):
            lpdet += torch.log(e[i])
            pinv += 1.0/e[i]*torch.outer(v[:,i],v[:,i])
    return pinv, lpdet

def torch_weighted_log_lik(disp, precision, lpdet):
    # meta data
    n_frames = disp.shape[0]
    # compute log Likelihood for all points
    log_lik = torch.trace(torch.sum(torch.matmul(torch.transpose(disp,1,2),torch.matmul(precision,disp)),0))
    log_lik += 3 * n_frames * lpdet
    log_lik /= -2*n_frames
    return log_lik

def torch_align_weighted(traj, stride=1000, dtype=torch.float32, device=torch.device("cuda:0"), thresh=1e-3):
    # meta data
    n_frames = traj.shape[0]
    n_atoms = traj.shape[1]
    
    # pass trajectory to device
    traj_tensor = torch.tensor(traj,dtype=dtype,device=device)
    covar_norm = torch.tensor(1/(3*(n_frames-1)),dtype=dtype,device=device)
    
    # initialize with average as the first frame (arbitrary choice)
    weighted_avg = traj_tensor[0].T
    
    delta_log_lik = thresh+10
    old_log_lik = 0
    while (delta_log_lik > thresh):
        # zero strided arrays
        new_avg = torch.tensor(np.zeros((n_atoms,3)),dtype=dtype,device=device)
        covar = torch.tensor(np.zeros((n_atoms,n_atoms)),dtype=dtype,device=device)
        # compute correlation matrices using batched matmul
        c_mats = torch.matmul(weighted_avg,traj_tensor)
        # perfrom SVD of c_mats using batched SVD
        u, s, v = torch.linalg.svd(c_mats)
        # ensure true rotation by correcting sign of determinant 
        prod_dets = torch.linalg.det(u)*torch.linalg.det(v)
        u[:,0,-1] *= prod_dets
        u[:,1,-1] *= prod_dets
        u[:,2,-1] *= prod_dets
        rot_mat = torch.transpose(torch.matmul(u,v),1,2)
        # do rotation
        traj_tensor = torch.matmul(traj_tensor,rot_mat)
        # compute new average
        new_avg += torch.mean(traj_tensor,0,False)
        # compute covar using strided data
        disp = traj_tensor - new_avg
        for frame in range(0,n_frames,stride):
            covar += torch.sum(torch.matmul(disp[frame:frame+stride],torch.transpose(disp[frame:frame+stride],1,2)),0)
        covar *= covar_norm
        # log likelihood
        precision, lpdet = torch_pseudo_inv(covar,dtype=dtype)
        log_lik = torch_weighted_log_lik(disp, precision, lpdet).cpu().numpy()
        delta_log_lik = abs(log_lik - old_log_lik)
        #print(log_lik)
        old_log_lik = log_lik
        avg = new_avg
        weighted_avg = torch.matmul(avg.T,precision)
    torch.cuda.empty_cache()
    return avg, traj_tensor, precision, lpdet

In [26]:
avg, traj_tensor = torch_align_weighted(traj, dtype=torch.float64)

-31.94574728367148
-30.720572987578883
-29.888099270568546
-28.914862407963614
-27.72567057409137
-26.54803221710204
-25.654201702063702
-24.991456978642876
-24.4954475943521
-24.1445170672913
-23.90600187357529
-23.75157807074262
-23.647350228027406
-23.56893211374824
-23.50250214736581
-23.43982919375617
-23.37518317966918
-23.305486802320896
-23.228855365261232
-23.141506407294727
-23.036776628102132
-22.907388960042773
-22.74601974171156
-22.5608779136198
-22.363899508763144
-22.165069786622073
-21.96974524843784
-21.77976107970529
-21.592999118528272
-21.401560472152056
-21.203461072180914
-20.997227046367776
-20.780061870964737
-20.54963810576432
-20.30510509129724
-20.047269161335453
-19.78106663499441
-19.519129383262342
-19.280147593517402
-19.079221852562686
-18.921209396976884
-18.80518507210785
-18.723735210592245
-18.667916319729347
-18.63016177366346
-18.60481963309143
-18.58788630239982
-18.57660189857164
-18.569092711678515
-18.56409800587304
-18.56077433739349
-18.5585

In [11]:
def torch_align_timing(traj, dtype=torch.float32, thresh=1e-3):
    total_start = torch.cuda.Event(enable_timing=True)
    total_stop = torch.cuda.Event(enable_timing=True)
    pass_data_start = torch.cuda.Event(enable_timing=True)
    pass_data_stop = torch.cuda.Event(enable_timing=True)
    cmats_start = torch.cuda.Event(enable_timing=True)
    cmats_stop = torch.cuda.Event(enable_timing=True)
    cmats_elapsed = 0.0
    svd_start = torch.cuda.Event(enable_timing=True)
    svd_stop = torch.cuda.Event(enable_timing=True)
    svd_elapsed = 0.0
    det_start = torch.cuda.Event(enable_timing=True)
    det_stop = torch.cuda.Event(enable_timing=True)
    det_elapsed = 0.0
    finish_rot_start = torch.cuda.Event(enable_timing=True)
    finish_rot_stop = torch.cuda.Event(enable_timing=True)
    finish_rot_elapsed = 0.0
    rot_start = torch.cuda.Event(enable_timing=True)
    rot_stop = torch.cuda.Event(enable_timing=True)
    rot_elapsed = 0.0
    log_lik_start = torch.cuda.Event(enable_timing=True)
    log_lik_stop = torch.cuda.Event(enable_timing=True)
    log_lik_elapsed = 0.0
    total_start.record()
    # meta data
    n_frames = traj.shape[0]
    n_atoms = traj.shape[1]
    
    # define device and type
    # device = torch.device("cpu")
    device = torch.device("cuda:0") # run on GPU
    
    # pass trajectory to device
    pass_data_start.record()
    traj_tensor = torch.tensor(traj,dtype=dtype,device=device)
    pass_data_stop.record()
    
    # initialize with average as the first frame (arbitrary choice)
    avg = traj_tensor[0]
    
    delta_log_lik = thresh+10
    old_log_lik = 0
    while (delta_log_lik > thresh):
        # compute correlation matrices using batched matmul
        cmats_start.record()
        c_mats = torch.matmul(avg.T,traj_tensor)
        cmats_stop.record()
        torch.cuda.synchronize()
        cmats_elapsed += cmats_start.elapsed_time(cmats_stop)
        # perfrom SVD of c_mats using batched SVD
        svd_start.record()
        u, s, v = torch.linalg.svd(c_mats)
        #u, s, v = svd(c_mats)
        svd_stop.record()
        torch.cuda.synchronize()
        svd_elapsed += svd_start.elapsed_time(svd_stop)
        # determine multiplier matrices based on determinants of u and v
        det_start.record()
        prod_dets = torch.linalg.det(u)*torch.linalg.det(v)
        det_stop.record()
        torch.cuda.synchronize()
        det_elapsed += det_start.elapsed_time(det_stop)
        # finish rot mat
        finish_rot_start.record()
        u[:,0,-1] *= prod_dets
        u[:,1,-1] *= prod_dets
        u[:,2,-1] *= prod_dets
        rot_mat = torch.transpose(torch.matmul(u,v),1,2)
        finish_rot_stop.record()
        torch.cuda.synchronize()
        finish_rot_elapsed += finish_rot_start.elapsed_time(finish_rot_stop)
        # do rotation
        rot_start.record()
        traj_tensor = torch.matmul(traj_tensor,rot_mat)
        rot_stop.record()
        torch.cuda.synchronize()
        rot_elapsed += rot_start.elapsed_time(rot_stop)
        # compute new average
        new_avg = torch.mean(traj_tensor,0,False)
        # compute log likelihood
        log_lik_start.record()
        disp = traj_tensor - new_avg
        covar = torch.matmul(torch.transpose(disp,1,2),disp)
        # batched trace
        var = covar.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)
        var = torch.sum(var).cpu().numpy()
        log_lik_stop.record()
        torch.cuda.synchronize()
        log_lik_elapsed += log_lik_start.elapsed_time(log_lik_stop)
        log_lik = np.copy(var)
        # finish variance
        var /= (n_frames-1)*3*(n_atoms-1)
        log_lik /= var
        log_lik +=  n_frames * 3 * (n_atoms-1) * np.log(var)
        log_lik /= -2*n_frames
        delta_log_lik = abs(log_lik - old_log_lik)
        print(log_lik)
        old_log_lik = log_lik
        avg = new_avg
    total_stop.record()
    torch.cuda.synchronize()
    total_time = total_start.elapsed_time(total_stop)
    print("Total elapsed time:",np.round(total_time,3))
    print("Time to send data:", np.round(pass_data_start.elapsed_time(pass_data_stop),3), np.round(pass_data_start.elapsed_time(pass_data_stop)/total_time*100,3))
    print("cmat time:", np.round(cmats_elapsed,3), np.round(cmats_elapsed/total_time*100,3))
    print("svd time:", np.round(svd_elapsed,3), np.round(svd_elapsed/total_time*100,3))
    print("det time", np.round(det_elapsed,3), np.round(det_elapsed/total_time*100,3))
    print("finish rot time:", np.round(finish_rot_elapsed,3), np.round(finish_rot_elapsed/total_time*100,3))
    print("rot time:", np.round(rot_elapsed,3), np.round(rot_elapsed/total_time*100,3))
    print("log_lik time", np.round(log_lik_elapsed,3), np.round(log_lik_elapsed/total_time*100,3))
    torch.cuda.empty_cache()
    return avg, traj_tensor

In [6]:
avg, traj_tensor = torch_align_timing(traj,dtype=torch.float32)

-1017.97107
-896.409
-895.2302
-895.21765
-895.2174
Total elapsed time: 1163.313
Time to send data: 125.895 10.822
cmat time: 23.603 2.029
svd time: 849.246 73.002
det time 8.412 0.723
finish rot time: 4.975 0.428
rot time: 58.765 5.052
log_lik time 86.678 7.451


In [12]:
avg, traj_tensor = torch_align_timing(traj,dtype=torch.float64)

-1017.971000791841
-896.4089968251327
-895.2301694209661
-895.2175451191622
-895.2173761191052
Total elapsed time: 5112.051
Time to send data: 140.011 2.739
cmat time: 203.255 3.976
svd time: 4349.359 85.081
det time 8.891 0.174
finish rot time: 7.124 0.139
rot time: 94.229 1.843
log_lik time 298.575 5.841


In [33]:
avg, traj_tensor = torch_align_timing(traj,dtype=torch.float64)

-1017.9710007918412
-896.4089968251325
-895.2301694209661
-895.2175451191622
-895.2173761191052
Total elapsed time: 5104.857421875
Time to send data: 141.21116638183594 2.766
cmat time: 203.85001754760742 3.993
svd time: 4363.8043212890625 85.483
det time 9.19593596458435 0.18
finish rot time: 13.390719890594482 0.262
rot time: 94.53641510009766 1.852
log_lik time 267.14771270751953 5.233


In [28]:
torch.cuda.get_device_name()

'NVIDIA GeForce RTX 2080 SUPER'

In [6]:
traj_tensor = torch.tensor(traj,device=device)

In [7]:
n_frames = traj.shape[0]

In [8]:
diag = torch.tensor(np.ones((n_frames,3)),dtype=dtype,device=device)

In [9]:
print(diag.shape)

torch.Size([15261, 3])


In [10]:
print(traj_tensor.shape)

torch.Size([15261, 33, 3])


In [57]:
c_mats = torch.matmul(traj_tensor[0].T,traj_tensor)
print(c_mats.shape)

torch.Size([15261, 3, 3])


In [58]:
print(c_mats[1])

tensor([[  213.3304,   119.9351, -1557.0653],
        [  431.8857,    35.1254,  -969.1498],
        [-1170.9221,   -63.4395,  2794.6249]], device='cuda:0',
       dtype=torch.float64)


In [59]:
u, s, v = torch.linalg.svd(c_mats)
print(u.shape)

torch.Size([15261, 3, 3])


In [72]:
rot_mat = torch.matmul(u,v)
print(rot_mat[1])

tensor([[-0.6740,  0.2422, -0.6979],
        [ 0.4464,  0.8863, -0.1235],
        [-0.5887,  0.3948,  0.7054]], device='cuda:0', dtype=torch.float64)


In [73]:
det_u = torch.linalg.det(u)
det_v = torch.linalg.det(v)

In [74]:
print(prod_dets[1])

tensor(-1.0000, device='cuda:0', dtype=torch.float64)


In [78]:
prod_dets = det_u*det_v
diag[:,0] = prod_dets
mult = torch.diag_embed(diag)
u = torch.matmul(mult,u)
rot_mat = torch.matmul(u,v)

In [79]:
print(rot_mat[1])

tensor([[ 0.6740, -0.2422,  0.6979],
        [ 0.4464,  0.8863, -0.1235],
        [ 0.5887, -0.3948, -0.7054]], device='cuda:0', dtype=torch.float64)


In [None]:
traj_tensor = torch.matmul()

In [48]:
def torch_max_likelihood_align_uniform(traj_tensor,device,thresh=1E-3,dtype=torch.float64):
    # trajectory metadata
    n_frames = traj_tensor.shape[0]
    n_atoms = traj_tensor.shape[1]
    # start be removing COG translation from every frame
    for ts in range(n_frames):
        mu = torch.tensor([0,0,0],device=device,dtype=dtype)
        for atom in range(n_atoms):
            mu += traj_tensor[ts,atom]
        mu /= n_atoms
        traj_tensor[ts] -= mu
    # Initialize average as first frame
    avg = torch.clone(traj_tensor[0])
    log_lik = torch_uniform_kabsch_log_lik(traj_tensor,avg)
    # perform iterative alignment and average to converge log likelihood
    log_lik_diff = 10
    count = 1
    while log_lik_diff > thresh:
        # rezero new average
        new_avg = torch.zeros_like(avg,device=device,dtype=dtype)
        # align trajectory to average and accumulate new average
        for ts in range(n_frames):
            traj_tensor[ts] = torch_kabsch_rotate(traj_tensor[ts], avg)
            new_avg += traj_tensor[ts]
        # finish average
        new_avg /= n_frames
        # compute log likelihood
        new_log_lik = torch_uniform_kabsch_log_lik(traj_tensor,avg)
        log_lik_diff = torch.abs(new_log_lik-log_lik)
        log_lik = new_log_lik
        # copy new average
        avg = torch.clone(new_avg)
        count += 1
    return avg

def torch_kabsch_rotate(mobile, target):
    correlation_matrix = torch.matmul(torch.transpose(mobile,0,1), target)
    V, S, W_tr = torch.linalg.svd(correlation_matrix)
    if torch.linalg.det(V) * torch.linalg.det(W_tr) < 0.0:
        V[:, -1] = -V[:, -1]
    rotation = torch.matmul(V, W_tr)
    mobile_prime = torch.matmul(mobile,rotation)
    return mobile_prime


def torch_uniform_kabsch_log_lik(x, mu):
    # meta data
    n_frames = x.shape[0]
    n_atoms = x.shape[1]
    # compute log Likelihood for all points
    log_lik = 0.0
    sampleVar = 0.0
    for i in range(n_frames):
        for j in range(3):
            disp = x[i,:,j] - mu[:,j]
            temp = torch.matmul(disp,disp)
            sampleVar += temp
            log_lik += temp
    # finish variance
    sampleVar /= (n_frames-1)*3*(n_atoms-1)
    log_lik /= sampleVar
    log_lik +=  n_frames * 3 * (n_atoms-1) * torch.log(sampleVar)
    log_lik *= -0.5
    return log_lik

In [49]:
torch_avg = torch_max_likelihood_align_uniform(traj_tensor,device)

In [33]:
# CPU
%timeit torch_avg = torch_max_likelihood_align_uniform(traj_tensor)

12.6 s ± 269 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [53]:
# GPU
%timeit torch_avg = torch_max_likelihood_align_uniform(traj_tensor,device)

24.1 s ± 180 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [55]:
print(torch_avg)

tensor([[-8.8136e+00,  2.7761e+00,  5.6701e+00],
        [-8.6363e+00,  1.6984e+00,  4.5916e+00],
        [-9.8360e+00,  1.3530e+00,  2.9926e+00],
        [-8.3056e+00,  1.7471e+00,  1.6843e+00],
        [-6.7849e+00, -6.3208e-03,  1.9070e+00],
        [-7.7955e+00, -1.8435e+00,  1.3085e+00],
        [-7.7394e+00, -1.6182e+00, -8.2500e-01],
        [-5.1752e+00, -1.9200e+00, -1.1867e+00],
        [-3.9311e+00, -3.7921e+00,  4.5359e-01],
        [-4.3799e+00, -5.1262e+00,  2.3399e+00],
        [-3.6644e+00, -3.5878e+00,  4.1838e+00],
        [-2.7841e+00, -2.0271e+00,  5.6479e+00],
        [-1.2592e+00, -5.7948e-01,  6.6520e+00],
        [ 1.1829e-01, -2.0474e+00,  7.9320e+00],
        [ 1.2045e+00, -2.8577e+00,  5.9995e+00],
        [ 2.3072e+00, -8.1139e-01,  4.9222e+00],
        [ 3.7223e+00, -2.7553e-01,  6.7896e+00],
        [ 5.1303e+00, -1.9867e+00,  6.2959e+00],
        [ 6.3809e+00, -8.9586e-01,  3.9409e+00],
        [ 7.7158e+00,  5.5063e-01,  2.0799e+00],
        [ 7.9290e+00