In [1]:
import os
import numpy as np
import torch
import torch.autograd.profiler as profiler
from torch_batch_svd import svd
import MDAnalysis as md
from numba import jit
from shapeGMM import _traj_tools as traj_tools
from torch_shapeGMM import ShapeGMMTorch
from shapeGMM import gmm_shapes

In [2]:
delta = 100
# read trajectory
data_path = '../../DESRES-Trajectory_pnas2012-2f4k-360K-protein/pnas2012-2f4k-360K-protein/'
selection = "name CA and not resid 42 76"
#selection = "bynum 5:204"
#selection = "all"
# LOAD DATA
prmtopFileName =  data_path + 'pnas2012-2f4k-360K-protein.pdb'
trajFiles = [data_path + files for files in sorted(os.listdir(data_path)) if files.endswith('.dcd')]
coord = md.Universe(prmtopFileName,trajFiles)
sel = coord.select_atoms(selection)
print("Number of atoms in trajectory:", coord.atoms.n_atoms)
print("Number of frames in trajectory:",coord.trajectory.n_frames)
print("Number of atoms being analyzed:",sel.n_atoms)
print("Number of frames being analyzed:",coord.trajectory.n_frames//delta+1)
traj = np.empty((coord.trajectory.n_frames//delta+1,sel.n_atoms,3),dtype=float)
count = 0
for ts in coord.trajectory[::delta]:
    traj[count,:,:] = sel.positions#-sel.center_of_geometry()
    count += 1

Number of atoms in trajectory: 577
Number of frames in trajectory: 1526041
Number of atoms being analyzed: 33
Number of frames being analyzed: 15261


In [3]:
sgmm = ShapeGMMTorch(n_clusters=2,verbose=True,init_cluster_method="chunk",covar_type='kronecker', dtype=torch.float64)
sgmm.fit(traj)

Number of frames being analyzed: 15261
Number of particles being analyzed: 33
Number of dimensions (must be 3): 3
Initializing clustering using method: chunk
Weights from initial clusters in fit: [0.50003276 0.49996724]
1 [0.736 0.264] -42.428
2 [0.702 0.298] -9.915
3 [0.677 0.323] 3.745
4 [0.662 0.338] 5.054
5 [0.657 0.343] 5.673
6 [0.655 0.345] 5.811
7 [0.652 0.348] 8.807
8 [0.65 0.35] 8.893
9 [0.649 0.351] 8.952
10 [0.648 0.352] 8.991
11 [0.647 0.353] 9.019
12 [0.647 0.353] 9.039
13 [0.646 0.354] 9.054
14 [0.645 0.355] 9.066
15 [0.644 0.356] 9.077
16 [0.643 0.357] 9.087
17 [0.641 0.359] 9.101
18 [0.638 0.362] 9.126
19 [0.635 0.365] 9.173
20 [0.631 0.369] 9.231
21 [0.627 0.373] 9.329
22 [0.623 0.377] 9.416
23 [0.621 0.379] 9.463
24 [0.62 0.38] 9.485
25 [0.619 0.381] 9.498
26 [0.618 0.382] 9.505
27 [0.618 0.382] 9.51
28 [0.617 0.383] 9.515
29 [0.617 0.383] 9.52
30 [0.616 0.384] 9.523
31 [0.615 0.385] 9.526
32 [0.615 0.385] 9.528
33 [0.615 0.385] 9.53
34 [0.615 0.385] 9.531
35 [0.614 0

In [12]:
sgmm_old = gmm_shapes.ShapeGMM(n_clusters=2,verbose=True,log_thresh=15.0, init_cluster_method="uniform")
fit_traj = sgmm_old.fit_weighted(traj)

Number of frames being analyzed: 15261
Number of particles being analyzed: 33
Number of dimensions (must be 3): 3
Initializing clustering using method: uniform
Weights from initial clusters in fit_weighted: [0.50003276 0.49996724]
0 [0.65611611 0.34388389] -542683.368109674
1 [0.68559448 0.31440552] -124849.74323508382
2 [0.66152117 0.33847883] 122326.82415978817
3 [0.65328896 0.34671104] 137503.80063696523
4 [0.65059016 0.34940984] 138751.30609375035
5 [0.64901371 0.35098629] 139011.0170284908
6 [0.64785434 0.35214566] 139080.80548396613
7 [0.64698078 0.35301922] 139119.60963216814
8 [0.64617509 0.35382491] 139145.12355020083
9 [0.64539883 0.35460117] 139170.5047556147
10 [0.64461704 0.35538296] 139198.0954637449
11 [0.64382073 0.35617927] 139229.66480666792
12 [0.64296243 0.35703757] 139265.690556446
13 [0.64161703 0.35838297] 139320.50172731993
14 [0.6395769 0.3604231] 139475.98201389724
15 [0.6366522 0.3633478] 139870.74152640378
16 [0.63360488 0.36639512] 140579.00799935538
17 [0.

In [11]:
1e-3*15261

15.261000000000001

In [4]:
def torch_fit():
    sgmm = ShapeGMMTorch(n_clusters=2,verbose=False,init_cluster_method="chunk",covar_type='kronecker',dtype=torch.float64)
    sgmm.fit(traj)

In [5]:
def cpu_fit():
    sgmm_old = gmm_shapes.ShapeGMM(n_clusters=2,log_thresh=15.0, verbose=False,init_cluster_method="uniform")
    fit_traj = sgmm_old.fit_weighted(traj)

In [6]:
# other svd
%timeit torch_fit()

Total elapsed time: 13236.952
Time to send data: 1391.375 10.511
Expectation time: 4430.921 33.474
Gamma time: 6.014 0.045
Maximization time: 7254.096 54.802
Total elapsed time: 13304.563
Time to send data: 1374.199 10.329
Expectation time: 4468.481 33.586
Gamma time: 6.078 0.046
Maximization time: 7302.383 54.886
Total elapsed time: 13309.149
Time to send data: 1382.954 10.391
Expectation time: 4464.14 33.542
Gamma time: 6.085 0.046
Maximization time: 7302.38 54.867
Total elapsed time: 13314.797
Time to send data: 1382.644 10.384
Expectation time: 4465.836 33.54
Gamma time: 6.145 0.046
Maximization time: 7306.168 54.873
Total elapsed time: 13319.37
Time to send data: 1384.287 10.393
Expectation time: 4465.302 33.525
Gamma time: 6.092 0.046
Maximization time: 7309.332 54.877
Total elapsed time: 13323.999
Time to send data: 1382.672 10.377
Expectation time: 4464.677 33.509
Gamma time: 6.095 0.046
Maximization time: 7302.837 54.81
Total elapsed time: 13400.501
Time to send data: 1383.374

In [7]:
%timeit cpu_fit()

2min 18s ± 1.6 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
%timeit cpu_fit()

2min 6s ± 380 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
