<a href="https://colab.research.google.com/github/jdhenaos/DL_excercises/blob/main/PMPNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#ProteinMPNN
This notebook is intended as a quick demo, more features to come!

In [2]:
!pip install torch-geometric

Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━[0m [32m61.4/63.1 kB[0m [31m52.3 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m17.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.6.1


In [5]:
import glob
import os
import os.path as osp
from typing import Callable, List, Optional

import torch

from torch_geometric.data import (
    Data,
    InMemoryDataset,
    download_url,
    extract_zip,
)
from torch_geometric.io import fs, read_off


class ModelNet(InMemoryDataset):
    r"""The ModelNet10/40 datasets from the `"3D ShapeNets: A Deep
    Representation for Volumetric Shapes"
    <https://people.csail.mit.edu/khosla/papers/cvpr2015_wu.pdf>`_ paper,
    containing CAD models of 10 and 40 categories, respectively.

    .. note::

        Data objects hold mesh faces instead of edge indices.
        To convert the mesh to a graph, use the
        :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`.
        To convert the mesh to a point cloud, use the
        :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to
        sample a fixed number of points on the mesh faces according to their
        face area.

    Args:
        root (str): Root directory where the dataset should be saved.
        name (str, optional): The name of the dataset (:obj:`"10"` for
            ModelNet10, :obj:`"40"` for ModelNet40). (default: :obj:`"10"`)
        train (bool, optional): If :obj:`True`, loads the training dataset,
            otherwise the test dataset. (default: :obj:`True`)
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.Data` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
        pre_filter (callable, optional): A function that takes in an
            :obj:`torch_geometric.data.Data` object and returns a boolean
            value, indicating whether the data object should be included in the
            final dataset. (default: :obj:`None`)
        force_reload (bool, optional): Whether to re-process the dataset.
            (default: :obj:`False`)

    **STATS:**

    .. list-table::
        :widths: 20 10 10 10 10 10
        :header-rows: 1

        * - Name
          - #graphs
          - #nodes
          - #edges
          - #features
          - #classes
        * - ModelNet10
          - 4,899
          - ~9,508.2
          - ~37,450.5
          - 3
          - 10
        * - ModelNet40
          - 12,311
          - ~17,744.4
          - ~66,060.9
          - 3
          - 40
    """

    urls = {
        '10':
        'http://vision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip',
        '40': 'http://modelnet.cs.princeton.edu/ModelNet40.zip'
    }

    def __init__(
        self,
        root: str,
        name: str = '10',
        train: bool = True,
        transform: Optional[Callable] = None,
        pre_transform: Optional[Callable] = None,
        pre_filter: Optional[Callable] = None,
        force_reload: bool = False,
    ) -> None:
        assert name in ['10', '40']
        self.name = name
        super().__init__(root, transform, pre_transform, pre_filter,
                         force_reload=force_reload)
        path = self.processed_paths[0] if train else self.processed_paths[1]
        self.load(path)

    @property
    def raw_file_names(self) -> List[str]:
        return [
            'bathtub', 'bed', 'chair', 'desk', 'dresser', 'monitor',
            'night_stand', 'sofa', 'table', 'toilet'
        ]

    @property
    def processed_file_names(self) -> List[str]:
        return ['training.pt', 'test.pt']

    def download(self) -> None:
        path = download_url(self.urls[self.name], self.root)
        extract_zip(path, self.root)
        os.unlink(path)
        folder = osp.join(self.root, f'ModelNet{self.name}')
        fs.rm(self.raw_dir)
        os.rename(folder, self.raw_dir)

        # Delete osx metadata generated during compression of ModelNet10
        metadata_folder = osp.join(self.root, '__MACOSX')
        if osp.exists(metadata_folder):
            fs.rm(metadata_folder)

    def process(self) -> None:
        self.save(self.process_set('train'), self.processed_paths[0])
        self.save(self.process_set('test'), self.processed_paths[1])

    def process_set(self, dataset: str) -> List[Data]:
        categories = glob.glob(osp.join(self.raw_dir, '*', ''))
        categories = sorted([x.split(os.sep)[-2] for x in categories])

        data_list = []
        for target, category in enumerate(categories):
            folder = osp.join(self.raw_dir, category, dataset)
            paths = glob.glob(f'{folder}/{category}_*.off')
            for path in paths:
                data = read_off(path)
                data.y = torch.tensor([target])
                data_list.append(data)

        if self.pre_filter is not None:
            data_list = [d for d in data_list if self.pre_filter(d)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(d) for d in data_list]

        return data_list

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}{self.name}({len(self)})'

In [9]:
path = "./sample_data"
transform = None
pre_transform = None
ModelNet(path, '10', True, transform, pre_transform)

Downloading http://vision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip
Extracting sample_data/ModelNet10.zip
Processing...
Done!


ModelNet10(3991)

# New version of ProteinMPNN dataset

In [63]:
from torch_geometric.data import (
    Data,
    InMemoryDataset,
    download_url,
    extract_tar,
)
from torch_geometric.io import fs
import torch
import os.path as osp
from torch_geometric.utils import one_hot
import pandas as pd

class PMPNNDataset(InMemoryDataset):
  url = 'https://files.ipd.uw.edu/pub/training_sets/pdb_2021aug02_sample.tar.gz'
  dir_name = 'pdb_2021aug02_sample'

  def __init__(self,
               #params,
               root,
               set_type, # 'train', 'val', or 'test'
               transform=None,
               pre_transform=None,
               pre_filter=None,
               log=True,
               force_reload=False
               #name='sample',
               ) -> None:
                 assert set_type in {'train', 'val', 'test'}
                 super().__init__(root, transform, pre_transform, pre_filter, log, force_reload)
                 self.set_type = set_type

  @property
  def raw_file_names(self) -> List[str]:
      return [self.dir_name + '.tar.gz']


  @property
  def raw_paths(self):
      r"""The absolute filepaths that must be present in order to skip
      downloading."""
      files = self.raw_file_names
      # If it is a list of files, the root directory is prepended to each.
      # Otherwise the root directory is simply joined with the filename.
      return [osp.join(self.raw_dir, f) for f in files] or osp.join(self.raw_dir, files)

  def download(self):
        # download is supposed to remove the old directory.
        # download the new zip file
        # put it into its respective folder (raw folder or processed folder)
        path = download_url(self.url, self.root)
        extract_tar(path, self.root)
        os.unlink(path)
        folder = osp.join(self.root)
        fs.rm(self.raw_dir)

In [64]:
path = "./sample_data"
PMPNNDataset(root=path,set_type='train')

Downloading https://files.ipd.uw.edu/pub/training_sets/pdb_2021aug02_sample.tar.gz
Extracting sample_data/pdb_2021aug02_sample.tar.gz


PMPNNDataset()

In [66]:
!ls ./sample_data/pdb_2021aug02_sample

list.csv  pdb  README  test_clusters.txt  valid_clusters.txt


In [3]:
from torch_geometric.data import (
    Data,
    InMemoryDataset,
    download_url,
    extract_zip,
)
from torch_geometric.io import fs
import torch
import os.path as osp
from torch_geometric.utils import one_hot
import pandas as pd

# What does the already preprocessed dataset contain?


class PMPNNDataset(InMemoryDataset):
    # dataset for sample

    url = 'https://files.ipd.uw.edu/pub/training_sets/pdb_2021aug02_sample.tar.gz'
    dir_name = 'pdb_2021aug02_sample'

    def __init__(self,
                 params,
                 root,
                 set_type, # 'train', 'val', or 'test'
                 transform=None,
                 pre_transform=None,
                 pre_filter=None,
                 log=True,
                 force_reload=False,
                 name='sample',
                 ):
        assert set_type in {'train', 'val', 'test'}
        super().__init__(root, transform, pre_transform, pre_filter, log, force_reload)
        self.params = params
        self.set_type = set_type

        self._rbf_max = 0.0
        self._rbf_min = 20.0
        self._rbf_counts = 16


    @property
    def raw_file_names(self):
        # I think we don't have any raw files names. The zip file is not considered raw.
        # In our case the zip is just compressed processed files
        return [self.dir_name + '.tar.gz']

    @property
    def processed_file_names(self):
        # What should we say here do you think? We can't define all files or?
        return ['data.pt']

    @property
    def raw_dir(self):
        # Not used since we dont process raw files
        return os.path.join(self.root, 'raw')

    @property
    def processed_dir(self):
        return os.path.join(self.root, 'processed')

    def download(self):
        # download is supposed to remove the old directory.
        # download the new zip file
        # put it into its respective folder (raw folder or processed folder)
        fs.rm(self.raw_dir)
        path = download_url(self.url, self.root)
        extract_zip(path, self.root)
        os.rename(osp.join(self.root, 'pdb_2021aug02_sample'), self.raw_dir)
        os.unlink(path)


    # def original_process(self):
    #     # Process here should loop through [train, val, test], put them into a
    #     # Data object, and save them.
    #     import torch
    #     import os.path as osp
    #     from torch_geometric.utils import one_hot
    #     amino_acid_list = ["A", "R", "N", "D", "C", "E", "Q", "G", "H", "I", "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V"]

    #     csv_file_path = os.path.join(self.raw_dir, 'list.csv')
    #     sequence_df = pd.read_csv(csv_file_path)

    #     validation_clusters = set(pd.read_csv(os.path.join(self.raw_dir, 'valid_clusters.txt'), header=None)[0])
    #     test_clusters = set(pd.read_csv(os.path.join(self.raw_dir, 'test_clusters.txt'), header=None)[0])

    #     training_data = []
    #     validation_data = []
    #     testing_data = []

    #     pdb_data_dir = osp.join(self.raw_dir, "pdb", "l3")


    #     for _, row in sequence_df.iterrows():

    #       pdb_chain_id = row["CHAINID"]
    #       sequence_cluster_id = row["CLUSTER"]
    #       protein_sequence = row["SEQUENCE"]

    #       pt_file_path = osp.join(pdb_data_dir, pdb_chain_id + ".pt")

    #       if not osp.exists(pt_file_path):
    #         continue

    #       protein_data = torch.load(pt_file_path)

    #       one_hot_encoded_sequence = one_hot(protein_sequence, num_classes=len(amino_acid_list))

    #       # add virtual node here.

    #       # add angles?? In the paper they decided against node features

    #       # input here is one combined chain e.g. positions.shape : torch.Size([447, 14, 3])
    #       pairwise_distance = torch.cdist(protein_data["xyz"], protein_data["xyz"]
    #       top_k_distances, edge_index = torch.topk(pairwise_distance), k=k)

    #       rbf_centers = torch.linspace(self._D_min, self._D_max, self._rbf_counts)
    #       sigma = (self._D_max - self._D_min) / self._rbf_counts
    #       edge_attr = torch.exp(-((top_k_distances.view(-1, 1) - rbf_centers) ** 2) / (2 * sigma ** 2))

    #       protein_graph = Data(
    #           x=one_hot_encoded_sequence,
    #           edge_index=edge_index,
    #           edge_attr=edge_attr,
    #           pos=protein_data["xyz"],
    #           mask=protein_data["mask"] # additional arg, maybe only train have mask??
    #       )

    #       if sequence_cluster_id in validation_clusters:
    #         validation_data.append(protein_graph)
    #       elif sequence_cluster_id in test_clusters:
    #         testing_data.append(protein_graph)
    #       else:
    #         training_data.append(protein_graph)

    #     return {
    #         "train": training_data,
    #         "val": validation_data,
    #         "test": testing_data
    #     }

    def loader_pdb(self, pdb_chid):
        pdb_id = pdb_chid.split('_')[0]
        chid = pdb_chid.split('_')[1]
        asmb_ids = self.chain_asmbs[pdb_chid]
        asmb_candidates = set(asmb_ids)
        if len(asmb_candidates)<1:
            chain = self.chains[pdb_chid]
            L = len(chain['seq'])
            return {'seq'    : chain['seq'],
                    'xyz'    : chain['xyz'],
                    'idx'    : torch.zeros(L).int(),
                    'masked' : torch.Tensor([0]).int(),
                    'label'  : pdb_chid}

        # randomly pick one assembly from candidates
        asmb_i = random.sample(list(asmb_candidates), 1)

        # indices of selected transforms
        asmbs = self.asmb_dict[asmb_i]

        # load relevant chains
        chains = {c:self.chains["%s_%s"%(pdb_chid,c)]
                for asmb in asmbs for c in asmb['asmb_chains']}
        chids = [cid.split(',')[0] for cid in chains]

        asmb_final = {}
        for asmb in asmbs:

            # pick k-th xform
            asmb_chains = asmb['asmb_chains']
            xform = asmb['xformIDX']
            u = xform[:,:3,:3]
            r = xform[:,:3,3]

            # select chains which k-th xform should be applied to

            chains_k = set(asmb_chains.split(','))

            # transform selected chains
            for c in chains_k:
                try:
                    xyz = self.chains[pdb_id+'_'+c]['xyz']
                    xyz_ru = torch.einsum('bij,raj->brai', u, xyz) + r[:,None,None,:]
                    asmb.update({(c,k,i):xyz_i for i,xyz_i in enumerate(xyz_ru)})
                except KeyError:
                    return {'seq': np.zeros(5)}

        # select chains which share considerable similarity to chid
        seqid = asmb['tm'][chids==chid][0,:,1]
        homo = set([ch_j for seqid_j,ch_j in zip(seqid,chids)
                    if seqid_j>self.params['HOMO']])
        # stack all chains in the assembly together
        seq,xyz,idx,masked = "",[],[],[]
        seq_list = []
        for counter,(k,v) in enumerate(asmb.items()):
            seq += chains[k[0]]['seq']
            seq_list.append(chains[k[0]]['seq'])
            xyz.append(v)
            idx.append(torch.full((v.shape[0],),counter))
            if k[0] in homo:
                masked.append(counter)

        return {'seq'    : seq,
                'xyz'    : torch.cat(xyz,dim=0),
                'idx'    : torch.cat(idx,dim=0),
                'masked' : torch.Tensor(masked).int(),
                'label'  : pdb_chid}

    def get_pdb(item, max_length=10000, num_units=1000000):
        init_alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G','H', 'I', 'J','K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T','U', 'V','W','X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g','h', 'i', 'j','k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't','u', 'v','w','x', 'y', 'z']
        extra_alphabet = [str(item) for item in list(np.arange(300))]
        chain_alphabet = init_alphabet + extra_alphabet
        c = 0
        c1 = 0

        t0 = time.time()
        item = {k:v[0] for k,v in item.items()}
        c1 += 1
        if 'label' in list(item):
            my_dict = {}
            s = 0
            concat_seq = ''
            concat_N = []
            concat_CA = []
            concat_C = []
            concat_O = []
            concat_mask = []
            coords_dict = {}
            mask_list = []
            visible_list = []
            if len(list(np.unique(item['idx']))) < 352:
                for idx in list(np.unique(item['idx'])):
                    letter = chain_alphabet[idx]
                    res = np.argwhere(item['idx']==idx)
                    initial_sequence= "".join(list(np.array(list(item['seq']))[res][0,]))
                    if initial_sequence[-6:] == "HHHHHH":
                        res = res[:,:-6]
                    if initial_sequence[0:6] == "HHHHHH":
                        res = res[:,6:]
                    if initial_sequence[-7:-1] == "HHHHHH":
                        res = res[:,:-7]
                    if initial_sequence[-8:-2] == "HHHHHH":
                        res = res[:,:-8]
                    if initial_sequence[-9:-3] == "HHHHHH":
                        res = res[:,:-9]
                    if initial_sequence[-10:-4] == "HHHHHH":
                        res = res[:,:-10]
                    if initial_sequence[1:7] == "HHHHHH":
                        res = res[:,7:]
                    if initial_sequence[2:8] == "HHHHHH":
                        res = res[:,8:]
                    if initial_sequence[3:9] == "HHHHHH":
                        res = res[:,9:]
                    if initial_sequence[4:10] == "HHHHHH":
                        res = res[:,10:]
                    if res.shape[1] < 4:
                        pass
                    else:
                        my_dict['seq_chain_'+letter]= "".join(list(np.array(list(item['seq']))[res][0,]))
                        concat_seq += my_dict['seq_chain_'+letter]
                        if idx in item['masked']:
                            mask_list.append(letter)
                        else:
                            visible_list.append(letter)
                        coords_dict_chain = {}
                        all_atoms = np.array(item['xyz'][res,])[0,] #[L, 14, 3]
                        coords_dict_chain['N_chain_'+letter]=all_atoms[:,0,:].tolist()
                        coords_dict_chain['CA_chain_'+letter]=all_atoms[:,1,:].tolist()
                        coords_dict_chain['C_chain_'+letter]=all_atoms[:,2,:].tolist()
                        coords_dict_chain['O_chain_'+letter]=all_atoms[:,3,:].tolist()
                        my_dict['coords_chain_'+letter]=coords_dict_chain
                my_dict['name']= item['label']
                my_dict['masked_list']= mask_list
                my_dict['visible_list']= visible_list
                my_dict['num_of_chains'] = len(mask_list) + len(visible_list)
                my_dict['seq'] = concat_seq

        return my_dict

    def is_valid_cluster(cluster):
        if self.set_type == 'train':
            return cluster not in self.val_clusters and cluster not in self.test_clusters
        elif self.set_type == 'test':
            return cluster in self.test_clusters
        elif self.set_type =='val':
            return cluster in self.val_clusters
        else:
            raise ValueError('invalid dataset type')
    def process(self):
        """
        Creates:
            self.cluster_chains: dict[cluster: list[str (chains)]]
            self.chains: dict[pdb_cid: dict['xyz': Tensor, 'seq': str]]
            self.asmb_dict: dict[asmb_id: list[dict['chain_seq': str, 'xformIDX': Tensor, 'tm': tm]]
            self.chain_asmbs: dict[pdb_cid: list[str(asmb_id)]]

        Potential problem: not checking if asmb_chain is in chains of PDBID file (why would this be necessary, how could an assembly include a chain not found)
        """
        with open(osp.join(self.raw_dir, 'test_clusters.csv'), 'r') as file:
            self.test_clusters = set([line.strip() for line in file])
        with open(osp.join(self.raw_dir, 'validation_clusters.csv'), 'r') as file:
            self.val_clusters = set([line.strip() for line in file])

        # Collecting the clusters and associated chains for processing
        csv_file_path = os.path.join(self.raw_dir, 'list.csv')
        sequence_df = pd.read_csv(csv_file_path)
        sample_sequence_df = sequence_df[sequence_df['CHAINID'].str[1:3]=='l3']
        cluster_chains = {}
        pdb_ids = []
        chain_ids = []
        for _, row in sample_sequence_df.iterrows():
            cluster = row['CLUSTER']
            if is_valid_cluster(cluster):
                if cluster in cluster_chains:
                    cluster_chains[cluster].append(row['CHAINID'])
                else:
                    cluster_chains[cluster] = [row['CHAINID']]
                pdb_ids.append(row['CHAINID'].split('_')[0])
                chain_ids.append(row['CHAINID'])

        self.cluster_chains = cluster_chains
        self.clusters = list(cluster_chains.keys())

        pdb_ids = list(set(pdb_ids))
        pdb_data_dir = osp.join(self.raw_dir, "pdb", "l3")
        pdb_files = [osp.join(pdb_data_dir, pdb_id + '.pt') for pdb_id in pdb_ids]
        chain_files = [osp.join(pdb_data_dir, chain_id + '.pt') for chain_id in chain_ids]

        # pdb_files = [osp.join(pdb_data_dir, f) for f in pdb_data_dir.iterdir() if '_' not in f and f[-3:]=='.pt']
        # chain_files = [osp.join(pdb_data_dir, f) for f in pdb_data_dir.iterdir() if '_' in f and f[-3:]=='.pt']
        asmb_dict = {}
        chain_asmbs = {}
        chains = {}
        for pdb_file in pdb_files:
            pdb_id = osp.splitext(osp.basename(pdb_file))[0]
            pdb_data = torch.load(pdb_file)
            chain_seqs = pdb_data['asmb_chains']
            asmb_ids = pdb_data['asmb_ids']
            xformIDXs = []
            j = 0
            while 'xform%d'%j in pdb_data:
                xformIDXs.append(pdb_data['xform%d'%j])
                j+=1
            tms = pdb_data['tm']
            if (len(tms) != len(chain_seqs) or len(chain_seqs != xformIDXs)) or (len(chain_seqs) != len(asmb_ids)):
                raise ValueError(f'Lengths do not match: chain_seqs: {len(chain_seqs)}, tms: {len(tms)}, xformIDXs: {len(xformIDXs)}, asmb_ids: {len(asmb_ids)}')
            for chain_seq, xformIDX, tm, asmb_id in zip(chain_seqs, xformIDXs, tms, asmb_ids):
                asmb_item = {
                    'chain_seq': chain_seq,
                    'xformIDX': xformIDX,
                    'tm': tm,
                }
                if asmb_id in asmb_dict:
                    asmb_dict[asmb_id].append(asmb_item)
                else:
                    asmb_dict[asmb_id] = asmb_item
                chain_set = set(chain_seq.split(','))
                for chain_id in chain_set:
                    pdb_cid = pdb_id + '_' + chain_id
                    if pdb_cid in chain_asmbs:
                        chain_asmbs[pdb_cid].append(asmb_id)
                    else:
                        chain_asmbs[pdb_cid] = [asmb_id]
        for chain_file in chain_files:
            chain_filename = osp.splitext(osp.basename(pdb_file))[0]
            chain_item = torch.load(chain_file)
            chains[chain_filename] = {
                'xyz': chain_item['xyz'],
                'seq': chain_item['seq']
            }

        self.chains = chains
        self.asmb_dict = asmb_dict
        self.chain_asmbs = chain_asmbs

    def __getitem__(self, idx):
      # think this one is implemented by base class check e.g.
      # https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/datasets/molecule_net.py
      # need to implement masking: probably in get function with better aligned data structures but this is first draft
      # train test split also needs to be figured out
      cluster_id = self.clusters[idx]
      pdbid_chid = random.choice(self.cluster_chains[cluster_id])
      asmb = None

      asmb = loader_pdb(pdbid_chid)
      pdb_dict = get_pdb(asmb)
      return pdb_dict

      # pdb_dict to data object maybe put featurize here




In [None]:
#@title Clone github repo
import json, time, os, sys, glob

if not os.path.isdir("ProteinMPNN"):
  os.system("git clone -q https://github.com/dauparas/ProteinMPNN.git")
sys.path.append('/content/ProteinMPNN')

In [None]:
#@title Setup Model
import matplotlib.pyplot as plt
import shutil
import warnings
import numpy as np
import torch
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split, Subset
import copy
import torch.nn as nn
import torch.nn.functional as F
import random
import os.path
from protein_mpnn_utils import loss_nll, loss_smoothed, gather_edges, gather_nodes, gather_nodes_t, cat_neighbors_nodes, _scores, _S_to_seq, tied_featurize, parse_PDB
from protein_mpnn_utils import EncLayer, PositionWiseFeedForward
from protein_mpnn_utils import StructureDataset, StructureDatasetPDB, ProteinMPNN

import torch
import torch.nn as nn
import torch_geometric.nn as pyg_nn
from torch_geometric.utils import scatter
import protein_mpnn_utils

import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import scatter

class PositionWiseFeedForward(nn.Module):
    def __init__(self, num_hidden, num_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(num_hidden, num_ff)
        self.linear2 = nn.Linear(num_ff, num_hidden)
        self.act = nn.GELU()

    def forward(self, x):
        return self.linear2(self.act(self.linear1(x)))

class EncLayer(nn.Module):
    def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30):
        super(EncLayer, self).__init__()
        self.num_hidden = num_hidden
        self.num_in = num_in
        self.scale = scale
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(num_hidden)
        self.norm2 = nn.LayerNorm(num_hidden)
        self.norm3 = nn.LayerNorm(num_hidden)

        self.W1 = nn.Linear(num_hidden + num_in, num_hidden, bias=True)
        self.W2 = nn.Linear(num_hidden, num_hidden, bias=True)
        self.W3 = nn.Linear(num_hidden, num_hidden, bias=True)
        self.W11 = nn.Linear(num_hidden + num_in, num_hidden, bias=True)
        self.W12 = nn.Linear(num_hidden, num_hidden, bias=True)
        self.W13 = nn.Linear(num_hidden, num_hidden, bias=True)
        self.act = torch.nn.GELU()
        self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)

        self.message_passing = MessagePassing(aggr='add')

    def forward(self, h_V, h_E, edge_index, mask_V=None, mask_attend=None):
        """Parallel computation of full transformer layer using PyTorch Geometric."""
        # First message passing layer
        h_message = self.message_passing_propagate(h_V, h_E, edge_index, mask_attend)
        h_V = self.norm1(h_V + self.dropout1(h_message / self.scale))

        # Feed-forward layer
        dh = self.dense(h_V)
        h_V = self.norm2(h_V + self.dropout2(dh))
        if mask_V is not None:
            h_V = mask_V.unsqueeze(-1) * h_V

        # Edge update
        h_E = self.edge_update(h_V, h_E, edge_index)
        return h_V, h_E

    def message_passing_propagate(self, h_V, h_E, edge_index, mask_attend):
        # Prepare edge features by concatenating node and edge features
        row, col = edge_index
        h_EV = torch.cat([h_V[row], h_E], dim=-1)
        h_EV = self.act(self.W1(h_EV))
        h_EV = self.act(self.W2(h_EV))
        h_message = self.W3(h_EV)

        if mask_attend is not None:
            h_message = mask_attend.unsqueeze(-1) * h_message

        # Aggregate messages to target nodes
        h_aggregated = scatter(h_message, col, dim=0, reduce='sum')
        return h_aggregated

    def edge_update(self, h_V, h_E, edge_index):
        # Update edge features using updated node features
        row, col = edge_index
        h_EV = torch.cat([h_V[row], h_E], dim=-1)
        h_EV = self.act(self.W11(h_EV))
        h_EV = self.act(self.W12(h_EV))
        h_message = self.W13(h_EV)
        h_E = self.norm3(h_E + self.dropout3(h_message))
        return h_E

class PositionWiseFeedForward(nn.Module):
    def __init__(self, num_hidden, num_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(num_hidden, num_ff)
        self.linear2 = nn.Linear(num_ff, num_hidden)
        self.act = nn.GELU()

    def forward(self, x):
        return self.linear2(self.act(self.linear1(x)))

class DecLayer(nn.Module):
    def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30):
        super(DecLayer, self).__init__()
        self.num_hidden = num_hidden
        self.num_in = num_in
        self.scale = scale
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(num_hidden)
        self.norm2 = nn.LayerNorm(num_hidden)

        self.W1 = nn.Linear(num_hidden + num_in, num_hidden, bias=True)
        self.W2 = nn.Linear(num_hidden, num_hidden, bias=True)
        self.W3 = nn.Linear(num_hidden, num_hidden, bias=True)
        self.act = nn.GELU()
        self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)

    def forward(self, h_V, h_E, edge_index, mask_V=None, mask_attend=None):
        """Parallel computation of the transformer decoder layer using PyTorch Geometric."""

        # Message passing and node feature update
        h_message = self.message_passing_propagate(h_V, h_E, edge_index, mask_attend)
        h_V = self.norm1(h_V + self.dropout1(h_message / self.scale))

        # Feedforward
        dh = self.dense(h_V)
        h_V = self.norm2(h_V + self.dropout2(dh))

        if mask_V is not None:
            h_V = mask_V.unsqueeze(-1) * h_V

        return h_V

    def message_passing_propagate(self, h_V, h_E, edge_index, mask_attend):
        row, col = edge_index
        h_EV = torch.cat([h_V[row], h_E], dim=-1)

        h_EV = self.act(self.W1(h_EV))
        h_EV = self.act(self.W2(h_EV))
        h_message = self.W3(h_EV)

        if mask_attend is not None:
            h_message = mask_attend.unsqueeze(-1) * h_message

        # Aggregate messages to target nodes
        h_aggregated = scatter(h_message, col, dim=0, reduce='sum')
        return h_aggregated

