##### https://github.com/kundtx/MFC-TopoReg

In [1]:
import os 
os.chdir("..")

In [None]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import json

from utils.visualization import plot_molecules, plot_molecule
from utils.preprocessing import get_average_trajectory_positions, get_time_distance_matrix
import processor.data as data_processor
import processor.graph as graph_processor
import utils.metrics as metrics
from utils.export import *

In [3]:
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available. Using:", torch.cuda.get_device_name(0))
else:
    device = torch.device("cpu")
    print("GPU not available. Using CPU.")

GPU not available. Using CPU.


# Load data

In [None]:
integrin = "a5b1"    # "a5b1"
data_type = "clamp"  # Ramp 
u, extensions, config = data_processor.load_data(data_type, integrin)

In [None]:
# define domains 
ext = extensions[0]

if integrin == "aVb3":
    domain_to_residues = {
        "beta-propeller": u[0][ext].select_atoms(f"protein and name CA and segid A and resid 1-438").atoms.resindices,
        "thigh": u[0][ext].select_atoms(f"protein and name CA and segid A and resid 439-592").atoms.resindices,
        "loopA": u[0][ext].select_atoms(f"protein and name CA and segid A and resid 593-601").atoms.resindices,
        "calf1": u[0][ext].select_atoms(f"protein and name CA and segid A and resid 602-738").atoms.resindices,
        "calf2": u[0][ext].select_atoms(f"protein and name CA and segid A and resid 739-967").atoms.resindices,
        "transmembrane-alpha": u[0][ext].select_atoms(f"name CA and segid A and resid 968-984").atoms.resindices,
        "cytoplasmic-alpha": u[0][ext].select_atoms(f"name CA and segid A and resid 985-1016").atoms.resindices,
        "psi": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 1-56").atoms.resindices,
        "hybrid": u[0][ext].select_atoms(f"protein and name CA and segid B and (resid 57-108 or resid 353-433)").atoms.resindices,
        "betaI": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 109-352").atoms.resindices,
        "loopB": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 434-436").atoms.resindices,
        "egf1": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 437-472").atoms.resindices,
        "egf2": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 473-522").atoms.resindices,
        "egf3": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 523-559").atoms.resindices,
        "egf4": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 560-600").atoms.resindices,
        "betaTD": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 601-695").atoms.resindices,
        "transmembrane-beta": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 696-725").atoms.resindices,
        "cytoplasmic-beta": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 726-763").atoms.resindices,
    }
    
elif integrin == "alphaVbeta3":
    domain_to_residues = {
        "beta-propeller": u[0][ext].select_atoms(f"protein and name CA and segid A and resid 1-438").atoms.resindices,
        "thigh": u[0][ext].select_atoms(f"protein and name CA and segid A and resid 439-592").atoms.resindices,
        "loopA": u[0][ext].select_atoms(f"protein and name CA and segid A and resid 593-601").atoms.resindices,
        "calf1": u[0][ext].select_atoms(f"protein and name CA and segid A and resid 602-738").atoms.resindices,
        "calf2": u[0][ext].select_atoms(f"protein and name CA and segid A and resid 739-956").atoms.resindices,
        "transmembrane-alpha": u[0][ext].select_atoms(f"name CA and segid A and resid 957-984").atoms.resindices,
        # "calf2": u[0][ext].select_atoms(f"protein and name CA and segid A and resid 739-967").atoms.resindices,
        # "transmembrane-alpha": u[0][ext].select_atoms(f"name CA and segid A and resid 968-984").atoms.resindices,
        "cytoplasmic-alpha": u[0][ext].select_atoms(f"name CA and segid A and resid 985-1016").atoms.resindices,
        "psi": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 1-56").atoms.resindices,
        "hybrid": u[0][ext].select_atoms(f"protein and name CA and segid B and (resid 57-108 or resid 353-433)").atoms.resindices,
        "betaI": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 109-352").atoms.resindices,
        "loopB": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 434-436").atoms.resindices,
        "egf1": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 437-472").atoms.resindices,
        "egf2": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 473-522").atoms.resindices,
        "egf3": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 523-559").atoms.resindices,
        "egf4": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 560-600").atoms.resindices,
        "betaTD": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 601-692").atoms.resindices,
        "transmembrane-beta": u[0][ext].select_atoms(f"name CA and segid B and resid 693-725").atoms.resindices,
        # "betaTD": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 601-695").atoms.resindices,
        # "transmembrane-beta": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 696-725").atoms.resindices,
        "cytoplasmic-beta": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 726-763").atoms.resindices,
    }
    
