In [81]:
# Colab environment setup
import numpy as np
# Install the correct version of Pytorch Geometric.
import torch

def format_pytorch_version(version):
  return version.split('+')[0]

TORCH_version = torch.__version__
TORCH = format_pytorch_version(TORCH_version)

def format_cuda_version(version):
  return 'cu' + version.replace('.', '')

CUDA_version = torch.version.cuda
CUDA = format_cuda_version(CUDA_version)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-cluster -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-spline-conv -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-geometric

!pip install pytorch-lightning

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [82]:
import os
import random
import h5py
import numpy as np
import torch


def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

RANDOM_SEED = 3407
set_seed(RANDOM_SEED)

Random seed set as 3407


In [3]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [4]:
from pathlib import Path
DATADIR = Path("/content/gdrive/MyDrive/MISATO-experiments/prepared_data")
list(DATADIR.iterdir())

[PosixPath('/content/gdrive/MyDrive/MISATO-experiments/prepared_data/esm2-embeddings'),
 PosixPath('/content/gdrive/MyDrive/MISATO-experiments/prepared_data/md_esm_if_input.hdf5'),
 PosixPath('/content/gdrive/MyDrive/MISATO-experiments/prepared_data/skipped_pdbids.csv'),
 PosixPath('/content/gdrive/MyDrive/MISATO-experiments/prepared_data/md_adaptabilities.hdf5'),
 PosixPath('/content/gdrive/MyDrive/MISATO-experiments/prepared_data/esm_if_out_frame0.hdf5'),
 PosixPath('/content/gdrive/MyDrive/MISATO-experiments/prepared_data/skipped_pdb'),
 PosixPath('/content/gdrive/MyDrive/MISATO-experiments/prepared_data/data description.gslides')]

In [5]:
!mkdir data
!cp /content/gdrive/MyDrive/MISATO-experiments/prepared_data/esm_if_out_frame0.hdf5 /content/data
!cp /content/gdrive/MyDrive/MISATO-experiments/prepared_data/md_adaptabilities.hdf5 /content/data

In [7]:
!git clone https://github.com/sab148/MiSaTo-dataset.git