protein_mpnn_utils.EncLayer = EncLayer
protein_mpnn_utils.DecLayer = DecLayer
protein_mpnn_utils.PositionWiseFeedForward = PositionWiseFeedForward

In [None]:
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")

# Model version and settings
model_name = "v_48_020"  # Example: choose from ["v_48_002", "v_48_010", "v_48_020", "v_48_030"]
backbone_noise = 0.00  # Standard deviation of Gaussian noise to add to backbone atoms

# Paths and model configuration
path_to_model_weights = '/content/ProteinMPNN/vanilla_model_weights'
hidden_dim = 128
num_layers = 3
model_folder_path = path_to_model_weights
# Ensure proper formatting of the model folder path
if model_folder_path[-1] != '/':
    model_folder_path = model_folder_path + '/'
checkpoint_path = model_folder_path + f'{model_name}.pt'

# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device)
print('Number of edges:', checkpoint['num_edges'])
noise_level_print = checkpoint['noise_level']
print(f'Training noise level: {noise_level_print}A')

# Initialize the custom PyTorch Geometric model
model = ProteinMPNN(
    num_letters=21,
    node_features=hidden_dim,
    edge_features=hidden_dim,
    hidden_dim=hidden_dim,
    num_encoder_layers=num_layers,
    num_decoder_layers=num_layers,
    augment_eps=backbone_noise,
    k_neighbors=checkpoint['num_edges']
)

