In [1]:
import pickle as pkl
from cgnet.molecule import CGMolecule
import mdtraj as md
import torch
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
from importlib import reload
from mlcg.nn.prior import _Prior, Harmonic, HarmonicBonds, HarmonicAngles, Dihedral
import networkx as nx
from networkx.algorithms.traversal.depth_first_search import *
import matplotlib
from mlcg.geometry.statistics import _symmetrise_angle_interaction, _symmetrise_distance_interaction
from mlcg.data import *
from torch_geometric.data.collate import collate
from mlcg.geometry import *
from copy import deepcopy
from mlcg.utils import tensor2tuple
from scipy.integrate import trapezoid



In [2]:
embedding_map = {'ALA' : 1,
                'CYS' : 2,
                'ASP' : 3,
                'GLU' : 4,
                'PHE' : 5,
                'GLY' : 6,
                'HIS' : 7,
                'ILE' : 8,
                'LYS' : 9,
                'LEU' : 10,
                'MET' : 11,
                'ASN' : 12,
                'PRO' : 13,
                'GLN' : 14,
                'ARG' : 15,
                'SER' : 16,
                'THR' : 17,
                'VAL' : 18,
                'TRP' : 19,
                'TYR' : 20,
                'N' : 21,
                'CA' : 22,
                'C' : 23,
                'O' : 24}

embed2res = {value:key for key, value in embedding_map.items()}

In [None]:
# first we load the peptide dictionary and grab an MDtraj object 
peptide_dictionary = pkl.load(open('/net/data02/nickc/prior_force_generators/peptide_cg_meta_dictionary_no_physical_GLY_repul_fix.pkl', 'rb'))

In [None]:
# Here we make the CG Molecule 
print(peptide_dictionary[2])

In [None]:
num=2
embeddings = peptide_dictionary[num]['embeddings']
print(peptide_dictionary[num]['resmap'])
resseq = []
for n, res in enumerate(peptide_dictionary[num]['residues']):
    if res.name == 'GLY':
        for _ in range(4):
            resseq.append(n +1)
    else:
        for _ in range(5):
            resseq.append(n +1)
peptide_dictionary[num]['cg_molecule'] = CGMolecule(names=[atom.name for atom in peptide_dictionary[num]['cg_atoms']],
                                                    resseq=resseq,
                                                    resmap=peptide_dictionary[num]['resmap'],
                                                    bonds='standard')

In [None]:
def get_n_pairs(connectivity_matrix, n=3, symmetrise=True):
    """This function uses networkx to identify those pairs
    that are exactly n atoms away.
    """
    graph = nx.Graph(connectivity_matrix.numpy())
    pairs = ([], [])
    for atom in graph.nodes:
        n_hop_paths = nx.single_source_dijkstra_path(graph, atom, cutoff=n)
        termini = [path[-1] for sub_atom, path in n_hop_paths.items() if len(path) == n]
        for child_atom in termini:
            sorted_pair = sorted((atom, child_atom))
            pairs[0].append(sorted_pair[0])
            pairs[1].append(sorted_pair[1])
            connections.append((sorted_pair[0], sorted_pair[1]))
            
    if symmetrise:
        pairs = _symmetrise_distance_interaction(torch.tensor(pairs))
    else:
        pairs = torch.tensor(pairs)
    return pairs


def get_n_paths(connectivity_matrix, n=3, symmetrise=True):
    """This function use networkx to grab all conencted paths defined
    by n connecting edges
    """
    
    if (n != 3 and n != 4) and symmetrise == True:
        raise NotImplementedError("Set symmetrise to False for n !=3/4.")
        
    graph = nx.Graph(connectivity_matrix.numpy())
    final_paths =  [ [] for i in range(n) ]
    for atom in graph.nodes:
        n_hop_paths = nx.single_source_dijkstra_path(graph, atom, cutoff=n)
        paths = [path for _ , path in n_hop_paths.items() if len(path) == n]
        #print(paths)
        for path in paths:
            #print(path)
            for k, sub_atom in enumerate(path):
                #print(sub_atom)
                final_paths[k].append(sub_atom)
    if symmetrise:
        if n == 3: final_paths = _symmetrise_angle_interaction(torch.tensor(final_paths))
        if n == 4: final_paths = _symmetrise_dihedral_interaction(torch.tensor(final_paths))
    else:
        final_paths = torch.tensor(final_paths)
    return final_paths


