In [3]:
import MDAnalysis as mda
import glob
import re
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


## Making coordinate data (common atoms)

In [4]:
b2ar_traj_path = "/wrk/eurastof/binding_spots_project/gpcr_sampling/b2ar/b2ar_centered_aligned/"
b2ar_common_ndx = "/wrk/eurastof/binding_spots_project/HFSP---Lipid-binding-states/calculations/b2ar_common.ndx"

with open(b2ar_common_ndx) as f:
    lines = "".join(f.readlines())

resids = " ".join(re.findall(r"\d+", lines)[1:])


dirs = glob.glob(f"{b2ar_traj_path}*")
coordinates = []

for d in dirs:

    gro = glob.glob(f"{d}/*gro")[0]
    xtcs = glob.glob(f"{d}/*xtc")
    cosmos = mda.Universe(gro, xtcs)
    common_ca = cosmos.select_atoms(f"bynum {resids}")

    for ts in cosmos.trajectory[0:-1:5]:
        coords = common_ca.positions.flatten()
        coordinates.append(coords.reshape(1, coords.shape[0]))



In [5]:
X = np.concatenate(coordinates)

In [6]:
np.save("./b2ar_common_ca_coordinates.npy", X)

## Autoencoder pipeline

In [1]:
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import Pipeline

import torch
from torch import nn, optim

import random



In [13]:

class Autoencoder(BaseEstimator, TransformerMixin, nn.Module):

    def __init__(self, in_shape, enc_shape, middle_shape, loss_fn=nn.L1Loss(), lr=1e-3):
        
        super().__init__()
        self.loss_fn = loss_fn
        self.lr = lr

        self.encode = nn.Sequential(
            nn.Linear(in_shape, middle_shape),
            nn.ReLU(True),
            nn.Linear(middle_shape, middle_shape),
            nn.ReLU(True),
            nn.Linear(middle_shape, enc_shape)
        )
        self.decode = nn.Sequential(
            nn.Linear(enc_shape, middle_shape),
            nn.ReLU(True),
            nn.Linear(middle_shape, middle_shape),
            nn.ReLU(True),
            nn.Linear(middle_shape, in_shape)
        )

    def fit(self, X, y=None, n_epochs=10, batch_size=32, verbose=False):

        self.training = True
        X = torch.Tensor(X)
        indices = [i for i in range(X.shape[0])]
        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        for epoch in range(n_epochs):
        
            random.shuffle(indices)
            batches = [i for i in range(0, len(indices), batch_size)]

            for i in range(len(batches) - 1):
                
                batch_X = X[indices[batches[i]:batches[i+1]]]
                self.optimizer.zero_grad()
                
                encoded = self.encode(batch_X)
                decoded = self.decode(encoded)

                loss = self.loss_fn(decoded, batch_X)
                loss.backward()
                self.optimizer.step()
            
            if verbose:
                print(f'epoch {epoch} \t Loss: {loss.item():.4g}')

    def transform(self, X, y=None):

        X = torch.Tensor(X)
        encoded = self.encode(X)

        return encoded



In [14]:
error = nn.L1Loss()
AE = Autoencoder(X.shape[1], 2, 1024, loss_fn=error)
#AE.fit(X, n_epochs=2, verbose=True)

In [15]:
print(AE.get_params())

RecursionError: maximum recursion depth exceeded while calling a Python object

In [103]:
pipe = Pipeline(
    steps=[("AE", Autoencoder(in_shape=X.shape[1], enc_shape=2, middle_shape=1024))]
)

In [104]:
pipe.fit(X)

AttributeError: 'Autoencoder' object has no attribute 'in_shape'

AttributeError: 'Autoencoder' object has no attribute 'in_shape'

AttributeError: 'Autoencoder' object has no attribute 'in_shape'

In [None]:
preds = pipe.predict(X)

In [101]:
ae = Autoencoder(in_shape=X.shape[1], enc_shape=2, middle_shape=1024)
dir(ae)


['T_destination',
 '__annotations__',
 '__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_backward_hooks',
 '_backward_pre_hooks',
 '_buffers',
 '_call_impl',
 '_check_feature_names',
 '_check_n_features',
 '_forward_hooks',
 '_forward_hooks_with_kwargs',
 '_forward_pre_hooks',
 '_forward_pre_hooks_with_kwargs',
 '_get_backward_hooks',
 '_get_backward_pre_hooks',
 '_get_name',
 '_get_param_names',
 '_get_tags',
 '_is_full_backward_hook',
 '_load_from_state_dict',
 '_load_state_dict_post_hooks',
 '_load_state_dict_pre_hooks',
 '_maybe_warn_non_full_backward_hook',
 '_modules',
 '_more_tags',
 '_name