# Move model to device (e.g., GPU if available)
model.to(device)

# Load model weights from checkpoint
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
model.eval()
print("Model loaded")


  checkpoint = torch.load(checkpoint_path, map_location=device)


Number of edges: 48
Training noise level: 0.2A
Model loaded


In [None]:
#@title Helper functions
def make_tied_positions_for_homomers(pdb_dict_list):
    my_dict = {}
    for result in pdb_dict_list:
        all_chain_list = sorted([item[-1:] for item in list(result) if item[:9]=='seq_chain']) #A, B, C, ...
        tied_positions_list = []
        chain_length = len(result[f"seq_chain_{all_chain_list[0]}"])
        for i in range(1,chain_length+1):
            temp_dict = {}
            for j, chain in enumerate(all_chain_list):
                temp_dict[chain] = [i] #needs to be a list
            tied_positions_list.append(temp_dict)
        my_dict[result['name']] = tied_positions_list
    return my_dict

Examples:

1) pdb: 6MRR, homomer: False, designed_chain: A

2) pdb: 1O91, homomer: True, designed_chain: A B C, for correct symmetric tying lenghts of homomer chains should be the same

In [None]:
import re
from google.colab import files
import numpy as np

#########################
def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  else:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"

#@markdown ### Input Options
pdb='1O91' #@param {type:"string"}
pdb_path = get_pdb(pdb)
#@markdown - pdb code (leave blank to get an upload prompt)