def get_connectivity_matrix(topology: Topology, directed: bool=False) -> torch.tensor:
    """Produces a full connectivity matrix from the bonded edge list"""

    if len(topology.bonds[0]) == 0 and len(topology.bonds[1]) == 0:
        raise ValueError("No bonds in the topology.")
    connectivity_matrix = torch.zeros((topology.n_atoms, topology.n_atoms))
    bonds = np.array(topology.bonds)
    if directed:
        for bond in range(bonds.shape[1]):
            connectivity_matrix[bonds[:, bond][0], bonds[:, bond][1]] = 1
    else:
        for bond in range(bonds.shape[1]):
            connectivity_matrix[bonds[:, bond][0], bonds[:, bond][1]] = 1
            connectivity_matrix[bonds[:, bond][1], bonds[:, bond][0]] = 1
    return connectivity_matrix

In [None]:
def _symmetrise_distance_interaction(unique_interaction_types):
    """Distance based interactions are symmetric w.r.t. the direction hence
    the need for only considering interactions (a,b) with a < b.
    """
    mask = unique_interaction_types[0] > unique_interaction_types[1]
    ee = unique_interaction_types[0, mask]
    unique_interaction_types[0, mask] = unique_interaction_types[1, mask]
    unique_interaction_types[1, mask] = ee
    unique_interaction_types = torch.unique(unique_interaction_types, dim=1)
    return unique_interaction_types


def _symmetrise_angle_interaction(unique_interaction_types):
    """For angles defined as::
      2---3
     /
    1
    atom 1 and 3 can be exchanged without changing the angle so the resulting
    interaction is symmetric w.r.t such transformations. Hence the need for only
    considering interactions (a,b,c) with a < c.
    """
    mask = unique_interaction_types[0] > unique_interaction_types[2]
    ee = unique_interaction_types[0, mask]
    unique_interaction_types[0, mask] = unique_interaction_types[2, mask]
    unique_interaction_types[2, mask] = ee
    unique_interaction_types = torch.unique(unique_interaction_types, dim=1)
    return unique_interaction_types


def _symmetrise_dihedral_interaction(unique_interaction_types):
    """For dihedrals defined as::
      2---3---4
     /
    1
    atoms [(1,2,3,4),(4,3,2,1) can be exchanged without changing the dihedral 
    so the resulting interaction is symmetric w.r.t such transformations. 
    Hence the need for only considering interactions (a,b,c,d) with a < d.
    """
    mask = unique_interaction_types[0] > unique_interaction_types[3]
    ee0 = unique_interaction_types[0, mask]
    ee1 = unique_interaction_types[1, mask]
    ee2 = unique_interaction_types[2, mask]
    ee3 = unique_interaction_types[3, mask]
    unique_interaction_types[0, mask] = ee3
    unique_interaction_types[1, mask] = ee2
    unique_interaction_types[2, mask] = ee1
    unique_interaction_types[3, mask] = ee0
    unique_interaction_types = torch.unique(unique_interaction_types, dim=1)
    return unique_interaction_types



_symmetrise_map = {
    2: _symmetrise_distance_interaction,
    3: _symmetrise_angle_interaction,
    4: _symmetrise_dihedral_interaction,
}

_flip_map = {
    2: lambda tup: torch.tensor([tup[1], tup[0]], dtype=torch.long),
    3: lambda tup: torch.tensor([tup[2], tup[1], tup[0]], dtype=torch.long),
    4: lambda tup: torch.tensor([tup[3], tup[2], tup[1], tup[0]], dtype=torch.long),
}


def _get_all_unique_keys(unique_types, order):
    # get all combinations of size order between the elements of unique_types
    keys = torch.cartesian_prod(*[unique_types for ii in range(order)]).t()
    # symmetrize the keys and keep only unique entries
    sym_keys = _symmetrise_map[order](keys)
    unique_sym_keys = torch.unique(sym_keys, dim=1)
    return unique_sym_keys


def _get_bin_centers(a, nbins):
    bin_centers = torch.zeros((nbins,), dtype=torch.float64)
    a_min = a.min()
    a_max = a.max()
    delta = (a_max - a_min) / nbins
    bin_centers = (
        a_min
        + 0.5 * delta
        + torch.arange(0, nbins, dtype=torch.float64) * delta
    )
    return bin_centers