elif integrin == "a5b1":
    domain_to_residues = {
        "beta-propeller": u[0][ext].select_atoms(f"protein and name CA and segid A and resid 1-449").atoms.resindices,
        "thigh": u[0][ext].select_atoms(f"protein and name CA and segid A and resid 450-602").atoms.resindices,
        "loopA": u[0][ext].select_atoms(f"protein and name CA and segid A and resid 603-611").atoms.resindices,
        "calf1": u[0][ext].select_atoms(f"protein and name CA and segid A and resid 612-748").atoms.resindices,
        "calf2": u[0][ext].select_atoms(f"protein and name CA and segid A and resid 749-947").atoms.resindices,
        "betaI": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 121-360").atoms.resindices,
        "hybrid": u[0][ext].select_atoms(f"protein and name CA and segid B and (resid 65-120 or resid 361-441)").atoms.resindices,
        "psi": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 1-64").atoms.resindices,
        "loopB": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 442-444").atoms.resindices,
        "egf1": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 445-480").atoms.resindices,
        "egf2": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 481-533").atoms.resindices,
        "egf3": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 534-570").atoms.resindices,
        "egf4": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 571-610").atoms.resindices,
        "betaTD": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 611-703").atoms.resindices
    }
    
elif integrin == "alpha2bbeta3":
    # https://www.cell.com/molecular-cell/fulltext/S1097-2765(08)00839-3?_returnURL=https%3A%2F%2Flinkinghub.elsevier.com%2Fretrieve%2Fpii%2FS1097276508008393%3Fshowall%3Dtrue 
    # Zhu, Jianghai, et al. "Structure of a complete integrin ectodomain in a physiologic resting state and activation and deactivation by applied forces." Molecular cell 32.6 (2008): 849-861.
    domain_to_residues = {
        "beta-propeller": u[0][ext].select_atoms(f"protein and name CA and segid A and resid 1-451").atoms.resindices,
        "thigh": u[0][ext].select_atoms(f"protein and name CA and segid A and resid 452-601").atoms.resindices,
        "calf1": u[0][ext].select_atoms(f"protein and name CA and segid A and resid 602-743").atoms.resindices,
        "calf2": u[0][ext].select_atoms(f"protein and name CA and segid A and resid 744-965").atoms.resindices,
        "transmembrane-alpha": u[0][ext].select_atoms(f"name CA and segid A and resid 966-1080").atoms.resindices,
        "cytoplasmic-alpha": u[0][ext].select_atoms(f"name CA and segid A and resid 989-1080").atoms.resindices,
        "psi": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 1008-1063").atoms.resindices,
        "hybrid": u[0][ext].select_atoms(f"protein and name CA and segid B and (resid 1064-1115 or resid 1360-1440)").atoms.resindices,
        "betaI": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 1116-1359").atoms.resindices,
        "egf1": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 1444-1479").atoms.resindices,
        "egf2": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 1480-1529").atoms.resindices,
        "egf3": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 1530-1566").atoms.resindices,
        "egf4": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 1567-1607").atoms.resindices,
        "betaTD": u[0][ext].select_atoms(f"protein and name CA and segid B and resid 1608-1699").atoms.resindices,
        "transmembrane-beta": u[0][ext].select_atoms(f"name CA and segid B and resid 1700-1732").atoms.resindices,
        "cytoplasmic-beta": u[0][ext].select_atoms(f"name CA and segid B and resid 1733-1770").atoms.resindices,
    }