homomer = True #@param {type:"boolean"}
designed_chain = "A B C" #@param {type:"string"}
fixed_chain = "" #@param {type:"string"}

if designed_chain == "":
  designed_chain_list = []
else:
  designed_chain_list = re.sub("[^A-Za-z]+",",", designed_chain).split(",")

if fixed_chain == "":
  fixed_chain_list = []
else:
  fixed_chain_list = re.sub("[^A-Za-z]+",",", fixed_chain).split(",")

chain_list = list(set(designed_chain_list + fixed_chain_list))

#@markdown - specified which chain(s) to design and which chain(s) to keep fixed.
#@markdown   Use comma:`A,B` to specifiy more than one chain

#chain = "A" #@param {type:"string"}
#pdb_path_chains = chain
##@markdown - Define which chain to redesign

#@markdown ### Design Options
num_seqs = 1 #@param ["1", "2", "4", "8", "16", "32", "64"] {type:"raw"}
num_seq_per_target = num_seqs

#@markdown - Sampling temperature for amino acids, T=0.0 means taking argmax, T>>1.0 means sample randomly.
sampling_temp = "0.1" #@param ["0.0001", "0.1", "0.15", "0.2", "0.25", "0.3", "0.5"]



save_score=0                      # 0 for False, 1 for True; save score=-log_prob to npy files
save_probs=0                      # 0 for False, 1 for True; save MPNN predicted probabilites per position
score_only=0                      # 0 for False, 1 for True; score input backbone-sequence pairs
conditional_probs_only=0          # 0 for False, 1 for True; output conditional probabilities p(s_i given the rest of the sequence and backbone)
conditional_probs_only_backbone=0 # 0 for False, 1 for True; if true output conditional probabilities p(s_i given backbone)

batch_size=1                      # Batch size; can set higher for titan, quadro GPUs, reduce this if running out of GPU memory
max_length=20000                  # Max sequence length