def compute_statistics(
    data: AtomicData,
    target: str,
    beta: float,
    TargetPrior: _Prior = Harmonic,
    nbins: int = 100,
):
    """TODO add doc"""

    unique_types = torch.unique(data.atom_types)
    order = data.neighbor_list[target]["index_mapping"].shape[0]
    unique_keys = _get_all_unique_keys(unique_types, order)

    mapping = data.neighbor_list[target]["index_mapping"]
    values = TargetPrior.compute_features(data.pos, mapping)

    interaction_types = torch.vstack(
        [data.atom_types[mapping[ii]] for ii in range(order)]
    )
    print(interaction_types)
    print(unique_keys)

    statistics = {}
    for unique_key in unique_keys.t():
        # find which values correspond to unique_key type of interaction
        mask = torch.all(
            torch.vstack(
                [
                    interaction_types[ii, :] == unique_key[ii]
                    for ii in range(order)
                ]
            ),
            dim=0,
        )
        print(unique_key, values)
        val = values[mask]
        if len(val) == 0:
            print(unique_key, 'TROUBLE')
            continue

        bin_centers = _get_bin_centers(val, nbins)
        hist = torch.histc(val, bins=nbins)

        mask = hist > 0
        bin_centers_nz = bin_centers[mask]
        ncounts_nz = hist[mask]
        dG_nz = -torch.log(ncounts_nz) / beta
        params = TargetPrior.fit_from_potential_estimates(bin_centers_nz, dG_nz)
        kk = tensor2tuple(unique_key)
        statistics[kk] = params

        statistics[kk]["p"] = hist / trapezoid(
            hist.cpu().numpy(), x=bin_centers.cpu().numpy()
        )
        statistics[kk]["p_bin"] = bin_centers
        statistics[kk]["V"] = dG_nz
        statistics[kk]["V_bin"] = bin_centers_nz

        kf = tensor2tuple(_flip_map[order](unique_key))
        statistics[kf] = deepcopy(statistics[kk])

    return statistics

In [None]:
def compute_dihedral_statistics(
    data: AtomicData,
    target: str,
    beta: float,
    TargetPrior: _Prior = Dihedral,
    nbins: int = 100,
):
    """TODO add doc"""

    unique_types = torch.unique(data.atom_types)
    order = data.neighbor_list[target]["index_mapping"].shape[0]
    unique_keys = _get_all_unique_keys(unique_types, order)

    mapping = data.neighbor_list[target]["index_mapping"]
    print(mapping, TargetPrior)
    values = TargetPrior.compute_features(data.pos, mapping)
    print(values)

    interaction_types = torch.vstack(
        [data.atom_types[mapping[ii]] for ii in range(order)]
    )
    print(interaction_types)
    print(unique_keys)

    statistics = {}
    for unique_key in unique_keys.t():
        # find which values correspond to unique_key type of interaction
        mask = torch.all(
            torch.vstack(
                [
                    interaction_types[ii, :] == unique_key[ii]
                    for ii in range(order)
                ]
            ),
            dim=0,
        )
        print(unique_key, values)
        val = values[mask]
        if len(val) == 0:
            print(unique_key, 'TROUBLE')
            continue

        bin_centers = _get_bin_centers(val, nbins)
        hist = torch.histc(val, bins=nbins)

        mask = hist > 0
        bin_centers_nz = bin_centers[mask]
        ncounts_nz = hist[mask]
        dG_nz = -torch.log(ncounts_nz) / beta
        params = TargetPrior.fit_from_potential_estimates(bin_centers_nz, dG_nz)
        kk = tensor2tuple(unique_key)
        statistics[kk] = params

        statistics[kk]["p"] = hist / trapezoid(
            hist.cpu().numpy(), x=bin_centers.cpu().numpy()
        )
        statistics[kk]["p_bin"] = bin_centers
        statistics[kk]["V"] = dG_nz
        statistics[kk]["V_bin"] = bin_centers_nz

        kf = tensor2tuple(_flip_map[order](unique_key))
        statistics[kf] = deepcopy(statistics[kk])

    return statistics

In [None]:
topo = Topology()
my_top = topo.from_mdtraj(peptide_dictionary[num]['cg_molecule'].topology)
my_top.types = embeddings
cmat = get_connectivity_matrix(my_top)
print(cmat.shape)
bonded_angles = get_n_paths(cmat,n=3)
bonded_dihedrals = get_n_paths(cmat,n=4,symmetrise=True)

full_graph = my_top.neighbor_list("fully connected")
bonds = my_top.neighbor_list("bonds")
for i in range(bonded_angles.shape[1]):
    a, b, c = bonded_angles[:,i].numpy()
    my_top.add_angle(a,b,c)
angles = my_top.neighbor_list("angles")
    
for ix in range(bonded_dihedrals.shape[1]):
    a, b, c, d = bonded_dihedrals[:,ix].numpy()
    my_top.add_dihedral(a,b,c,d)
dihedrals = my_top.neighbor_list("dihedrals")


print(full_graph)
print(bonds)
print(angles)
print(dihedrals)

In [None]:
print(full_graph['index_mapping'])
print(bonds['index_mapping'])
print(angles['index_mapping'])
print(dihedrals['index_mapping'])

print(full_graph['index_mapping'].shape)
print(bonds['index_mapping'].shape)
print(angles['index_mapping'].shape)
print(dihedrals['index_mapping'].shape)