else:
    domain_to_residues = {
        "molecule": u[0][ext].select_atoms(f"protein and name CA").atoms.resindices,
    }
    
residue_to_domain = {}
for domain, residues in domain_to_residues.items():
    for residue in residues:
        residue_to_domain[residue] = domain

domain_to_chain = {
    'beta-propeller': "A", 
    'thigh': "A", 
    'loopA': "A", 
    'calf1': "A", 
    'calf2': "A", 
    'transmembrane-alpha': "A", 
    'cytoplasmic-alpha': "A", 
    'psi': "B", 
    'hybrid': "B",
    'betaI': "B", 
    'loopB': "B", 
    'egf1': "B", 
    'egf2': "B", 
    'egf3': "B", 
    'egf4': "B", 
    'betaTD': "B", 
    'transmembrane-beta': "B", 
    'cytoplasmic-beta': "B"
}

In [None]:
# construct molecular graph sequence 
graph_sequences, dygraph_sequences, resindices_to_index, dist_matrices = graph_processor.contruct_graph_dygraph(
    u=u,
    extensions=extensions,
    config=config,
    residue_to_domain=residue_to_domain,
    warm_up_frames=1,
    node_attributes="coords",
    bound_thd = 5, 
    pval_thd = 1e-5
)

# Run model

In [None]:
import processor.model as model_processor

In [66]:
##########
class Args(dict):
    def __init__(self, n_cluster, file_name=None, network_type="MFC") -> None:
        # self.encoded_space_dim = 50
        # self.n_cluster = n_cluster  # clusters
        # self.num_epoch = 701 #1000
        # self.learning_rate = 0.001 # for topo
        # self.LAMBDA = 1
        # self.card = 20 # num of ph considered
        # self.file_name = file_name
        # self.network_type = network_type
        # self.start_mf = 500
        
        self.encoded_space_dim = 50
        self.n_cluster = 29  # clusters
        self.num_epoch = 701 #1000
        self.learning_rate = 0.001 # for topo
        self.LAMBDA = 1
        self.card = 40 # num of ph considered
        self.file_name = file_name
        self.network_type = network_type
        self.start_mf = 500

In [None]:
model_init = model_processor.InitModel(device=device)

In [69]:
# graph_sequences[rep] : list (nx.Graph for each timestep)
# residue_to_domain : dict, {node_id: domain_name}
rep = 0

snapshot_list, n_cluster = graph_processor.md_graphs_to_mfc_format(
    graph_sequence=graph_sequences[rep],
    residue_to_domain=residue_to_domain,
    device=device
)

print(len(snapshot_list))   # num of snapshots (time step)
print(n_cluster)        # num of labels (clusters)

# check first snapshot
adj, features, labels = snapshot_list[0]
print(adj.shape)        # adjacency size
print(features.shape)   # feature size
print(len(labels))      # num of nodes

4
14
(1609, 1609)
torch.Size([1609, 1609])
1609


In [70]:
# pip install gudhi
# pip install pot
from processor.filtration import WrcfLayer

args = Args(n_cluster, file_name=None, network_type="MFC") # fix 20 cluster or assume known n_cluster
model_list = []
dgm_list = []
wrcf_layer_dim0 = WrcfLayer(dim=0, card=args.card)
wrcf_layer_dim1 = WrcfLayer(dim=1, card=args.card)

from processor.trainer import base_train, retrain_with_topo
from processor.filtration import build_community_graph

results_raw = [] 
results_topo = []

In [71]:
# base deep clustering training
for idx, (adj,features,labels) in enumerate(snapshot_list):
    model = model_init(adj, features.size(1), args)
    model_list.append(model)
    
    base_train(model, features, adj, args, str(idx))
    with torch.no_grad():
        # if network_type == "SDCN":
        #     _, Q, _, Z = model(features,adj)
        # else:
        _, Z, Q = model(features,adj)
        results_raw.append([
            Z.cpu().detach().numpy(),
            Q.cpu().detach().numpy(),
            adj, labels
        ])
        # record dgm at each time step
        community_graph = build_community_graph(Q,adj)
        dgm0 = wrcf_layer_dim0(community_graph)
        dgm1 = wrcf_layer_dim1(community_graph)
        dgm_list.append([dgm0,dgm1])