out_folder='.'                    # Path to a folder to output sequences, e.g. /home/out/
jsonl_path=''                     # Path to a folder with parsed pdb into jsonl
omit_AAs='X'                      # Specify which amino acids should be omitted in the generated sequence, e.g. 'AC' would omit alanine and cystine.

pssm_multi=0.0                    # A value between [0.0, 1.0], 0.0 means do not use pssm, 1.0 ignore MPNN predictions
pssm_threshold=0.0                # A value between -inf + inf to restric per position AAs
pssm_log_odds_flag=0               # 0 for False, 1 for True
pssm_bias_flag=0                   # 0 for False, 1 for True


##############################################################

folder_for_outputs = out_folder

NUM_BATCHES = num_seq_per_target//batch_size
BATCH_COPIES = batch_size
temperatures = [float(item) for item in sampling_temp.split()]
omit_AAs_list = omit_AAs
alphabet = 'ACDEFGHIKLMNPQRSTVWYX'

omit_AAs_np = np.array([AA in omit_AAs_list for AA in alphabet]).astype(np.float32)

chain_id_dict = None
fixed_positions_dict = None
pssm_dict = None
omit_AA_dict = None
bias_AA_dict = None
tied_positions_dict = None
bias_by_res_dict = None
bias_AAs_np = np.zeros(len(alphabet))


###############################################################
pdb_dict_list = parse_PDB(pdb_path, input_chain_list=chain_list)
dataset_valid = StructureDatasetPDB(pdb_dict_list, truncate=None, max_length=max_length)

chain_id_dict = {}
chain_id_dict[pdb_dict_list[0]['name']]= (designed_chain_list, fixed_chain_list)

print(chain_id_dict)
for chain in chain_list:
  l = len(pdb_dict_list[0][f"seq_chain_{chain}"])
  print(f"Length of chain {chain} is {l}")

if homomer:
  tied_positions_dict = make_tied_positions_for_homomers(pdb_dict_list)
else:
  tied_positions_dict = None

{'1O91': (['A', 'B', 'C'], [])}
Length of chain A is 131
Length of chain B is 131
Length of chain C is 131


In [None]:
from torch_geometric.data import Data, Batch
import torch
import itertools
import numpy as np

def tied_featurize(batch, device, chain_dict, fixed_position_dict=None, omit_AA_dict=None,
                   tied_positions_dict=None, pssm_dict=None, bias_by_res_dict=None,
                   ca_only=False, edge_threshold=10.0):
    """ Featurize the batch into a PyTorch Geometric compatible format. """
    alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
    B = len(batch)

    data_list = []  # List to store Data objects for each protein

    for i, b in enumerate(batch):
        # Get sequence length
        L = len(b['seq'])

        # Determine visible and masked chains
        if chain_dict:
            masked_chains, visible_chains = chain_dict[b['name']]  # List of masked and visible chains
        else:
            masked_chains = [item[-1:] for item in list(b) if item[:10] == 'seq_chain_']
            visible_chains = []

        # Prepare node features (coordinates, chain encoding, PSSM, etc.)
        node_features = []
        sequence = []
        residue_idx = []
        chain_encoding = []
        pssm_coef_list = []
        pssm_bias_list = []
        pssm_log_odds_list = []
        bias_by_res_list = []

        global_idx_start = 0
        for chain in visible_chains + masked_chains:
            chain_seq = ''.join([aa if aa != '-' else 'X' for aa in b[f'seq_chain_{chain}']])
            chain_length = len(chain_seq)
            coords = b[f'coords_chain_{chain}']  # Residue coordinates for the chain

            # If ca_only, use CA coordinates only
            if ca_only:
                chain_coords = np.array(coords[f'CA_chain_{chain}'])
                chain_coords = chain_coords[:, None, :]  # [L, 1, 3]
            else:
                chain_coords = np.stack([coords[c] for c in ['N_chain_' + chain, 'CA_chain_' + chain, 'C_chain_' + chain, 'O_chain_' + chain]], axis=1)

            node_features.append(torch.tensor(chain_coords, dtype=torch.float32))
            sequence += [alphabet.index(aa) for aa in chain_seq]
            residue_idx.append(torch.arange(global_idx_start, global_idx_start + chain_length))
            chain_encoding.append(torch.full((chain_length,), i + 1, dtype=torch.int32))  # Chain encoding

            global_idx_start += chain_length

            # Handle PSSM and bias information
            pssm_coef = pssm_dict[b['name']][chain]['pssm_coef'] if pssm_dict else np.zeros(chain_length)
            pssm_bias = pssm_dict[b['name']][chain]['pssm_bias'] if pssm_dict else np.zeros([chain_length, 21])
            pssm_log_odds = pssm_dict[b['name']][chain]['pssm_log_odds'] if pssm_dict else 10000.0 * np.ones([chain_length, 21])
            bias_by_res = bias_by_res_dict[b['name']][chain] if bias_by_res_dict else np.zeros([chain_length, 21])

            pssm_coef_list.append(torch.tensor(pssm_coef, dtype=torch.float32))
            pssm_bias_list.append(torch.tensor(pssm_bias, dtype=torch.float32))
            pssm_log_odds_list.append(torch.tensor(pssm_log_odds, dtype=torch.float32))
            bias_by_res_list.append(torch.tensor(bias_by_res, dtype=torch.float32))

        node_features = torch.cat(node_features, dim=0)  # Concatenate all node features for the protein
        sequence = torch.tensor(sequence, dtype=torch.long)
        residue_idx = torch.cat(residue_idx, dim=0)
        chain_encoding = torch.cat(chain_encoding, dim=0)

        pssm_coef_all = torch.cat(pssm_coef_list, dim=0)
        pssm_bias_all = torch.cat(pssm_bias_list, dim=0)
        pssm_log_odds_all = torch.cat(pssm_log_odds_list, dim=0)
        bias_by_res_all = torch.cat(bias_by_res_list, dim=0)

        # Create edge index based on sequence adjacency (residue connectivity)
        edge_index = []
        for j in range(L - 1):
            edge_index.append([j, j + 1])
            edge_index.append([j + 1, j])

        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()  # [2, num_edges]

        # Optionally, create edge features (e.g., distances between residues)
        edge_attr = None
        if edge_threshold:
            distances = torch.norm(node_features[:, 1, :] - node_features[:, 0, :], dim=-1)  # CA distance as an example
            edges_within_threshold = (distances < edge_threshold).nonzero(as_tuple=False)
            edge_attr = distances[edges_within_threshold]
            edge_index = edges_within_threshold.t().contiguous()  # Update edge index based on threshold

        # Pack everything into a Data object
        data = Data(
            x=node_features,  # Node features (coordinates, etc.)
            edge_index=edge_index.to(device),  # Edge indices
            edge_attr=edge_attr.to(device) if edge_attr is not None else None,  # Optional edge features
            sequence=sequence.to(device),  # Sequence of residues (as indices)
            pssm_coef=pssm_coef_all.to(device),
            pssm_bias=pssm_bias_all.to(device),
            pssm_log_odds=pssm_log_odds_all.to(device),
            bias_by_res=bias_by_res_all.to(device),
            chain_encoding=chain_encoding.to(device),
            residue_idx=residue_idx.to(device)
        )

        data_list.append(data)

    # Combine all data objects into a batch
    batch_data = Batch.from_data_list(data_list)

    return batch_data


