In [4]:
# Colab environment setup
import numpy as np
# Install the correct version of Pytorch Geometric.
import torch

def format_pytorch_version(version):
  return version.split('+')[0]

TORCH_version = torch.__version__
TORCH = format_pytorch_version(TORCH_version)

def format_cuda_version(version):
  return 'cu' + version.replace('.', '')

CUDA_version = torch.version.cuda
CUDA = format_cuda_version(CUDA_version)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-cluster -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-spline-conv -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-geometric

!pip install pytorch-lightning

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [5]:
import os
import random
import numpy as np
import torch


def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

RANDOM_SEED = 3407
set_seed(RANDOM_SEED)

Random seed set as 3407


In [6]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [7]:
from pathlib import Path
DATADIR = Path("/content/gdrive/MyDrive/MISATO-experiments/prepared_data")
list(DATADIR.iterdir())

[PosixPath('/content/gdrive/MyDrive/MISATO-experiments/prepared_data/esm2-embeddings'),
 PosixPath('/content/gdrive/MyDrive/MISATO-experiments/prepared_data/md_esm_if_input.hdf5'),
 PosixPath('/content/gdrive/MyDrive/MISATO-experiments/prepared_data/skipped_pdbids.csv'),
 PosixPath('/content/gdrive/MyDrive/MISATO-experiments/prepared_data/md_adaptabilities.hdf5'),
 PosixPath('/content/gdrive/MyDrive/MISATO-experiments/prepared_data/esm_if_out_frame0.hdf5'),
 PosixPath('/content/gdrive/MyDrive/MISATO-experiments/prepared_data/skipped_pdb'),
 PosixPath('/content/gdrive/MyDrive/MISATO-experiments/prepared_data/data description.gslides')]

In [8]:
!mkdir data
!mkdir data/raw
if not os.path.exists("/content/data/raw/md_adaptabilities.hdf5"):
  !cp /content/gdrive/MyDrive/MISATO-experiments/prepared_data/esm_if_out_frame0.hdf5 /content/data/raw
  !cp /content/gdrive/MyDrive/MISATO-experiments/prepared_data/md_adaptabilities.hdf5 /content/data/raw

In [9]:
import h5py
# DATADIR = Path("prepared_data")
# h5py.File("md")
# embeddings_collection = h5py.File("data/esm_if_out_frame0.hdf5")
# adaptabilities_collection = h5py.File("data/md_adaptabilities.hdf5")

In [10]:
if not os.path.exists("MiSaTo-dataset"):
  !git clone https://github.com/sab148/MiSaTo-dataset.git

