In [1]:
from sklearn.model_selection import GridSearchCV
from sklearn.decomposition import PCA
from sklearn.pipeline import Pipeline
from sklearn.mixture import GaussianMixture
from sklearn.manifold import MDS, Isomap
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.cluster import DBSCAN

from scipy.special import kl_div
import MDAnalysis as mda
import sys
import itertools
from scipy.special import rel_entr
import matplotlib.pyplot as plt
import glob
import numpy as np
import itertools
import pickle
from scipy.spatial.distance import pdist, cdist
import random

import torch
import torch.nn as nn
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.preprocessing import StandardScaler

import copy
import skorch


## Making test data

selection = "name CA"
data_dict = {}

for system in glob.glob("../binding_spots_project/gpcr_sampling/b2ar/b2ar_centered_aligned/*"):
    name = system.split("/")[-1]
    cosmos = mda.Universe(glob.glob(f"{system}/*gro")[0], glob.glob(f"{system}/*xtc"))
    size = cosmos.select_atoms(selection).positions.flatten()
    pos = np.zeros(size.reshape(1, size.shape[0]).shape)
    for ts in cosmos.trajectory:
        pos = np.concatenate((pos, cosmos.select_atoms(selection).positions.reshape(1, pos.shape[1])))
    pos = pos[~np.all(pos == 0, axis=1)]
    data_dict[name] = pos
    
with open('data_dict.pkl', 'wb') as f:
    pickle.dump(data_dict, f)
    

In [2]:
with open('data_dict.pkl', 'rb') as handle:
    data_dict = pickle.load(handle)


In [3]:
system1, system2 = "popc", "chol-site-3"

X = np.concatenate([v for k, v in data_dict.items() if k == system1 or k == system2])
Y = np.concatenate([[-1 for i in range(data_dict[system1].shape[0])], [1 for i in range(data_dict[system2].shape[0])]])


In [107]:
device = ('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using {device} device.")

Using cuda device.


## FLOW

In [108]:
n_comps_dim_reduc = [i for i in range(2, 7)] if X.shape[1] > 5 else [i for i in range(2, X.shape[1])]
n_comps_cluster = [i for i in range(2, 3)]

In [109]:

default_steps = {
    "dim_reducer": {
        PCA(): {
            "n_components": n_comps_dim_reduc
        }
    },
    "classifier": {
        GaussianMixture(): {
            "n_components": n_comps_cluster,
            "covariance_type": ["full", "spherical", "diag", "tied"]
        }
    }
}


In [110]:
def build_pipelines(step_grid, X):
    
    reducers = step_grid["dim_reducer"]
    classifiers = step_grid["classifier"]
    combinations = list(itertools.product(reducers.keys(), classifiers.keys()))
    pipelines = []
    
    for comb in combinations:
        print(f"fitting combination: {comb}")
        grid = {}
        for param, values in step_grid["dim_reducer"][comb[0]].items():
            key = f"dim_reducer__{param}"
            grid[key] = values
        for param, values in step_grid["classifier"][comb[1]].items():
            key = f"classifier__{param}"
            grid[key] = values
        print(f"with params {grid}")
        pipe = Pipeline(steps=[("dim_reducer", comb[0]), ("classifier", comb[1])])
        cv = GridSearchCV(pipe, param_grid=grid).fit(X)
        pipelines.append(cv.best_estimator_)

    return pipelines



In [111]:

def caluclate_KL(pipeline, X, Y):
    
    preds = pipeline.predict(X)
    cluster_populations = []
    for system in set(Y.flatten()):
        system_preds = preds[np.where(Y == system)[0]]
        populations = [system_preds[np.where(system_preds == i)].shape[0] / system_preds.shape[0] for i in set(preds)]
        cluster_populations.append(populations)
        
    return sum(rel_entr(cluster_populations[0], cluster_populations[1])) # only two systems atm
    


In [112]:

def flow(X, Y, step_grid):
    
    best_KL = 0
    
    pipelines = build_pipelines(step_grid=step_grid, X=X)
    best_pipe = pipelines[0]
    for pipe in pipelines:
        KL = caluclate_KL(pipe, X, Y)
        if KL > best_KL:
            best_KL = KL
            best_pipe = pipe
    
    return best_pipe, best_KL    
        


feature selection
    
random seeds