In [None]:
with torch.no_grad():
    print('Generating sequences...')
    for ix, protein in enumerate(dataset_valid):
        score_list = []
        all_probs_list = []
        all_log_probs_list = []
        S_sample_list = []
        batch_clones = [copy.deepcopy(protein) for i in range(BATCH_COPIES)]

        # Instead of X, S, mask, now we use Data objects
        data_list = tied_featurize(
            batch=batch_clones,                     # Your batch of data
            device=device,                   # CPU or GPU device
            chain_dict=None,                 # Chain dictionary or None
            fixed_position_dict=None,        # Fixed position dictionary or None
            omit_AA_dict=None,               # Omit amino acids dictionary or None
            tied_positions_dict=None,        # Tied positions dictionary or None
            pssm_dict=None,                  # PSSM dictionary or None
            bias_by_res_dict=None,           # Bias by residue dictionary or None
            ca_only=False                    # Whether to use only CA atoms
        )

        # Iterate over the Data objects
        for data in data_list:
            data = data_list
            data = data.to(device)

            randn_1 = torch.randn(data.x.shape, device=data.x.device)
            log_probs = model(data.x, data.edge_index, data.edge_attr, data.mask, randn_1)

            # Compute the mask for loss as per the new model's structure
            mask_for_loss = data.mask

            # Compute scores (adapt this to work with the geometric model)
            scores = _scores(S, log_probs, mask_for_loss)
            native_score = scores.cpu().data.numpy()

            # Continue with sampling, adjusting sampling method to use graph data
            for temp in temperatures:
                for j in range(NUM_BATCHES):
                    randn_2 = torch.randn(data.x.shape, device=data.x.device)

                    sample_dict = model.sample(data.x, randn_2, data.edge_index, ...)
                    S_sample = sample_dict["S"]

                    log_probs = model(data.x, data.edge_index, data.edge_attr, S_sample, randn_2)
                    mask_for_loss = data.mask

                    scores = _scores(S_sample, log_probs, mask_for_loss)
                    scores = scores.cpu().data.numpy()

                    all_probs_list.append(sample_dict["probs"].cpu().data.numpy())
                    all_log_probs_list.append(log_probs.cpu().data.numpy())
                    S_sample_list.append(S_sample.cpu().data.numpy())

                    # Continue with other parts of sequence recovery and output printing...
                    for b_ix in range(BATCH_COPIES):
                      masked_chain_length_list = masked_chain_length_list_list[b_ix]
                      masked_list = masked_list_list[b_ix]
                      seq_recovery_rate = torch.sum(torch.sum(torch.nn.functional.one_hot(S[b_ix], 21)*torch.nn.functional.one_hot(S_sample[b_ix], 21),axis=-1)*mask_for_loss[b_ix])/torch.sum(mask_for_loss[b_ix])
                      seq = _S_to_seq(S_sample[b_ix], chain_M[b_ix])
                      score = scores[b_ix]
                      score_list.append(score)
                      native_seq = _S_to_seq(S[b_ix], chain_M[b_ix])
                      if b_ix == 0 and j==0 and temp==temperatures[0]:
                          start = 0
                          end = 0
                          list_of_AAs = []
                          for mask_l in masked_chain_length_list:
                              end += mask_l
                              list_of_AAs.append(native_seq[start:end])
                              start = end
                          native_seq = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))
                          l0 = 0
                          for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:
                              l0 += mc_length
                              native_seq = native_seq[:l0] + '/' + native_seq[l0:]
                              l0 += 1
                          sorted_masked_chain_letters = np.argsort(masked_list_list[0])
                          print_masked_chains = [masked_list_list[0][i] for i in sorted_masked_chain_letters]
                          sorted_visible_chain_letters = np.argsort(visible_list_list[0])
                          print_visible_chains = [visible_list_list[0][i] for i in sorted_visible_chain_letters]
                          native_score_print = np.format_float_positional(np.float32(native_score.mean()), unique=False, precision=4)
                          line = '>{}, score={}, fixed_chains={}, designed_chains={}, model_name={}\n{}\n'.format(name_, native_score_print, print_visible_chains, print_masked_chains, model_name, native_seq)
                          print(line.rstrip())
                      start = 0
                      end = 0
                      list_of_AAs = []
                      for mask_l in masked_chain_length_list:
                          end += mask_l
                          list_of_AAs.append(seq[start:end])
                          start = end

                      seq = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))
                      l0 = 0
                      for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:
                          l0 += mc_length
                          seq = seq[:l0] + '/' + seq[l0:]
                          l0 += 1
                      score_print = np.format_float_positional(np.float32(score), unique=False, precision=4)
                      seq_rec_print = np.format_float_positional(np.float32(seq_recovery_rate.detach().cpu().numpy()), unique=False, precision=4)
                      line = '>T={}, sample={}, score={}, seq_recovery={}\n{}\n'.format(temp,b_ix,score_print,seq_rec_print,seq)
                      print(line.rstrip())

all_probs_concat = np.concatenate(all_probs_list)
all_log_probs_concat = np.concatenate(all_log_probs_list)
S_sample_concat = np.concatenate(S_sample_list)

Generating sequences...


AttributeError: 'GlobalStorage' object has no attribute 'mask'

In [None]:
data

DataBatch(x=[393, 4, 3], edge_index=[1, 393], edge_attr=[393, 1], sequence=[393], pssm_coef=[393], pssm_bias=[393, 21], pssm_log_odds=[393, 21], bias_by_res=[393, 21], chain_encoding=[393], residue_idx=[393], batch=[393], ptr=[2])

In [None]:
#@markdown ### Amino acid probabilties
import plotly.express as px
fig = px.imshow(np.exp(all_log_probs_concat).mean(0).T,
                labels=dict(x="positions", y="amino acids", color="probability"),
                y=list(alphabet),
                template="simple_white"
               )

fig.update_xaxes(side="top")


fig.show()

In [None]:
#@markdown ### Sampling temperature adjusted amino acid probabilties
import plotly.express as px
fig = px.imshow(all_probs_concat.mean(0).T,
                labels=dict(x="positions", y="amino acids", color="probability"),
                y=list(alphabet),
                template="simple_white"
               )

fig.update_xaxes(side="top")


fig.show()

In [None]:

#@title Clone github repo
import json, time, os, sys, glob

if not os.path.isdir("ProteinMPNN"):
  os.system("git clone -q https://github.com/dauparas/ProteinMPNN.git")
sys.path.append('/content/ProteinMPNN')

#@title Setup Model
import matplotlib.pyplot as plt
import shutil
import warnings
import numpy as np
import torch
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split, Subset
import copy
import torch.nn as nn
import torch.nn.functional as F
import random
import os.path
from protein_mpnn_utils import loss_nll, loss_smoothed, gather_edges, gather_nodes, gather_nodes_t, cat_neighbors_nodes, _scores, _S_to_seq, tied_featurize, parse_PDB
from protein_mpnn_utils import StructureDataset, StructureDatasetPDB, ProteinMPNN

device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
#v_48_010=version with 48 edges 0.10A noise
model_name = "v_48_020" #@param ["v_48_002", "v_48_010", "v_48_020", "v_48_030"]


backbone_noise=0.00               # Standard deviation of Gaussian noise to add to backbone atoms

path_to_model_weights='/content/ProteinMPNN/vanilla_model_weights'
hidden_dim = 128
num_layers = 3
model_folder_path = path_to_model_weights
if model_folder_path[-1] != '/':
    model_folder_path = model_folder_path + '/'
checkpoint_path = model_folder_path + f'{model_name}.pt'

checkpoint = torch.load(checkpoint_path, map_location=device)
print('Number of edges:', checkpoint['num_edges'])
noise_level_print = checkpoint['noise_level']
print(f'Training noise level: {noise_level_print}A')
model = ProteinMPNN(num_letters=21, node_features=hidden_dim, edge_features=hidden_dim, hidden_dim=hidden_dim, num_encoder_layers=num_layers, num_decoder_layers=num_layers, augment_eps=backbone_noise, k_neighbors=checkpoint['num_edges'])
model.to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print("Model loaded")

#@title Helper functions
def make_tied_positions_for_homomers(pdb_dict_list):
    my_dict = {}
    for result in pdb_dict_list:
        all_chain_list = sorted([item[-1:] for item in list(result) if item[:9]=='seq_chain']) #A, B, C, ...
        tied_positions_list = []
        chain_length = len(result[f"seq_chain_{all_chain_list[0]}"])
        for i in range(1,chain_length+1):
            temp_dict = {}
            for j, chain in enumerate(all_chain_list):
                temp_dict[chain] = [i] #needs to be a list
            tied_positions_list.append(temp_dict)
        my_dict[result['name']] = tied_positions_list
    return my_dict


import re
from google.colab import files
import numpy as np

#########################
def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  else:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"

#@markdown ### Input Options
pdb='1O91' #@param {type:"string"}
pdb_path = get_pdb(pdb)
#@markdown - pdb code (leave blank to get an upload prompt)

homomer = True #@param {type:"boolean"}
designed_chain = "A B C" #@param {type:"string"}
fixed_chain = "" #@param {type:"string"}

if designed_chain == "":
  designed_chain_list = []
else:
  designed_chain_list = re.sub("[^A-Za-z]+",",", designed_chain).split(",")

if fixed_chain == "":
  fixed_chain_list = []
else:
  fixed_chain_list = re.sub("[^A-Za-z]+",",", fixed_chain).split(",")