Cloning into 'MiSaTo-dataset'...
remote: Enumerating objects: 784, done.[K
remote: Counting objects: 100% (728/728), done.[K
remote: Compressing objects: 100% (361/361), done.[K
remote: Total 784 (delta 368), reused 663 (delta 345), pack-reused 56[K
Receiving objects: 100% (784/784), 173.98 MiB | 33.84 MiB/s, done.
Resolving deltas: 100% (368/368), done.
Updating files: 100% (90/90), done.


In [84]:
SPLITS_DIR = Path("MiSaTo-dataset/data/MD/splits")

def read_split(filename):
  with open(filename) as f:
    for line in f:
      line = line.strip()
      yield line


train_pdbids = list(read_split(SPLITS_DIR / "train_tinyMD.txt"))
test_pdbids = list(read_split(SPLITS_DIR / "test_tinyMD.txt"))
val_pdbids = list(read_split(SPLITS_DIR / "val_tinyMD.txt"))


In [87]:
from torch_geometric.data import Dataset
from torch_geometric.data import Data
import os

class AdaptabilityDataset(Dataset):
    def __init__(self, raw_dir, pdbid_list=None, target_pretransform=None, 
                 transform=None, pre_transform=None, pre_filter=None):
        self.pdbid_list = pdbid_list
        self.processed_pdbids = []
        self.pdbid2idx = dict()
        # list of pdb ids in data split
        self.file_names = []
        self.target_pretransform = target_pretransform
        super().__init__(raw_dir, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return [
            os.path.join("data", 'esm_if_out_frame0.hdf5'), 
            os.path.join("data", 'md_adaptabilities.hdf5')
        ]

    @property
    def processed_file_names(self):
        return self.file_names
        #return [f"{pdbid}.pt" for pdbid in self.split_idx]
        # return ['data_1.pt', 'data_2.pt', ...]

    def process(self):
        idx = 0
        embedding_file_name, adaptabilities_file_name = self.raw_file_names
        with h5py.File(embedding_file_name) as embeddings_collection, \
             h5py.File(adaptabilities_file_name) as adaptabilities_collection:
            if self.pdbid_list is not None:
                pdbid_list = self.pdbid_list
            else:
                pdbid_list = list(sorted(self.embeddings_collection.keys()))

            for pdbid in pdbid_list:
                if not pdbid in adaptabilities_collection:
                    continue
                
                if not pdbid in embeddings_collection:
                    continue
                embedding = embeddings_collection[pdbid][()]
                adaptabilities = adaptabilities_collection[pdbid][()]
                embedding = torch.from_numpy(embedding).to(torch.float)
                adaptabilities = torch.from_numpy(adaptabilities).to(torch.float)
                if self.target_pretransform is not None:
                    adaptabilities = self.target_pretransform(adaptabilities)
                data = Data(x=embedding, y=adaptabilities)

                if self.pre_filter is not None and not self.pre_filter(data):
                    continue
                if self.pre_transform is not None:
                    data = self.pre_transform(data)
                torch.save(data, os.path.join(self.processed_dir, f'{pdbid}.pt'))

                self.pdbid2idx[pdbid] = idx
                self.processed_pdbids.append(pdbid)
                idx += 1
                self.file_names.append(pdbid)

    def len(self):
        return len(self.file_names)

    def get_indices(self, pdbid_list):
        return np.asarray([self.pdbid2idx[pdbid] for pdbid in pdbid_list if pdbid in self.pdbid2idx])

    def get(self, idx):
        pdbid = self.processed_pdbids[idx]
        data = torch.load(os.path.join(self.processed_dir, f'{pdbid}.pt'))
        # data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))
        return data

In [88]:
import pytorch_lightning as pl
# import torch
import torch.nn.functional as F
from torchmetrics import MeanSquaredError

import torch_geometric.transforms as T
from torch_geometric.data.lightning import LightningDataset

from torch_geometric.nn import MLP, DynamicEdgeConv

class DGCNN(pl.LightningModule):
    #https://github.com/pyg-team/pytorch_geometric/blob/master/examples/dgcnn_segmentation.py
    def __init__(self, in_channels: int=512, out_channels: int=3,
                 hidden_channels: int = 64, num_layers: int = 3,
                 dropout: float = 0.5, k=30, aggr='max'):
        super().__init__()
        self.conv1 = DynamicEdgeConv(MLP([in_channels*2, 256, 128]), k, aggr)
        self.conv2 = DynamicEdgeConv(MLP([256, 256, 128]), k, aggr)
        self.conv3 = DynamicEdgeConv(MLP([256, 256, 128]), k, aggr)

        self.mlp = MLP([3 * 128, 256, 128, out_channels], dropout=0.5,
                       norm=None)

        self.train_metric = MeanSquaredError()
        self.val_metric = MeanSquaredError()
        self.test_metric = MeanSquaredError()

    def forward(self, x, edge_index, batch):
        x1 = self.conv1(x, batch)
        x2 = self.conv2(x1, batch)
        x3 = self.conv3(x2, batch)
        out = self.mlp(torch.cat([x1, x2, x3], dim=1))
        return F.selu(out) 

    def training_step(self, data, batch_idx):
        y_hat = self(data.x, data.edge_index, data.batch)
        loss = F.mse_loss(y_hat, data.y)
        self.train_metric(y_hat, data.y)
        self.log('train_metric', self.train_metric, prog_bar=True, on_step=False,
                 on_epoch=True)
        return loss

    def validation_step(self, data, batch_idx):
        y_hat = self(data.x, data.edge_index, data.batch)
        self.val_metric(y_hat, data.y)
        self.log('val_metric', self.val_metric, prog_bar=True, on_step=False,
                 on_epoch=True)

    def test_step(self, data, batch_idx):
        y_hat = self(data.x, data.edge_index, data.batch)
        self.test_metric(y_hat, data.y)
        self.log('test_metric', self.test_metric, prog_bar=True, on_step=False,
                 on_epoch=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.01)



In [71]:
datadir = "data"
target_pretransform = lambda adaptabilities: adaptabilities[1]
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset_pdbids = train_pdbids + test_pdbids + val_pdbids
full_dataset = AdaptabilityDataset(
    datadir, pdbid_list=dataset_pdbids, target_pretransform=target_pretransform
)

train_idx = full_dataset.get_indices(train_pdbids)
train_dataset = full_dataset[train_idx]
print("Train dataset length:", len(train_dataset))

val_idx = full_dataset.get_indices(val_pdbids)
val_dataset = full_dataset[val_idx]
print("Val dataset length:", len(val_dataset))

test_idx = full_dataset.get_indices(test_pdbids)
test_dataset = full_dataset[test_idx]
print("Test dataset length:", len(test_dataset))


Processing...


Train dataset length: 19
Val dataset length: 5
Test dataset length: 4


Done!


In [74]:
datamodule = LightningDataset(
    train_dataset, val_dataset, test_dataset,
    batch_size=64, 
    num_workers=1
)

model = DGCNN()
pre_transform = T.NormalizeScale()

if torch.cuda.is_available():
  devices = torch.cuda.device_count()
strategy = pl.strategies.DDPStrategy(accelerator='gpu')
checkpoint = pl.callbacks.ModelCheckpoint(
    monitor='val_metric', save_top_k=1,
    mode='min')
trainer = pl.Trainer(# strategy=strategy, 
                      # devices='cpu',
                      max_epochs=10,
                      log_every_n_steps=1, 
                     callbacks=[checkpoint])

trainer.fit(model, datamodule)
trainer.test(ckpt_path='best', datamodule=datamodule)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.callbacks.model_summary:
  | Name         | Type             | Params
--------------------------------------------------
0 | conv1        | DynamicEdgeConv  | 295 K 
1 | conv2        | DynamicEdgeConv  | 99.2 K
2 | conv3        | DynamicEdgeConv  | 99.2 K
3 | mlp          | MLP              | 131 K 
4 | train_metric | MeanSquaredError | 0     
5 | val_metric   | MeanSquaredError | 0     
6 | test_metric  | MeanSquaredError | 0     
--------------------------------------------------
626 K     Trainable params
0         Non-trainable params
626 K     Total params
2.504     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.
INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at /content/lightning_logs/version_10/checkpoints/epoch=5-step=6.ckpt
INFO:pytorch_lightning.utilities.rank_zero:Loaded model weights from the checkpoint at /content/lightning_logs/version_10/checkpoints/epoch=5-step=6.ckpt


Testing: 0it [00:00, ?it/s]

[{'test_metric': 0.3390730023384094}]