Epoch: 0001 extra_loss= 0.00000 re_loss= 0.69248 train_acc= 0.00166 time= 0.04215
Epoch: 0101 extra_loss= 0.00000 re_loss= 0.65146 train_acc= 0.00231 time= 0.01078
Epoch: 0201 extra_loss= 0.00000 re_loss= 0.60464 train_acc= 0.00958 time= 0.01227
Epoch: 0301 extra_loss= 0.00000 re_loss= 0.57197 train_acc= 0.03030 time= 0.01268
Epoch: 0401 extra_loss= 0.00000 re_loss= 0.54060 train_acc= 0.06470 time= 0.01165
Epoch: 0501 extra_loss= 0.00000 re_loss= 0.51372 train_acc= 0.12049 time= 0.01238
Epoch: 0601 extra_loss= 0.02863 re_loss= 0.49610 train_acc= 0.20428 time= 0.01631
Epoch: 0701 extra_loss= 0.02963 re_loss= 0.48068 train_acc= 0.29125 time= 0.01603
Epoch: 0001 extra_loss= 0.00000 re_loss= 0.69247 train_acc= 0.00167 time= 0.01232
Epoch: 0101 extra_loss= 0.00000 re_loss= 0.65112 train_acc= 0.00196 time= 0.01308
Epoch: 0201 extra_loss= 0.00000 re_loss= 0.60436 train_acc= 0.00819 time= 0.01117
Epoch: 0301 extra_loss= 0.00000 re_loss= 0.57191 train_acc= 0.02966 time= 0.01157
Epoch: 0401 extr

In [74]:
# topological regulaized training
for t in range(len(snapshot_list)):
    m = model_list[t]
    adj,features,labels = snapshot_list[t]
    if t == 0:
        gt_dgm = [None, dgm_list[t+1]]
    elif t == len(snapshot_list)-1: 
        gt_dgm = [dgm_list[t-1], None]
    else:
        gt_dgm = [dgm_list[t-1],dgm_list[t+1]]
    retrain_with_topo(m, gt_dgm, adj, features, args, str(t))
    with torch.no_grad():
        # if network_type == "SDCN":
        #     _, Q, _, Z = m(features,adj)
        # else:
        _, Z, Q = m(features,adj)
        results_topo.append([
            Z.cpu().detach().numpy(),
            Q.cpu().detach().numpy(),
            adj,labels
        ])
        # update dgm at time 
        community_graph = build_community_graph(Q,adj)
        dgm0_new = wrcf_layer_dim0(community_graph)
        dgm1_new = wrcf_layer_dim1(community_graph)
        dgm_list[t] = [dgm0_new,dgm1_new]

Epoch: 0001 extra_loss= 0.00007 re_loss= 0.48054 train_acc= 0.29211 time= 0.38928
Epoch: 0101 extra_loss= 0.00005 re_loss= 0.46433 train_acc= 0.38962 time= 0.35626
Epoch: 0201 extra_loss= 0.00006 re_loss= 0.45230 train_acc= 0.49112 time= 0.32645
Epoch: 0301 extra_loss= 0.00005 re_loss= 0.44307 train_acc= 0.57908 time= 0.29766
Epoch: 0401 extra_loss= 0.00005 re_loss= 0.43590 train_acc= 0.64355 time= 0.28462
Epoch: 0501 extra_loss= 0.00004 re_loss= 0.43010 train_acc= 0.69085 time= 0.30305
Epoch: 0601 extra_loss= 0.00005 re_loss= 0.42515 train_acc= 0.72863 time= 0.28463
Epoch: 0701 extra_loss= 0.00004 re_loss= 0.42083 train_acc= 0.75709 time= 0.27432
Epoch: 0001 extra_loss= 0.00014 re_loss= 0.48171 train_acc= 0.28502 time= 0.32112
Epoch: 0101 extra_loss= 0.00013 re_loss= 0.46561 train_acc= 0.37948 time= 0.33840
Epoch: 0201 extra_loss= 0.00012 re_loss= 0.45346 train_acc= 0.47969 time= 0.37362
Epoch: 0301 extra_loss= 0.00013 re_loss= 0.44394 train_acc= 0.56930 time= 0.39296
Epoch: 0401 extr