chain_list = list(set(designed_chain_list + fixed_chain_list))

#@markdown - specified which chain(s) to design and which chain(s) to keep fixed.
#@markdown   Use comma:`A,B` to specifiy more than one chain

#chain = "A" #@param {type:"string"}
#pdb_path_chains = chain
##@markdown - Define which chain to redesign

#@markdown ### Design Options
num_seqs = 1 #@param ["1", "2", "4", "8", "16", "32", "64"] {type:"raw"}
num_seq_per_target = num_seqs

#@markdown - Sampling temperature for amino acids, T=0.0 means taking argmax, T>>1.0 means sample randomly.
sampling_temp = "0.1" #@param ["0.0001", "0.1", "0.15", "0.2", "0.25", "0.3", "0.5"]



save_score=0                      # 0 for False, 1 for True; save score=-log_prob to npy files
save_probs=0                      # 0 for False, 1 for True; save MPNN predicted probabilites per position
score_only=0                      # 0 for False, 1 for True; score input backbone-sequence pairs
conditional_probs_only=0          # 0 for False, 1 for True; output conditional probabilities p(s_i given the rest of the sequence and backbone)
conditional_probs_only_backbone=0 # 0 for False, 1 for True; if true output conditional probabilities p(s_i given backbone)

batch_size=1                      # Batch size; can set higher for titan, quadro GPUs, reduce this if running out of GPU memory
max_length=20000                  # Max sequence length

out_folder='.'                    # Path to a folder to output sequences, e.g. /home/out/
jsonl_path=''                     # Path to a folder with parsed pdb into jsonl
omit_AAs='X'                      # Specify which amino acids should be omitted in the generated sequence, e.g. 'AC' would omit alanine and cystine.

pssm_multi=0.0                    # A value between [0.0, 1.0], 0.0 means do not use pssm, 1.0 ignore MPNN predictions
pssm_threshold=0.0                # A value between -inf + inf to restric per position AAs
pssm_log_odds_flag=0               # 0 for False, 1 for True
pssm_bias_flag=0                   # 0 for False, 1 for True


##############################################################

folder_for_outputs = out_folder

NUM_BATCHES = num_seq_per_target//batch_size
BATCH_COPIES = batch_size
temperatures = [float(item) for item in sampling_temp.split()]
omit_AAs_list = omit_AAs
alphabet = 'ACDEFGHIKLMNPQRSTVWYX'

omit_AAs_np = np.array([AA in omit_AAs_list for AA in alphabet]).astype(np.float32)

chain_id_dict = None
fixed_positions_dict = None
pssm_dict = None
omit_AA_dict = None
bias_AA_dict = None
tied_positions_dict = None
bias_by_res_dict = None
bias_AAs_np = np.zeros(len(alphabet))


###############################################################
pdb_dict_list = parse_PDB(pdb_path, input_chain_list=chain_list)
dataset_valid = StructureDatasetPDB(pdb_dict_list, truncate=None, max_length=max_length)

chain_id_dict = {}
chain_id_dict[pdb_dict_list[0]['name']]= (designed_chain_list, fixed_chain_list)

print(chain_id_dict)
for chain in chain_list:
  l = len(pdb_dict_list[0][f"seq_chain_{chain}"])
  print(f"Length of chain {chain} is {l}")

if homomer:
  tied_positions_dict = make_tied_positions_for_homomers(pdb_dict_list)
else:
  tied_positions_dict = None

  checkpoint = torch.load(checkpoint_path, map_location=device)


Number of edges: 48
Training noise level: 0.2A
Model loaded
{'1O91': (['A', 'B', 'C'], [])}
Length of chain B is 131
Length of chain C is 131
Length of chain A is 131


In [None]:
#@title RUN
with torch.no_grad():
  print('Generating sequences...')
  for ix, protein in enumerate(dataset_valid):
    score_list = []
    all_probs_list = []
    all_log_probs_list = []
    S_sample_list = []
    batch_clones = [copy.deepcopy(protein) for i in range(BATCH_COPIES)]
    X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef, pssm_bias, pssm_log_odds_all, bias_by_res_all, tied_beta = tied_featurize(batch_clones, device, chain_id_dict, fixed_positions_dict, omit_AA_dict, tied_positions_dict, pssm_dict, bias_by_res_dict)
    pssm_log_odds_mask = (pssm_log_odds_all > pssm_threshold).float() #1.0 for true, 0.0 for false
    name_ = batch_clones[0]['name']

    randn_1 = torch.randn(chain_M.shape, device=X.device)
    log_probs = model(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1)
    mask_for_loss = mask*chain_M*chain_M_pos
    scores = _scores(S, log_probs, mask_for_loss)
    native_score = scores.cpu().data.numpy()

    for temp in temperatures:
        for j in range(NUM_BATCHES):
            randn_2 = torch.randn(chain_M.shape, device=X.device)
            if tied_positions_dict == None:
                sample_dict = model.sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), bias_by_res=bias_by_res_all)
                S_sample = sample_dict["S"]
            else:
                sample_dict = model.tied_sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), tied_pos=tied_pos_list_of_lists_list[0], tied_beta=tied_beta, bias_by_res=bias_by_res_all)
            # Compute scores
                S_sample = sample_dict["S"]
            log_probs = model(X, S_sample, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_2, use_input_decoding_order=True, decoding_order=sample_dict["decoding_order"])
            mask_for_loss = mask*chain_M*chain_M_pos
            scores = _scores(S_sample, log_probs, mask_for_loss)
            scores = scores.cpu().data.numpy()
            all_probs_list.append(sample_dict["probs"].cpu().data.numpy())
            all_log_probs_list.append(log_probs.cpu().data.numpy())
            S_sample_list.append(S_sample.cpu().data.numpy())
            for b_ix in range(BATCH_COPIES):
                masked_chain_length_list = masked_chain_length_list_list[b_ix]
                masked_list = masked_list_list[b_ix]
                seq_recovery_rate = torch.sum(torch.sum(torch.nn.functional.one_hot(S[b_ix], 21)*torch.nn.functional.one_hot(S_sample[b_ix], 21),axis=-1)*mask_for_loss[b_ix])/torch.sum(mask_for_loss[b_ix])
                seq = _S_to_seq(S_sample[b_ix], chain_M[b_ix])
                score = scores[b_ix]
                score_list.append(score)
                native_seq = _S_to_seq(S[b_ix], chain_M[b_ix])
                if b_ix == 0 and j==0 and temp==temperatures[0]:
                    start = 0
                    end = 0
                    list_of_AAs = []
                    for mask_l in masked_chain_length_list:
                        end += mask_l
                        list_of_AAs.append(native_seq[start:end])
                        start = end
                    native_seq = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))
                    l0 = 0
                    for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:
                        l0 += mc_length
                        native_seq = native_seq[:l0] + '/' + native_seq[l0:]
                        l0 += 1
                    sorted_masked_chain_letters = np.argsort(masked_list_list[0])
                    print_masked_chains = [masked_list_list[0][i] for i in sorted_masked_chain_letters]
                    sorted_visible_chain_letters = np.argsort(visible_list_list[0])
                    print_visible_chains = [visible_list_list[0][i] for i in sorted_visible_chain_letters]
                    native_score_print = np.format_float_positional(np.float32(native_score.mean()), unique=False, precision=4)
                    line = '>{}, score={}, fixed_chains={}, designed_chains={}, model_name={}\n{}\n'.format(name_, native_score_print, print_visible_chains, print_masked_chains, model_name, native_seq)
                    print(line.rstrip())
                start = 0
                end = 0
                list_of_AAs = []
                for mask_l in masked_chain_length_list:
                    end += mask_l
                    list_of_AAs.append(seq[start:end])
                    start = end

                seq = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))
                l0 = 0
                for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:
                    l0 += mc_length
                    seq = seq[:l0] + '/' + seq[l0:]
                    l0 += 1
                score_print = np.format_float_positional(np.float32(score), unique=False, precision=4)
                seq_rec_print = np.format_float_positional(np.float32(seq_recovery_rate.detach().cpu().numpy()), unique=False, precision=4)
                line = '>T={}, sample={}, score={}, seq_recovery={}\n{}\n'.format(temp,b_ix,score_print,seq_rec_print,seq)
                print(line.rstrip())