full_map = full_graph['index_mapping']
bonds_map = bonds['index_mapping']
angles_map = angles['index_mapping']
dihedrals_map = dihedrals['index_mapping']

In [None]:
# Here we need to subtract all physical bonds and all 1-3 distances
# first the bonds
non_bonds = ([], [])
for i in range(full_map.shape[1]):
    pair = full_map[:,i]
    #print(pair)
    one_three = False
    bonded = False
    for k in range(angles_map.shape[1]):
        angle = angles_map[:,k]
        if pair[0] == angle[0] and pair[1] == angle[2]:
            one_three = True
            break
        if pair[0] == angle[2] and pair[1] == angle[0]:
            one_three = True
            break
    for k in range(bonds_map.shape[1]):
        bond = bonds_map[:,k]
        if pair[0] == bond[0] and pair[1] == bond[1]:
            bonded = True
            break
        if pair[0] == bond[1] and pair[1] == bond[0]:
            bonded = True
            break
    if one_three == False and bonded == False:
        pair = sorted(pair)
        non_bonds[0].append(int(pair[0]))
        non_bonds[1].append(int(pair[1]))
non_bonds = torch.tensor(non_bonds)
non_bonds = _symmetrise_distance_interaction(non_bonds)
print(non_bonds.shape)

In [None]:
my_graph = nx.Graph(cmat.numpy())

In [None]:
print(my_graph)
labels = [str(i) for i in my_graph.nodes]
labels = {int(i):i for i in labels}
positions = nx.spring_layout(my_graph, seed=0)
nx.draw(my_graph, pos=positions, node_color=embeddings, labels=labels, cmap=matplotlib.cm.tab20b)


In [None]:
# Next, we need to load the opep data, get some statistics, and fit some priors
temperature = 350  # K
#:Boltzmann constan in kcal/mol/K
kB = 0.0019872041
beta = (1/(temperature*kB))


coords = np.load('/net/data02/nickc/octapeptides/brooke_map_cg_data/opep_0002_cg_coords_no_physical_GLY_repul_fix.npy'.format(num))  
print(coords.shape)
print(len(embeddings))

prior_nls = {'bonds':bonds, 'angles':angles, 'dihedrals':dihedrals}

data_list = []
for i in range(coords.shape[0]):
    data_list.append(AtomicData.from_points(pos=torch.tensor(coords[i]),
                                            atom_types=torch.tensor(embeddings), 
                                            neighborlist=prior_nls))

print(data_list[-1])
print(data_list[-1]['atom_types'])

In [None]:
datas, slices, _ = collate(
    data_list[0].__class__,
    data_list=data_list,
    increment=True,
    add_batch=True,
)

In [None]:
print(datas)
print(datas.n_atoms)
print(datas.neighbor_list.keys())
print(datas['atom_types'])

In [None]:
temperature = 350  # K
#:Boltzmann constan in kcal/mol/K
kB = 0.0019872041
beta = (1/(temperature*kB))

bond_stats = compute_statistics(datas, 'bonds', beta=beta, TargetPrior=HarmonicBonds)
angle_stats = compute_statistics(datas, 'angles', beta=beta, TargetPrior=HarmonicAngles)
dihedral_stats = compute_statistics(datas, 'dihedrals', beta=beta, TargetPrior=Dihedral)

In [None]:
print(bond_stats.keys())
# print(bond_stats)

In [None]:
for key in bond_stats.keys():
    print(key)
    name = str(embed2res[key[0]] + " - " + embed2res[key[1]])
    p = bond_stats[key]['p'].numpy()
    bins = bond_stats[key]['p_bin'].numpy()
    plt.step(bins,p,where='mid',linestyle='-', label=str(name))
plt.legend(loc='best')
plt.show()

In [None]:
print(angle_stats.keys())
# print(angle_stats)

In [None]:
for key in angle_stats.keys():
    name = str(embed2res[key[0]] + " - " + embed2res[key[1]] + " - " + embed2res[key[2]] )
    p = angle_stats[key]['p'].numpy()
    bins = angle_stats[key]['p_bin'].numpy()
    plt.step(bins,p,where='mid',linestyle='-', label=name)
plt.legend(loc='best')
plt.show()

In [None]:
print(dihedral_stats.keys())
print(dihedral_stats)

In [None]:
fig,ax = plt.subplots(figsize=(12,8))
for key in dihedral_stats.keys():
    name = str(embed2res[key[0]] + " - " + embed2res[key[1]] + " - " 
               + embed2res[key[2]] + " - " + embed2res[key[3]] )
    p = dihedral_stats[key]['p'].numpy()
    bins = dihedral_stats[key]['p_bin'].numpy()
    plt.step(bins,p,where='mid',linestyle='-', label=name)
plt.legend(loc='best')
plt.show()