In [1]:
# Colab environment setup
import numpy as np
# Install the correct version of Pytorch Geometric.
import torch

def format_pytorch_version(version):
  return version.split('+')[0]

TORCH_version = torch.__version__
TORCH = format_pytorch_version(TORCH_version)

def format_cuda_version(version):
  return 'cu' + version.replace('.', '')

CUDA_version = torch.version.cuda
CUDA = format_cuda_version(CUDA_version)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-cluster -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-spline-conv -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install -q torch-geometric

# Install esm
!pip install -q git+https://github.com/facebookresearch/esm.git

# Install biotite
!pip install -q biotite

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m22.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.8/4.8 MB[0m [31m20.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m884.9/884.9 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m661.6/661.6 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for torch-geometric (pyproject.toml) ... [?25l[?25hdone
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.tom

In [2]:
## Verify that pytorch-geometric is correctly installed
import torch_geometric
import torch_sparse
from torch_geometric.nn import MessagePassing

In [3]:
import esm
model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50()
# use eval mode for deterministic output e.g. without random dropout
model = model.cuda().eval()
# model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50()
# # use eval mode for deterministic output e.g. without random dropout
# model = model.eval()

Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm_if1_gvp4_t16_142M_UR50.pt" to /root/.cache/torch/hub/checkpoints/esm_if1_gvp4_t16_142M_UR50.pt


In [4]:
from pathlib import Path
import pandas as pd
import numpy as np
import h5py

In [5]:
# Let's start with google disk folder with preprocessed files

In [6]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [7]:
DATADIR = Path("/content/gdrive/MyDrive/MISATO-experiments")


In [8]:
WORKDIR = Path("/content/workdir")
WORKDIR.mkdir(exist_ok=True)
INPUTDIR = DATADIR / "preprocessed"
ESMIFDIR = DATADIR / "esm-if-embeddings-emb"
ESMIFDIR.mkdir(exist_ok=True)

In [9]:
# dff.chain.value_counts()

Currently there is a bug at `get_encoder_output` method, so I had to modify it (sent them pull request), partially the following code might not be in use - for multichain complexes I've decided to compute the embeddings for each chain separately and not to use _concatenate_coords method. But I might use it later.

In [10]:

def _concatenate_coords(coords, padding_length=0):
    """
    Args:
        coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
            coordinates representing the backbone of each chain
        target_chain_id: The chain id to sample sequences for
        padding_length: Length of padding between concatenated chains
    Returns:
        Tuple (coords, seq)
            - coords is an L x 3 x 3 array for N, CA, C coordinates, a
              concatenation of the chains with padding in between
            - seq is the extracted sequence, with padding tokens inserted
              between the concatenated chains
    """
    pad_coords = np.full((padding_length, 3, 3), np.nan, dtype=np.float32)
    # For best performance, put the target chain first in concatenation.
    coords_list = [] #[coords[target_chain_id]]
    for chain_id in sorted(coords):
        # if chain_id == target_chain_id:
        #     continue
        coords_list.append(pad_coords)
        coords_list.append(coords[chain_id])
    coords_concatenated = np.concatenate(coords_list, axis=0)
    return coords_concatenated

def get_encoder_output(model, alphabet, coords, seq=None):
    device = next(model.parameters()).device
    batch_converter = esm.inverse_folding.util.CoordBatchConverter(alphabet)
    batch = [(coords, None, seq)]
    coords, confidence, strs, tokens, padding_mask = batch_converter(
        batch, device=device)
    encoder_out = model.encoder.forward(coords, padding_mask, confidence,
            return_all_hiddens=False)
    # remove beginning and end (bos and eos tokens)
    return encoder_out['encoder_out'][0][1:-1, 0]

def get_encoder_output_for_complex(model, alphabet, coords):
  all_coords = _concatenate_coords(coords)
  all_rep = get_encoder_output(model, alphabet, all_coords)
  #target_chain_len = coords[target_chain_id].shape[0]
  # return all_rep[:target_chain_len]
  return all_rep


# target_chain_id = 'A'
# rep = get_encoder_output_for_complex(model, alphabet, coords)
# len(coords[target_chain_id]), rep.shape, [coords[k].shape for k in coords.keys()]

In [11]:
# all_coords = esm.inverse_folding.multichain_util._concatenate_coords(coords, target_chain_id)


In [12]:
# all_coords.shape

In [13]:
next(model.parameters()).device

device(type='cuda', index=0)

In [14]:
# len(coords[target_chain_id]), rep.shape, np.sum([len(coords[k]) for k in coords.keys()])
!ls {INPUTDIR}

esm_if_input.npz       md_esm_if_10000_.hdf5	  md_test_out.hdf5
md_esm_if_0_5000.hdf5  md_esm_if_5000_10000.hdf5  misato_sequences_info.csv


In [44]:
from tqdm.auto import tqdm
h5files = [
    # "md_test_out.hdf5"
    "md_esm_if_0_5000.hdf5",
    "md_esm_if_5000_10000.hdf5",
    "md_esm_if_10000_.hdf5"
]
# during the preprocess I've stored all the data in 3 smaller parts, 
# to be able to start writing the processing script for the embeddings
# extraction while my data is still being processed
data = dict()
for filename in tqdm(h5files):
  h5path = INPUTDIR / filename
  
  with h5py.File(h5path) as f:
    for pdbid in tqdm(f.keys(), total=len(f.keys())):
      data[pdbid] = f[pdbid][()]



  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/4337 [00:00<?, ?it/s]

  0%|          | 0/4263 [00:00<?, ?it/s]

  0%|          | 0/5853 [00:00<?, ?it/s]

In [45]:
len(data)

14453

In [19]:
# np.savez_compressed(INPUTDIR/"esm_if_input.npz", **data)

In [20]:
df = pd.read_csv(INPUTDIR/"misato_sequences_info.csv")
# sequences = df.sequence.unique()  #.shape, df.shape
dff = df.groupby("pdbid").max()[["chain"]].reset_index()
df['seqlen'] = df.sequence.apply(len)

In [21]:
pdb_chain_positions = dict()
for pdbid, row in df.groupby("pdbid"):
  row = row[row.full==0]
  chain_lengths = {chain: seqlen for chain, seqlen in row[["chain", "seqlen"]].values}
  start = 0
  positions = []
  for chain in sorted(row.chain):
    offset = chain_lengths[chain]
    positions.append((chain, start, start + offset))
    start+=offset
  pdb_chain_positions[pdbid] = positions

# row

In [22]:
len(pdb_chain_positions)

16972

In [23]:
SAVEDIR = ESMIFDIR / "npz_frame0" # change the folder to "npz" for all the frames
SAVEDIR.mkdir(exist_ok=True)

The method `get_encoder_output_modified` given below tries to produce embeddings in batches for each structure, if the nframes > 1, otherwise it just returns the embeddings (separate arrays for each chain) for the structure in the frame 0.

In [24]:



batch_converter = esm.inverse_folding.util.CoordBatchConverter(alphabet)


def get_encoder_output_modified(model, alphabet, all_coords, batch_size=8, nframes=1):
    
    device = next(model.parameters()).device
    # batch_converter = esm.inverse_folding.util.CoordBatchConverter(alphabet)
    outputs = []
    for i in range(0, nframes, batch_size):#, total=all_coords.shape[0]//batch_size):
      batch = []
      total = min(all_coords[i:i+batch_size].shape[0], nframes)
      for j in range(min(batch_size, total)):
        batch.append((all_coords[j], None, None))
      # print(len(batch))
      #batch = [(coords, None, None)]
      coords, confidence, strs, tokens, padding_mask = batch_converter(
        batch, device=device)
      with torch.no_grad():
        x, components = model.encoder.forward_embedding(coords, padding_mask, confidence)
      # remove beginning and end (bos and eos tokens)
      # print(len(encoder_out['encoder_out']), list(encoder_out))
      # print(encoder_out['encoder_embedding'][0].shape)
      #out = encoder_out['encoder_out'][0][1:-1, :].detach().cpu().numpy()
      #out = np.swapaxes(out, 0, 1)
      #out = [o[1:-1, 0].detach().cpu().numpy() in encoder_out['encoder_out']]
      outputs.append(x.detach().cpu().numpy())
      #del encoder_out
      del x, components
      del coords, confidence, strs, tokens, padding_mask
      # break
    # del batch_converter
    return outputs

In [25]:
dff = df.groupby("pdbid").max()[["chain", "split_name", "seqlen"]]
# dff[dff.chain == 0].seqlen.mean()#.split_name.value_counts()
sel_pdbids = set(dff[dff.chain == 0 ].index)#, dff[dff.chain != 0].shape

In [46]:
from tqdm.auto import tqdm
counter = 0 
pdbids = sorted(set(data.keys()))# & sel_pdbids)
# for pdbid in tqdm(data.keys(), total=len(data.keys())):
for pdbid in tqdm(pdbids, total=len(pdbids)):
  counter+=1
  # if counter < 4500:
  #   continue
  # if counter > 6000:
  #    break
  #if len(pdb_chain_positions[pdbid]) > 0:
  #  continue
  path = SAVEDIR / f"{pdbid}.npz"
  if path.exists():
    continue
  if not pdbid in pdb_chain_positions:
    print(f"{pdbid} not in pdb_chain_positions")
    continue
  trajectory = np.asarray(data[pdbid]).astype(np.float32)
  
  chain_embeddings = dict()
  for chain, b, e in (pdb_chain_positions[pdbid]):
    coords = trajectory[:, b:e]
    with torch.no_grad():
      rep = get_encoder_output_modified(model, alphabet, coords, batch_size=1, nframes=1)
    # frame_data = []
    # for frame in range(100):
    #   coords = trajectory[frame, b:e]
    
    #   rep = get_encoder_output(model, alphabet, coords).detach().cpu().numpy()
    
    #   frame_data.append(rep)
    frame_data = np.concatenate(rep)  #np.stack(frame_data)
    chain_embeddings[str(chain)] = frame_data
    # break

  np.savez_compressed(path, **chain_embeddings)
  torch.cuda.empty_cache()

  # break

  0%|          | 0/14453 [00:00<?, ?it/s]

In [47]:
len(list(SAVEDIR.glob("*.npz")))

14453