In [None]:
import processor.metrics as pr_metrics

# ================= Metrics (Modularity & Conductance) =================
def _hard_clusters_from_Q(Q_np: np.ndarray) -> np.ndarray:
    return np.argmax(Q_np, axis=1)

def compute_metrics_time_series(results):
    modularity_ts = []
    conductance_ts = []
    modularity_gt_ts = []
    conductance_gt_ts = []
    for Z_np, Q_np, adj_sp, labels in results:
        clusters_pred = _hard_clusters_from_Q(Q_np)
        labels_np = np.asarray(labels)
        # adjacency is scipy.sparse (coo); metrics.* expects sparse
        mod_pred = pr_metrics.modularity(adj_sp, clusters_pred)
        cond_pred = pr_metrics.conductance(adj_sp, clusters_pred)
        modularity_ts.append(float(mod_pred))
        conductance_ts.append(float(cond_pred))
        # also compute against provided labels for reference
        mod_gt = pr_metrics.modularity(adj_sp, labels_np)
        cond_gt = pr_metrics.conductance(adj_sp, labels_np)
        modularity_gt_ts.append(float(mod_gt))
        conductance_gt_ts.append(float(cond_gt))
    return {
        "modularity": modularity_ts,
        "conductance": conductance_ts,
        "modularity_gt": modularity_gt_ts,
        "conductance_gt": conductance_gt_ts,
        "modularity_mean": float(np.mean(modularity_ts)),
        "conductance_mean": float(np.mean(conductance_ts)),
        "modularity_gt_mean": float(np.mean(modularity_gt_ts)),
        "conductance_gt_mean": float(np.mean(conductance_gt_ts)),
    }

In [None]:
metrics_base = compute_metrics_time_series(results_raw)
metrics_topo = compute_metrics_time_series(results_topo)

In [None]:

metrics_base

{'modularity': [0.6254347401071695,
  0.6327878477828665,
  0.6365867697841664,
  0.6330318787661444],
 'conductance': [0.33286516853932585,
  0.32525629077353213,
  0.3235294117647059,
  0.32680652680652683],
 'modularity_gt': [0.8386057763119134,
  0.8368505005953988,
  0.8379323406922682,
  0.8379464358485337],
 'conductance_gt': [0.012640449438202247,
  0.013513513513513514,
  0.013071895424836602,
  0.012587412587412588],
 'modularity_mean': 0.6319603091100867,
 'conductance_mean': 0.32711434947102264,
 'modularity_gt_mean': 0.8378337633620285,
 'conductance_gt_mean': 0.012953317740991238}

In [77]:
metrics_topo

{'modularity': [0.6283104686557534,
  0.6282854423018965,
  0.6473568730324373,
  0.6530068843355556],
 'conductance': [0.32771535580524347,
  0.3294501397949674,
  0.3118580765639589,
  0.303962703962704],
 'modularity_gt': [0.8386057763119134,
  0.8368505005953988,
  0.8379323406922682,
  0.8379464358485337],
 'conductance_gt': [0.012640449438202247,
  0.013513513513513514,
  0.013071895424836602,
  0.012587412587412588],
 'modularity_mean': 0.6392399170814107,
 'conductance_mean': 0.31824656903171844,
 'modularity_gt_mean': 0.8378337633620285,
 'conductance_gt_mean': 0.012953317740991238}