all_probs_concat = np.concatenate(all_probs_list)
all_log_probs_concat = np.concatenate(all_log_probs_list)
S_sample_concat = np.concatenate(S_sample_list)

Generating sequences...
>1O91, score=1.3138, fixed_chains=[], designed_chains=['A', 'B', 'C'], model_name=v_48_020
EMPAFTAELTVPFPPVGAPVKFDKLLYNGRQNYNPQTGIFTCEVPGVYYFAYHVHCKGGNVWVALFKNNEPMMYTYDEYKKGFLDQASGSAVLLLRPGDQVFLQMPSEQAAGLYAGQYVHSSFSGYLLYPM/EMPAFTAELTVPFPPVGAPVKFDKLLYNGRQNYNPQTGIFTCEVPGVYYFAYHVHCKGGNVWVALFKNNEPMMYTYDEYKKGFLDQASGSAVLLLRPGDQVFLQMPSEQAAGLYAGQYVHSSFSGYLLYPM/EMPAFTAELTVPFPPVGAPVKFDKLLYNGRQNYNPQTGIFTCEVPGVYYFAYHVHCKGGNVWVALFKNNEPMMYTYDEYKKGFLDQASGSAVLLLRPGDQVFLQMPSEQAAGLYAGQYVHSSFSGYLLYPM
>T=0.1, sample=0, score=0.6546, seq_recovery=0.4733
TVEAFTALLTTANPAVGTPVKFNTLIYNGGNVYDPATGVFTCKTEGIYLFNWVLYCYGNDLHAVLMKNDTPILNQYLQNVDGKINQVSGSAILELKKGDKVYVKIPSSSANGLYASSTNHSYFSGYLLYPL/TVEAFTALLTTANPAVGTPVKFNTLIYNGGNVYDPATGVFTCKTEGIYLFNWVLYCYGNDLHAVLMKNDTPILNQYLQNVDGKINQVSGSAILELKKGDKVYVKIPSSSANGLYASSTNHSYFSGYLLYPL/TVEAFTALLTTANPAVGTPVKFNTLIYNGGNVYDPATGVFTCKTEGIYLFNWVLYCYGNDLHAVLMKNDTPILNQYLQNVDGKINQVSGSAILELKKGDKVYVKIPSSSANGLYASSTNHSYFSGYLLYPL


In [None]:
data

NameError: name 'data' is not defined

In [None]:
!wget https://files.ipd.uw.edu/pub/training_sets/pdb_2021aug02_sample.tar.gz
!tar -xvzf pdb_2021aug02_sample.tar.gz

--2024-10-29 18:40:08--  https://files.ipd.uw.edu/pub/training_sets/pdb_2021aug02_sample.tar.gz
Resolving files.ipd.uw.edu (files.ipd.uw.edu)... 128.95.160.134, 128.95.160.135, 2607:4000:406::160:134, ...
Connecting to files.ipd.uw.edu (files.ipd.uw.edu)|128.95.160.134|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 49690915 (47M) [application/octet-stream]
Saving to: ‘pdb_2021aug02_sample.tar.gz.1’


2024-10-29 18:40:10 (48.5 MB/s) - ‘pdb_2021aug02_sample.tar.gz.1’ saved [49690915/49690915]

./pdb_2021aug02_sample/
./pdb_2021aug02_sample/README
./pdb_2021aug02_sample/list.csv
./pdb_2021aug02_sample/pdb/
./pdb_2021aug02_sample/pdb/l3/
./pdb_2021aug02_sample/pdb/l3/5l3p.pt
./pdb_2021aug02_sample/pdb/l3/5l3g_A.pt
./pdb_2021aug02_sample/pdb/l3/5l3f.pt
./pdb_2021aug02_sample/pdb/l3/5l3r_B.pt
./pdb_2021aug02_sample/pdb/l3/4l3o_G.pt
./pdb_2021aug02_sample/pdb/l3/1l3b_E.pt
./pdb_2021aug02_sample/pdb/l3/3l3t_C.pt
./pdb_2021aug02_sample/pdb/l3/6l3y_A.pt
./pdb_2021aug02

In [None]:
import json, time, os, sys, glob

if not os.path.isdir("ProteinMPNN"):
  os.system("git clone -q https://github.com/dauparas/ProteinMPNN.git")
sys.path.append('/content/ProteinMPNN')

In [None]:
import json, time, os, sys, glob
import shutil
import warnings
import numpy as np
import torch
from torch import optim
from torch.utils.data import DataLoader
import queue
import copy
import torch.nn as nn
import torch.nn.functional as F
import random
import os.path
import subprocess
from concurrent.futures import ProcessPoolExecutor
from training.utils import worker_init_fn, get_pdbs, loader_pdb, build_training_clusters, PDB_dataset, StructureDataset, StructureLoader
from training.model_utils import featurize, loss_smoothed, loss_nll, get_std_opt, ProteinMPNN


In [None]:
data_path = "/content/pdb_2021aug02_sample"
params = {
    "LIST"    : f"{data_path}/list.csv",
    "VAL"     : f"{data_path}/valid_clusters.txt",
    "TEST"    : f"{data_path}/test_clusters.txt",
    "DIR"     : f"{data_path}",
    "DATCUT"  : "2030-Jan-01",
    "RESCUT"  : 3.5, #resolution cutoff for PDBs
    "HOMO"    : 0.70 #min seq.id. to detect homo chains
}

In [None]:
LOAD_PARAM = {'batch_size': 1,
                  'shuffle': True,
                  'pin_memory':False,
                  'num_workers': 4}

In [None]:
train, valid, test = build_training_clusters(params,False)

In [None]:
train_set = PDB_dataset(list(train.keys()), loader_pdb, train, params)
train_loader = torch.utils.data.DataLoader(train_set, worker_init_fn=worker_init_fn, **LOAD_PARAM)
valid_set = PDB_dataset(list(valid.keys()), loader_pdb, valid, params)
valid_loader = torch.utils.data.DataLoader(valid_set, worker_init_fn=worker_init_fn, **LOAD_PARAM)



In [None]:
from torch_geometric.data import Data

for batch in train_loader:
  if len(batch.keys()) == 5:
         print(batch["label"])

  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))


['6l3h_A']


  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))


['6l30_A']


  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))


['1l3p_A']
['6l35_M']
['5l3s_B']
['3l30_A']
['4l3a_A']
['4l3f_A']
['3l3s_D']
['4l3r_A']
['4l3k_B']
['7l3l_A']


  meta = torch.load(PREFIX+".pt")
  chains = {c:torch.load("%s_%s.pt"%(PREFIX,c))


['4l3u_A']
['3l39_A']
['3l32_B']
['5l3x_B']
['5l3q_D']
['6l3q_A']
['5l3x_A']
['1l3b_A']
['5l35_D']


{12123: [['5naf_A', '057113'],
  ['5naf_B', '057113'],
  ['5naf_C', '057113'],
  ['5naf_D', '057113'],
  ['4lg9_A', '057231']],
 21385: [['5nai_A', '067953'],
  ['5nbk_A', '053744'],
  ['5nbk_B', '053744'],
  ['5nhz_A', '052880'],
  ['5nhz_B', '052880'],
  ['5ni0_A', '052880'],
  ['5ni0_B', '052880'],
  ['4nq2_A', '052879'],
  ['4nq4_A', '038638'],
  ['4nq5_A', '038638'],
  ['4nq6_A', '038638'],
  ['2nxa_A', '106143'],
  ['2nyp_A', '106143'],
  ['6ny7_A', '028379'],
  ['2nze_A', '038637'],
  ['2nze_B', '038637'],
  ['2nzf_A', '106144'],
  ['5o2e_A', '018443'],
  ['5o2e_B', '018443'],
  ['5o2f_A', '028383'],
  ['5o2f_B', '028383'],
  ['6o3r_A', '028379'],
  ['6o5t_A', '025654'],
  ['6o5t_B', '025654'],
  ['5o7n_A', '060873'],
  ['5o7n_B', '060873'],
  ['6ogo_A', '018815'],
  ['6ogo_B', '018815'],
  ['6ogo_C', '018815'],
  ['1bc2_A', '102174'],
  ['1bc2_B', '102174'],
  ['2bc2_A', '102178'],
  ['2bc2_B', '102178'],
  ['6ol8_A', '018813'],
  ['6ol8_B', '018813'],
  ['3bc2_A', '102178'],
 

In [None]:
!pwd

/content