Cloning into 'MiSaTo-dataset'...
remote: Enumerating objects: 784, done.[K
remote: Counting objects: 100% (728/728), done.[K
remote: Compressing objects: 100% (361/361), done.[K
remote: Total 784 (delta 367), reused 665 (delta 345), pack-reused 56[K
Receiving objects: 100% (784/784), 173.98 MiB | 16.64 MiB/s, done.
Resolving deltas: 100% (367/367), done.
Updating files: 100% (90/90), done.


In [11]:
SPLITS_DIR = Path("MiSaTo-dataset/data/MD/splits")

def read_split(filename):
  with open(filename) as f:
    for line in f:
      line = line.strip()
      yield line


train_pdbids = list(read_split(SPLITS_DIR / "train_MD.txt")) # train_tinyMD
test_pdbids = list(read_split(SPLITS_DIR / "test_MD.txt"))
val_pdbids = list(read_split(SPLITS_DIR / "val_MD.txt"))


In [12]:
SAVEDIR = Path("/content/gdrive/MyDrive/MISATO-experiments/dgcnn_v2_saves")
SAVEDIR.mkdir(exist_ok=True)

In [13]:
# pdbid = train_idx[0]
# print(pdbid)
# embeddings = embeddings_collection[pdbid][()]
# adaptabilities = adaptabilities_collection[pdbid][()]

In [14]:
# embeddings.shape, adaptabilities[1]  # .shape
# # baseline: use 

In [25]:
from torch_geometric.data import Dataset
from torch_geometric.data import Data
import os

class AdaptabilityDataset(Dataset):
    def __init__(self, raw_dir, pdbid_list=None, target_pretransform=None, 
                 transform=None, pre_transform=None, pre_filter=None, recompute=True):
        self.pdbid_list = pdbid_list
        self.processed_pdbids = []
        self.pdbid2idx = dict()
        self.file_names = []
        self.recompute = recompute
        self.target_pretransform = target_pretransform
        super().__init__(raw_dir, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return [
            os.path.join(self.raw_dir, 'esm_if_out_frame0.hdf5'), 
            os.path.join(self.raw_dir, 'md_adaptabilities.hdf5')
        ]

    @property
    def processed_file_names(self):
        return self.file_names

    def process(self):
        idx = 0
        embedding_file_name, adaptabilities_file_name = self.raw_file_names
        with h5py.File(embedding_file_name) as embeddings_collection, \
             h5py.File(adaptabilities_file_name) as adaptabilities_collection:
            if self.pdbid_list is not None:
                pdbid_list = self.pdbid_list
            else:
                pdbid_list = list(sorted(self.embeddings_collection.keys()))

            for pdbid in pdbid_list:
                if not pdbid in adaptabilities_collection:
                    continue
                
                if not pdbid in embeddings_collection:
                    continue
                savepath = os.path.join(self.processed_dir, f'{pdbid}.pt')
                if os.path.exists(savepath) and not self.recompute:
                    # print(pdbid)
                    self.pdbid2idx[pdbid] = idx
                    self.processed_pdbids.append(pdbid)
                    idx += 1
                    self.file_names.append(pdbid)
                    continue
                embedding = embeddings_collection[pdbid][()]
                adaptabilities = adaptabilities_collection[pdbid][()]
                
                
                if self.target_pretransform is not None:
                    adaptabilities = self.target_pretransform(adaptabilities)
                embedding = torch.from_numpy(embedding).to(torch.float)
                adaptabilities = torch.from_numpy(adaptabilities).to(torch.float)
                data = Data(x=embedding, y=adaptabilities)

                if self.pre_filter is not None and not self.pre_filter(data):
                    continue
                if self.pre_transform is not None:
                    data = self.pre_transform(data)
                torch.save(data, savepath)

                self.pdbid2idx[pdbid] = idx
                # if not pdbid in self.pdbid2idx:
                self.processed_pdbids.append(pdbid)
                idx += 1
                self.file_names.append(pdbid)

    def len(self):
        return len(self.file_names)

    def get_indices(self, pdbid_list):
        return np.asarray([self.pdbid2idx[pdbid] for pdbid in pdbid_list if pdbid in self.pdbid2idx])

    def get(self, idx):
        pdbid = self.processed_pdbids[idx]
        data = torch.load(os.path.join(self.processed_dir, f'{pdbid}.pt'))
        return data

In [26]:
# here are imports
# import os.path as osp

import pytorch_lightning as pl
# import torch
import torch.nn.functional as F
from torchmetrics import PearsonCorrCoef
from torchmetrics import MeanSquaredError

import torch_geometric.transforms as T
from torch_geometric.data.lightning import LightningDataset
# from torch_geometric.datasets import TUDataset
from torch_geometric.nn import GIN, MLP, global_add_pool
from torch_geometric.nn import MLP, DynamicEdgeConv

class DGCNN(pl.LightningModule):
    #https://github.com/pyg-team/pytorch_geometric/blob/master/examples/dgcnn_segmentation.py
    def __init__(self, in_channels: int=512, out_channels: int=3,
                 # hidden_channels: int = 64, num_layers: int = 3,
                 dropout: float = 0.5, k=30, aggr='max'):
        super().__init__()
        self.save_hyperparameters()
        self.conv1 = DynamicEdgeConv(MLP([in_channels*2, 256, 128]), k, aggr)
        self.conv2 = DynamicEdgeConv(MLP([256, 256, 128]), k, aggr)
        self.conv3 = DynamicEdgeConv(MLP([256, 256, 128]), k, aggr)

        self.mlp = MLP([3 * 128, 256, 128, out_channels], dropout=dropout,
                       norm=None)

        self.train_mse = MeanSquaredError()
        self.val_mse = MeanSquaredError()
        self.test_mse = MeanSquaredError()
        self.test_pearson = PearsonCorrCoef(num_outputs=1)
        self.test_pearson3atoms = PearsonCorrCoef(num_outputs=3)

    def forward(self, x, edge_index, batch):
        x1 = self.conv1(x, batch)
        x2 = self.conv2(x1, batch)
        x3 = self.conv3(x2, batch)
        out = self.mlp(torch.cat([x1, x2, x3], dim=1))
        return F.selu(out)

    def training_step(self, data, batch_idx):
        y_hat = self(data.x, data.edge_index, data.batch)
        loss = F.mse_loss(y_hat, data.y)
        self.train_mse(y_hat, data.y)
        self.log('train_mse', self.train_mse, prog_bar=True, on_step=False,
                 on_epoch=True)
        return loss

    def validation_step(self, data, batch_idx):
        y_hat = self(data.x, data.edge_index, data.batch)
        self.val_mse(y_hat, data.y)
        self.log('val_mse', self.val_mse, prog_bar=True, on_step=False,
                 on_epoch=True)

    def test_step(self, data, batch_idx):
        y_hat = self(data.x, data.edge_index, data.batch)
        self.test_mse(y_hat, data.y)
        # print(y_hat.shape, data.y.shape)
        self.test_pearson(y_hat.reshape(-1), data.y.reshape(-1))
        # self.test_pearson3atoms(y_hat, data.y)
        self.log('test_mse', self.test_mse, prog_bar=True, on_step=False,
                 on_epoch=True)
        self.log('Test Pearson Correlation Coefficient', self.test_pearson, prog_bar=True, on_step=False,
                  on_epoch=True)
        # self.log('Test Pearson Corr. Coeff. on [N]', self.test_pearson3atoms[0], prog_bar=True, on_step=False,
        #           on_epoch=True)
        # self.log('Test Pearson Corr. Coeff. on [CA]', self.test_pearson3atoms[1], prog_bar=True, on_step=False,
        #           on_epoch=True)
        # self.log('Test Pearson Corr. Coeff. on [C]', self.test_pearson3atoms[2], prog_bar=True, on_step=False,
        #           on_epoch=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.01)



In [27]:
# trainer.test(test_dataset[0].x, test_dataset[0].edge_index, test_dataset[0].batch)

In [29]:
datadir = "data"
def adaptabilities2distribution_parameters(adaptabilities):
    # input shape (frames, num_residues, 3)
    # output shape should be (num_residues, 3+3) for mean and std values
    mean_values = adaptabilities[1:].mean(axis=0)
    std_values = adaptabilities[1:].std(axis=0)
    targets = np.concatenate([mean_values, std_values], axis=-1)
    return targets

# target_pretransform = lambda adaptabilities: adaptabilities[1]
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset_pdbids = train_pdbids + test_pdbids + val_pdbids
full_dataset = AdaptabilityDataset(
    datadir, pdbid_list=dataset_pdbids, target_pretransform=adaptabilities2distribution_parameters,
    recompute=False
)

train_idx = full_dataset.get_indices(train_pdbids)
train_dataset = full_dataset[train_idx]
print("Train dataset length:", len(train_dataset))

val_idx = full_dataset.get_indices(val_pdbids)
val_dataset = full_dataset[val_idx]
print("Val dataset length:", len(val_dataset))

test_idx = full_dataset.get_indices(test_pdbids)
test_dataset = full_dataset[test_idx]
print("Test dataset length:", len(test_dataset))


Processing...


Train dataset length: 11759
Val dataset length: 1342
Test dataset length: 1352


Done!


In [30]:
# prot = torch.load("data/processed/10GS.pt")
# prot
from pytorch_lightning.callbacks import BatchSizeFinder, EarlyStopping

In [None]:
from IPython.core.interactiveshell import Logger
NUM_EPOCHS = 50
BATCH_SIZE = 16 # T4
NSTEPS = max(len(train_dataset) //BATCH_SIZE, 1)
print("log every n steps", NSTEPS)
datamodule = LightningDataset(
    train_dataset, val_dataset, test_dataset,
    test_dataset,
    batch_size=BATCH_SIZE, 
    num_workers=1
)
from pytorch_lightning.loggers import CSVLogger
logger = CSVLogger("logs", name="pyg_dgcnn_next_frame")

model = DGCNN(out_channels=6)
pre_transform = T.NormalizeScale()

if torch.cuda.is_available():
  devices = torch.cuda.device_count()
strategy = pl.strategies.DDPStrategy(accelerator='gpu')
checkpoint = pl.callbacks.ModelCheckpoint(
    monitor='val_mse', save_top_k=3,
    mode='min',
    filename='dgcnn-{epoch:02d}-{step:03d}-{val_mse:.2f}.pt'
)
batch_size_finder=BatchSizeFinder()
# from lightning.pytorch.callbacks import EarlyStopping
early_stopping = EarlyStopping(
    'val_mse', mode='min',
    check_finite=True,
    patience=5
)
# trainer = Trainer(callbacks=[early_stopping])
callbacks = [
    checkpoint,
    early_stopping,
    # batch_size_finder
]
if torch.cuda.is_available():
  devices = torch.cuda.device_count()
  trainer = pl.Trainer(
        # strategy=strategy, 
        devices=devices,
        max_epochs=NUM_EPOCHS,
        log_every_n_steps=NSTEPS, 
        callbacks=callbacks,
        logger=logger
  )
else:
    trainer = pl.Trainer(
        # strategy=strategy, 
        # devices='cpu',
        max_epochs=NUM_EPOCHS,
        log_every_n_steps=NSTEPS, 
        callbacks=callbacks,
        logger=logger
    )

trainer.fit(model, datamodule)
trainer.test(ckpt_path='best', datamodule=datamodule)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name               | Type             | Params
--------------------------------------------------------
0 | conv1              | DynamicEdgeConv  | 295 K 
1 | conv2              | DynamicEdgeConv  | 99.2 K
2 | conv3              | DynamicEdgeConv  | 99.2 K
3 | mlp                | MLP              | 132 K 
4 | train_mse          | MeanSquaredError | 0     
5 | val_mse            | MeanSquaredError | 0     
6 | test_mse           | MeanSquaredError | 0     
7 | test_pearson       | PearsonCorrCoef  | 0     
8 | test_pearson3atoms | Pear

log every n steps 734


  rank_zero_warn(


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

In [32]:
# del train_dataset
# del test_dataset
# del val_dataset
# torch.cuda.empty_cache()
# del datamodule
# del trainer
# del model

In [33]:

# !cp -r logs {SAVEDIR}

In [34]:
# trainer.test(ckpt_path='best',)

In [None]:
# !cp {SAVEDIR/"logs/pyg_dgcnn_next_frame/version_0/checkpoints"}