From 4f232f167f09e3300a80d604585d327bc7184edf Mon Sep 17 00:00:00 2001 From: Jinhua Zhu <1462540095@qq.com> Date: Thu, 2 Mar 2023 14:57:57 +0800 Subject: [PATCH] Musse (#102) * Uni-Fold Musse * mv plm to ssmultimer * update readme * udpatereadme * ss to musse * rename fn * fix monomer bug --- README.md | 30 ++ infer_lm.py | 381 ++++++++++++++ run_unifold_musse.sh | 30 ++ unifold/__init__.py | 1 + unifold/config.py | 30 ++ unifold/data/process.py | 95 +++- unifold/homo_search.py | 10 +- unifold/inference.py | 80 ++- unifold/loss.py | 32 +- unifold/modules/__init__.py | 3 +- unifold/modules/alphafold.py | 81 ++- unifold/modules/auxillary_heads.py | 5 +- unifold/msa/utils.py | 12 +- unifold/musse/__init__.py | 2 + unifold/musse/dataset.py | 478 ++++++++++++++++++ unifold/musse/model.py | 47 ++ unifold/musse/modules/__init__.py | 0 unifold/musse/modules/alphafold.py | 242 +++++++++ unifold/musse/modules/auxiliary_heads.py | 7 + unifold/musse/modules/embedders.py | 90 ++++ unifold/musse/modules/evoformer.py | 165 ++++++ unifold/musse/plm/__init__.py | 1 + unifold/musse/plm/dict_esm.txt | 33 ++ unifold/musse/plm/model/__init__.py | 3 + unifold/musse/plm/model/bert.py | 365 +++++++++++++ .../musse/plm/model/multihead_attention.py | 260 ++++++++++ unifold/musse/plm/model/rotary_embedding.py | 69 +++ .../musse/plm/model/transformer_encoder.py | 235 +++++++++ .../plm/model/transformer_encoder_layer.py | 111 ++++ unifold/pack_feat.py | 149 ++++++ unifold/task.py | 9 +- 31 files changed, 2969 insertions(+), 87 deletions(-) create mode 100644 infer_lm.py create mode 100644 run_unifold_musse.sh create mode 100644 unifold/musse/__init__.py create mode 100644 unifold/musse/dataset.py create mode 100644 unifold/musse/model.py create mode 100644 unifold/musse/modules/__init__.py create mode 100644 unifold/musse/modules/alphafold.py create mode 100644 unifold/musse/modules/auxiliary_heads.py create mode 100644 unifold/musse/modules/embedders.py create mode 100644 unifold/musse/modules/evoformer.py create mode 100755 unifold/musse/plm/__init__.py create mode 100755 unifold/musse/plm/dict_esm.txt create mode 100755 unifold/musse/plm/model/__init__.py create mode 100755 unifold/musse/plm/model/bert.py create mode 100755 unifold/musse/plm/model/multihead_attention.py create mode 100755 unifold/musse/plm/model/rotary_embedding.py create mode 100755 unifold/musse/plm/model/transformer_encoder.py create mode 100755 unifold/musse/plm/model/transformer_encoder_layer.py create mode 100644 unifold/pack_feat.py diff --git a/README.md b/README.md index 8dd7b0d..f14f6fb 100644 --- a/README.md +++ b/README.md @@ -243,6 +243,36 @@ bash run_uf_symmetry.sh \ to inference with UF-Symmetry. **Note that the input FASTA file should contain the sequences of the asymmetric unit only, and a symmetry group must be specified for the model.** +## Run Uni-Fold Musse + +### Installing Uni-Fold MuSSe + +Clone & install unifold. +```shell +git clone --single-branch -b Musse git@github.com:dptech-corp/Uni-Fold.git unifold_musse +cd unifold_musse +pip install -e . +``` +### Downloading the pre-trained model parameters +Use the following command to download the parameters of our further pre-trained protein language model and single sequence protein complex predictor: +```shell +# the protein language model +wget https://bioos-hermite-beijing.tos-cn-beijing.volces.com/unifold_model/unifold_musse/plm.pt + +# the protein complex predictor +wget https://bioos-hermite-beijing.tos-cn-beijing.volces.com/unifold_model/unifold_musse/mp.pt +``` + +### Running Uni-Fold MuSSe +Run the following command to predict the structure of the target fasta: +```shell +bash run_unifold_musse.sh \ + /path/to/the/input.fasta \ # target fasta file + /path/to/the/output/directory/ \ # output directory + /path/to/multimer_model_parameters.pt \ # multimer predictor parameters + /path/to/pretrain_lm_parameters.pt # language model parameters + +``` ## Inference on Hermite We provide covenient structure prediction service on [Hermiteā„¢](https://hermite.dp.tech/), a new-generation drug design platform powered by AI, physics, and computing. Users only need to upload sequences of protein monomers and multimers to obtain the predicted structures from Uni-Fold, acompanied by various analyzing tools. [Click here](https://docs.google.com/document/d/1iFdezkKJVuhyqN3WvzsC7-422T-zf18IhP7M9CBj5gs) for more information of how to use Hermiteā„¢. diff --git a/infer_lm.py b/infer_lm.py new file mode 100644 index 0000000..a82c6f7 --- /dev/null +++ b/infer_lm.py @@ -0,0 +1,381 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import pathlib +from pathlib import Path +import sys +import gc +import os +import time +import re +from unicore import options +from functools import lru_cache +from unicore.data import Dictionary +import numpy as np +import torch +from unifold.musse.plm.model.bert import BertModel +import time + +import pickle +import gzip + +import torch.utils.data.distributed +from tqdm import tqdm + +HHBLITS_AA_TO_ID = { + "[CLS]": 0, + "[PAD]": 1, + "[SEP]": 2, + "[UNK]": 3, + "L": 4, + "A": 5, + "G": 6, + "V": 7, + "S": 8, + "E": 9, + "R": 10, + "T": 11, + "I": 12, + "D": 13, + "P": 14, + "K": 15, + "Q": 16, + "N": 17, + "F": 18, + "Y": 19, + "M": 20, + "H": 21, + "W": 22, + "C": 23, + "X": 24, + "B": 25, + "U": 26, + "Z": 27, + "O": 28, + ".": 29, + "-": 30, + "": 31, + "[MASK]": 32, +} +AA_TO_ESM = HHBLITS_AA_TO_ID + + +def parse_fasta(fasta_string: str): + """Parses FASTA string and returns list of strings with amino-acid sequences. + + Arguments: + fasta_string: The string contents of a FASTA file. + + Returns: + A tuple of two lists: + * A list of sequences. + * A list of sequence descriptions taken from the comment lines. In the + same order as the sequences. + """ + sequences = [] + descriptions = [] + index = -1 + for line in fasta_string.splitlines(): + line = line.strip() + if line.startswith("#"): + continue + if line.startswith(">"): + index += 1 + descriptions.append(line[1:]) # Remove the '>' at the beginning. + sequences.append("") + continue + elif not line: + continue # Skip blank lines. + sequences[index] += line + + return sequences, descriptions + + +class MultimerDataset: + def __init__(self, inputfn): + self.key = Path(inputfn).stem + fasta_str = open(inputfn).read() + input_seqs, _ = parse_fasta(fasta_str) + self.seqs = input_seqs + + def __len__(self): + return 1 + + @lru_cache(maxsize=16) + def __getitem__(self, idx): + assert idx == 0 + sequences = self.seqs + all_chain_features = self.get_chain_features(sequences) + is_same_entity = self.is_same_entity_func(all_chain_features) + has_same_sequence = self.has_same_sequence_func(all_chain_features) + tokenized_data = [AA_TO_ESM[tok] for sequence in sequences for tok in sequence] + tokenized_data = [0] + tokenized_data + [2] + sequence = torch.from_numpy(np.array(tokenized_data)).long() + return self.key, sequence, is_same_entity, has_same_sequence + + def get_chain_features(self, sequences): + chain_features = {} + all_chain_features = [] + seq_to_entity_id = {} + chain_id = 1 + for seq in sequences: + chain_features = {} + if str(seq) not in seq_to_entity_id: + seq_to_entity_id[str(seq)] = len(seq_to_entity_id) + 1 + chain_features["seq_length"] = len(seq) + chain_features["seq"] = seq + chain_features["asym_id"] = chain_id * np.ones(chain_features["seq_length"]) + chain_features["entity_id"] = seq_to_entity_id[str(seq)] * np.ones( + chain_features["seq_length"] + ) + + all_chain_features.append(chain_features) + chain_id += 1 + return all_chain_features + + def is_same_entity_func(self, all_chain_features): + entity_id_list = [] + for chain_id in range(1, len(all_chain_features) + 1): + if all_chain_features[chain_id - 1] is None: + continue + else: + entity_id_list.append(all_chain_features[chain_id - 1]["entity_id"]) + + entity_id = np.concatenate(entity_id_list) + seq_length = len(entity_id) + mask = entity_id.reshape(seq_length, 1) == entity_id.reshape(1, seq_length) + mask = torch.tensor(mask) + is_same_entity = torch.zeros(seq_length, seq_length).long() + is_same_entity = is_same_entity.masked_fill_(mask, 1) + + is_same_entity_bos_eos = torch.zeros( + is_same_entity.shape[0] + 2, is_same_entity.shape[0] + 2 + ).long() + + is_same_entity_bos_eos[1:-1, 1:-1] = is_same_entity + return is_same_entity_bos_eos + + def has_same_sequence_func(self, all_chain_features): + asym_id_list = [] + for chain_id in range(1, len(all_chain_features) + 1): + if all_chain_features[chain_id - 1] is None: + continue + else: + asym_id_list.append(all_chain_features[chain_id - 1]["asym_id"]) + + asym_id = np.concatenate(asym_id_list) + seq_length = len(asym_id) + mask = asym_id.reshape(seq_length, 1) == asym_id.reshape(1, seq_length) + mask = torch.tensor(mask) + has_same_sequence = torch.zeros(seq_length, seq_length).long() + has_same_sequence = has_same_sequence.masked_fill_(mask, 1) + + has_same_sequence_bos_eos = torch.zeros( + has_same_sequence.shape[0] + 2, has_same_sequence.shape[0] + 2 + ).long() + + has_same_sequence_bos_eos[1:-1, 1:-1] = has_same_sequence + return has_same_sequence_bos_eos.long() + + +def load_model_ensemble_and_task( + filenames, +): + + from unicore import tasks + + filename = filenames[0] + state = torch.load(filename, map_location=torch.device("cpu")) + args = state["args"] + dictionary = Dictionary.load(os.path.join("unifold/musse/plm", "dict_esm.txt")) + model = BertModel(args, dictionary) + + def upgrade_state_dict(state_dict): + """Removes prefixes 'model.encoder.sentence_encoder.' and 'model.encoder.'.""" + prefixes = ["encoder."] # ["encoder.sentence_encoder.", "encoder."] + pattern = re.compile("^" + "|".join(prefixes)) + state_dict = { + pattern.sub("", name): param for name, param in state_dict.items() + } + + prefixes = [ + "sentence_encoder.embed_tokens" + ] # ["encoder.sentence_encoder.", "encoder."] + pattern = re.compile("^" + "|".join(prefixes)) + state_dict = { + pattern.sub("embed_tokens", name): param + for name, param in state_dict.items() + } + return state_dict + + state["model"] = upgrade_state_dict(state["model"]) + model.load_state_dict(state["model"], strict=True, model_args=args) + return model + + +def create_parser(): + parser = argparse.ArgumentParser( + description="Extract per-token representations and model outputs for sequences in a FASTA file" # noqa + ) + parser.add_argument("--input", type=str, default="") + parser.add_argument( + "--path", # model_location + type=str, + help="PyTorch model file OR name of pretrained model to download (see README for models)", + ) + parser.add_argument( + "--bf16", + action="store_true", + default=False, + help="where to use bf16", + ) + parser.add_argument( + "--fp16", + action="store_true", + default=False, + help="where to use fp16", + ) + # parser.add_argument( + # "fasta_file", + # type=pathlib.Path, + # help="FASTA file on which to extract representations", + # ) + parser.add_argument( + "--output_dir", + type=pathlib.Path, + help="output directory for extracted representations", + ) + + parser.add_argument( + "--toks_per_batch", type=int, default=4096, help="maximum batch size" + ) + parser.add_argument( + "--repr_layers", + type=int, + default=[-1], + nargs="+", + help="layers indices from which to extract representations (0 to num_layers, inclusive)", + ) + parser.add_argument( + "--include", + type=str, + nargs="+", + choices=["mean", "per_tok", "bos", "contacts", "attentions"], + help="specify which representations to return", + required=True, + ) + parser.add_argument( + "--truncate", + action="store_true", + help="Truncate sequences longer than 1024 to match the training setup", + ) + parser.add_argument( + "--user-dir", + default=None, + help="path to a python module containing custom extensions (tasks and/or architectures)", + ) + parser.add_argument("--batch_size", default=1, type=int, help="not used") + + parser.add_argument( + "--nogpu", action="store_true", help="Do not use GPU even if available" + ) + parser.add_argument( + "--local_rank", default=-1, type=int, help="node rank for distributed training" + ) + return parser + + +def main(args): + if args.bf16: + raise "not support bf16" + + this_rank = 0 + + if this_rank == 0: + print(f"model_path: {args.path}") + model = load_model_ensemble_and_task( + [args.path], + ) + + model = model.half() + + + gc.collect() + + if this_rank == 0: + print("loaded model successfully") + + model.eval() + + dataset = MultimerDataset(args.input) + + data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False) + if this_rank == 0: + print(f"{len(dataset)} sequences") + torch.cuda.set_device(this_rank) + + model = model.cuda(this_rank) + gc.collect() + + args.output_dir.mkdir(parents=True, exist_ok=True) + # return_contacts = "contacts" in args.include + # need_head_weights = "attentions" in args.include + + num_layers = 36 + # args.repr_layers = [i+1 for i in range(args.num_layers)] + + assert all(-(num_layers + 1) <= i <= num_layers for i in args.repr_layers) + repr_layers = [(i + num_layers + 1) % (num_layers + 1) for i in args.repr_layers] + + t0 = time.time() + with torch.no_grad(): + for batch_idx, (labels, sequence, is_same_entity, has_same_sequence) in tqdm( + enumerate(data_loader), total=len(data_loader) + ): + + assert sequence.shape[0] == 1 + sequence = sequence.cuda(this_rank, non_blocking=True) + is_same_entity = is_same_entity.cuda(this_rank, non_blocking=True) + has_same_sequence = has_same_sequence.cuda(this_rank, non_blocking=True) + + + out = model( + sequence, + is_same_entity=is_same_entity, + has_same_sequence=has_same_sequence, + features_only=True, + ) + + + representations = torch.stack( + [out[1].to(device="cpu"), out[36].to(device="cpu")], dim=2 + ) + + for i, label in enumerate(labels): + args.output_file = ( + args.output_dir + / label + / f"{label}.esm2_multimer_finetune_emb.pkl.gz" + ) + args.output_file.parent.mkdir(parents=True, exist_ok=True) + result = {} # {"label": label} + + result["token"] = representations[i, 1:-1].clone().numpy() + assert result["token"].shape[1] == 2 + assert result["token"].shape[2] == 2560 + + + pickle.dump(result, gzip.GzipFile(args.output_file, "wb"), protocol=4) + t1 = time.time() + if this_rank == 0: + print(f"total inference time for {len(dataset)} samples: {t1-t0}s") + + +if __name__ == "__main__": + parser = create_parser() + args = options.parse_args_and_arch(parser) + main(args) diff --git a/run_unifold_musse.sh b/run_unifold_musse.sh new file mode 100644 index 0000000..689aa92 --- /dev/null +++ b/run_unifold_musse.sh @@ -0,0 +1,30 @@ +set -e +fasta_path=$1 +output_dir_base=$2 +param_path=$3 +lm_param_path=$4 + +echo "Starting generating reprentations from the further pre-trained protein language model" +python infer_lm.py \ + --input $fasta_path \ + --path $lm_param_path \ + --output_dir $output_dir_base \ + --repr_layers 36 --include per_tok --toks_per_batch 512 --fp16 + +echo "Starting preparing features for fasta sequences" +python unifold/pack_feat.py \ + --fasta_path=$fasta_path \ + --output_dir=$output_dir_base + +echo "Starting prediction..." +fasta_file=$(basename $fasta_path) +target_name=${fasta_file%.fa*} +python unifold/inference.py \ + --model_name=unifold_musse \ + --param_path=$param_path \ + --data_dir=$output_dir_base \ + --target_name=$target_name \ + --output_dir=$output_dir_base + +echo "done" + \ No newline at end of file diff --git a/unifold/__init__.py b/unifold/__init__.py index 4b9220f..5df866f 100644 --- a/unifold/__init__.py +++ b/unifold/__init__.py @@ -3,3 +3,4 @@ import argparse from . import task, model, loss +from . import musse \ No newline at end of file diff --git a/unifold/config.py b/unifold/config.py index 497db12..ce14292 100644 --- a/unifold/config.py +++ b/unifold/config.py @@ -21,6 +21,7 @@ inf = mlc.FieldReference(3e4, field_type=float) use_templates = mlc.FieldReference(True, field_type=bool) is_multimer = mlc.FieldReference(False, field_type=bool) +use_musse = mlc.FieldReference(False, field_type=bool) def base_config(): @@ -56,6 +57,7 @@ def base_config(): "msa_chains": [N_MSA, None], "msa_row_mask": [N_MSA], "num_recycling_iters": [], + "num_ensembles": [], "pseudo_beta": [N_RES, None], "pseudo_beta_mask": [N_RES], "residue_index": [N_RES], @@ -94,6 +96,7 @@ def base_config(): "sym_id": [N_RES], "entity_id": [N_RES], "num_sym": [N_RES], + "token": [N_RES, None, None], "asym_len": [None], "cluster_bias_mask": [N_MSA], }, @@ -117,6 +120,7 @@ def base_config(): "msa_cluster_features": True, "reduce_msa_clusters_by_max_templates": True, "resample_msa_in_recycling": True, + "train_max_date": "2022-04-30", "template_features": [ "template_all_atom_positions", "template_sum_probs", @@ -133,6 +137,8 @@ def base_config(): "between_segment_residues", "deletion_matrix", "num_recycling_iters", + "num_ensembles", + "token", "crop_and_fix_size_seed", ], "recycling_features": [ @@ -160,6 +166,7 @@ def base_config(): ], "use_templates": use_templates, "is_multimer": is_multimer, + "use_musse": use_musse, "use_template_torsion_angles": use_templates, "max_recycling_iters": max_recycling_iters, }, @@ -240,6 +247,7 @@ def base_config(): }, "model": { "is_multimer": is_multimer, + "use_musse": use_musse, "input_embedder": { "tf_dim": 22, "msa_dim": 49, @@ -256,6 +264,12 @@ def base_config(): "num_bins": 15, "inf": 1e8, }, + "esm2_embedder": { + "token_dim": 2560, + "d_pair": d_pair, + "d_msa": d_msa, + "dropout": 0.1, + }, "template": { "distogram": { "min_bin": 3.25, @@ -591,6 +605,22 @@ def multimer(c): recursive_set(c, "max_msa_clusters", 256) c.data.train.crop_size = 384 c.loss.violation.weight = 0.5 + elif name == "unifold_musse": + c = multimer(c) + recursive_set(c, "use_musse", True) + recursive_set(c, "d_msa", 1024) + recursive_set(c, "d_single", 1024) + recursive_set(c, "num_heads_msa", 32) + recursive_set(c, "num_heads_ipa", 32) + c.model.template.enabled = False + c.model.template.embed_angles = False + c.model.extra_msa.enabled = False + recursive_set(c, "use_templates", False) + recursive_set(c, "use_template_torsion_angles", False) + c.loss.masked_msa.weight = 0.0 + c.loss.repr_norm.weight = 0.0 + c.model.heads.pae.disable_enhance_head = False + c.model.esm2_embedder.token_dim = 2560 elif name == "multimer_af2": recursive_set(c, "max_extra_msa", 1152) recursive_set(c, "max_msa_clusters", 256) diff --git a/unifold/data/process.py b/unifold/data/process.py index 5661231..8e50607 100644 --- a/unifold/data/process.py +++ b/unifold/data/process.py @@ -23,7 +23,9 @@ def nonensembled_fns(common_cfg, mode_cfg): ] ) operators.append( - data_ops.make_hhblits_profile_v2 if v2_feature else data_ops.make_hhblits_profile + data_ops.make_hhblits_profile_v2 + if v2_feature + else data_ops.make_hhblits_profile ) if common_cfg.use_templates: operators.extend( @@ -48,10 +50,60 @@ def nonensembled_fns(common_cfg, mode_cfg): operators.append(data_ops.make_atom14_masks) operators.append(data_ops.make_target_feat) + return operators +def single_fns(common_cfg, mode_cfg, crop_and_fix_size_seed): + """Input pipeline data transformers that are for single sequence""" + v2_feature = common_cfg.v2_feature + operators = [] + + operators.extend( + [ + data_ops.cast_to_64bit_ints, + data_ops.squeeze_features, + data_ops.make_seq_mask, + ] + ) + + operators.append(data_ops.make_atom14_masks) + operators.append(data_ops.make_target_feat) + + crop_feats = dict(common_cfg.features) + if mode_cfg.fixed_size: + if mode_cfg.crop: + if common_cfg.is_multimer: + crop_fn = data_ops.crop_to_size_multimer( + crop_size=mode_cfg.crop_size, + shape_schema=crop_feats, + seed=crop_and_fix_size_seed, + spatial_crop_prob=mode_cfg.spatial_crop_prob, + ca_ca_threshold=mode_cfg.ca_ca_threshold, + ) + else: + crop_fn = data_ops.crop_to_size_single( + crop_size=mode_cfg.crop_size, + shape_schema=crop_feats, + seed=crop_and_fix_size_seed, + ) + operators.append(crop_fn) + + operators.append(data_ops.select_feat(crop_feats)) + + operators.append( + data_ops.make_fixed_size( + crop_feats, + None, + None, + mode_cfg.crop_size, + None, + ) + ) + return operators + + def crop_and_fix_size_fns(common_cfg, mode_cfg, crop_and_fix_size_seed): operators = [] if common_cfg.reduce_msa_clusters_by_max_templates: @@ -204,9 +256,43 @@ def wrap_ensemble_fn(data, i): # add a dummy dim to align with recycling features tensors = {k: torch.stack([tensors[k]], dim=0) for k in tensors} tensors.update(ensemble_tensors) + tensors["num_ensembles"] = torch.tensor([num_ensembles]) return tensors +def process_features_single(tensors, common_cfg, mode_cfg): + """Based on the config, apply filters and transformations to the data.""" + is_distillation = bool(tensors.get("is_distillation", 0)) + multimer_mode = common_cfg.is_multimer + crop_and_fix_size_seed = int(tensors["crop_and_fix_size_seed"]) + + def wrap_ensemble_fn(data, i): + """Function to be mapped over the ensemble dimension.""" + d = data.copy() + return d + + nonensembled = single_fns(common_cfg, mode_cfg, crop_and_fix_size_seed) + + if mode_cfg.supervised and (not multimer_mode or is_distillation): + nonensembled.extend(label_transform_fn()) + + tensors = compose(nonensembled)(tensors) + + num_recycling = int(tensors["num_recycling_iters"]) + 1 + num_ensembles = mode_cfg.num_ensembles + + ensemble_tensors = map_fn( + lambda x: wrap_ensemble_fn(tensors, x), + torch.arange(num_ensembles), + ) + ensemble_tensors["num_recycling_iters"] = ensemble_tensors["num_recycling_iters"][ + 0:1 + ] + ensemble_tensors["seq_length"] = ensemble_tensors["seq_length"][0:1] + ensemble_tensors["num_ensembles"] = torch.tensor([num_ensembles]) + return ensemble_tensors + + @data_ops.curry1 def compose(x, fs): for f in fs: @@ -223,7 +309,7 @@ def pad_then_stack( for v in values: if v.shape[0] < size: res = values[0].new_zeros(size, *v.shape[1:]) - res[:v.shape[0], ...] = v + res[: v.shape[0], ...] = v else: res = v new_values.append(res) @@ -231,14 +317,13 @@ def pad_then_stack( new_values = values return torch.stack(new_values, dim=0) + def map_fn(fun, x): ensembles = [fun(elem) for elem in x] features = ensembles[0].keys() ensembled_dict = {} for feat in features: - ensembled_dict[feat] = pad_then_stack( - [dict_i[feat] for dict_i in ensembles] - ) + ensembled_dict[feat] = pad_then_stack([dict_i[feat] for dict_i in ensembles]) return ensembled_dict diff --git a/unifold/homo_search.py b/unifold/homo_search.py index f00a32f..1bb5331 100644 --- a/unifold/homo_search.py +++ b/unifold/homo_search.py @@ -170,7 +170,7 @@ def generate_pkl_features( # Get features. features_output_path = os.path.join( - output_dir, "{}.feature.pkl.gz".format(chain_id) + output_dir, "{}.feature.pkl.gz".format(fasta_name) ) if not os.path.exists(features_output_path): t_0 = time.time() @@ -184,7 +184,7 @@ def generate_pkl_features( # Get uniprot if use_uniprot: uniprot_output_path = os.path.join( - output_dir, "{}.uniprot.pkl.gz".format(chain_id) + output_dir, "{}.uniprot.pkl.gz".format(fasta_name) ) if not os.path.exists(uniprot_output_path): t_0 = time.time() @@ -201,7 +201,7 @@ def generate_pkl_features( logging.info("Final timings for %s: %s", fasta_name, timings) - timings_output_path = os.path.join(output_dir, "{}.timings.json".format(chain_id)) + timings_output_path = os.path.join(output_dir, "{}.timings.json".format(fasta_name)) with open(timings_output_path, "w") as f: f.write(json.dumps(timings, indent=4)) @@ -277,8 +277,8 @@ def main(argv): os.makedirs(output_dir) chain_order_path = os.path.join(output_dir, "chains.txt") with open(chain_order_path, "w") as f: - f.write("A") - fasta_names = [fasta_name] + f.write(f"{fasta_name}_A") + fasta_names = [f"{fasta_name}_A"] fasta_paths = [fasta_path] # Check for duplicate FASTA file names. diff --git a/unifold/inference.py b/unifold/inference.py index 7cd2a8e..30729bd 100644 --- a/unifold/inference.py +++ b/unifold/inference.py @@ -11,12 +11,15 @@ import pickle from unifold.config import model_config from unifold.modules.alphafold import AlphaFold +from unifold.musse.modules.alphafold import AlphaFoldMusse from unifold.data import residue_constants, protein from unifold.dataset import load_and_process, UnifoldDataset +from unifold.musse.dataset import load_and_process as load_and_process_musse from unicore.utils import ( tensor_tree_map, ) + def get_device_mem(device): if device != "cpu" and torch.cuda.is_available(): cur_device = torch.cuda.current_device() @@ -26,19 +29,20 @@ def get_device_mem(device): else: return 40 + def automatic_chunk_size(seq_len, device, is_bf16): total_mem_in_GB = get_device_mem(device) - factor = math.sqrt(total_mem_in_GB/40.0*(0.55 * is_bf16 + 0.45))*0.95 - if seq_len < int(1024*factor): + factor = math.sqrt(total_mem_in_GB / 40.0 * (0.55 * is_bf16 + 0.45)) * 0.95 + if seq_len < int(1024 * factor): chunk_size = 256 block_size = None - elif seq_len < int(2048*factor): + elif seq_len < int(2048 * factor): chunk_size = 128 block_size = None - elif seq_len < int(3072*factor): + elif seq_len < int(3072 * factor): chunk_size = 64 block_size = None - elif seq_len < int(4096*factor): + elif seq_len < int(4096 * factor): chunk_size = 32 block_size = 512 else: @@ -46,30 +50,47 @@ def automatic_chunk_size(seq_len, device, is_bf16): block_size = 256 return chunk_size, block_size + def load_feature_for_one_target( config, data_folder, seed=0, is_multimer=False, use_uniprot=False ): if not is_multimer: uniprot_msa_dir = None - sequence_ids = ["A"] + sequence_ids = open(os.path.join(data_folder, "chains.txt")).readline().split() if use_uniprot: uniprot_msa_dir = data_folder else: uniprot_msa_dir = data_folder sequence_ids = open(os.path.join(data_folder, "chains.txt")).readline().split() - batch, _ = load_and_process( - config=config.data, - mode="predict", - seed=seed, - batch_idx=None, - data_idx=0, - is_distillation=False, - sequence_ids=sequence_ids, - monomer_feature_dir=data_folder, - uniprot_msa_dir=uniprot_msa_dir, - is_monomer=(not is_multimer), - ) + if config.data.common.use_musse: + batch, _ = load_and_process_musse( + config=config.data, + mode="predict", + seed=seed, + batch_idx=None, + data_idx=0, + is_distillation=False, + sequence_ids=sequence_ids, + feature_dir=data_folder, + is_monomer=(not is_multimer), + emb_dir=data_folder, + msa_feature_dir=None, + template_feature_dir=None, + ) + else: + batch, _ = load_and_process( + config=config.data, + mode="predict", + seed=seed, + batch_idx=None, + data_idx=0, + is_distillation=False, + sequence_ids=sequence_ids, + monomer_feature_dir=data_folder, + uniprot_msa_dir=uniprot_msa_dir, + is_monomer=(not is_multimer), + ) batch = UnifoldDataset.collater([batch]) return batch @@ -83,7 +104,11 @@ def main(args): if args.sample_templates: # enable template samples for diversity config.data.predict.subsample_templates = True - model = AlphaFold(config) + model = ( + AlphaFold(config) + if not config.data.common.use_musse + else AlphaFoldMusse(config) + ) print("start to load params {}".format(args.param_path)) state_dict = torch.load(args.param_path)["ema"]["params"] @@ -125,10 +150,8 @@ def main(args): seq_len = batch["aatype"].shape[-1] # faster prediction with large chunk/block size chunk_size, block_size = automatic_chunk_size( - seq_len, - args.model_device, - args.bf16 - ) + seq_len, args.model_device, args.bf16 + ) model.globals.chunk_size = chunk_size model.globals.block_size = block_size @@ -152,9 +175,8 @@ def to_float(x): if not args.save_raw_output: score = ["plddt", "ptm", "iptm", "iptm+ptm"] out = { - k: v for k, v in raw_out.items() - if k.startswith("final_") or k in score - } + k: v for k, v in raw_out.items() if k.startswith("final_") or k in score + } else: out = raw_out del raw_out @@ -181,10 +203,12 @@ def to_float(x): plddts[cur_save_name] = str(mean_plddt) if is_multimer: ptms[cur_save_name] = str(np.mean(out["iptm+ptm"])) - with open(os.path.join(output_dir, cur_save_name + '.pdb'), "w") as f: + with open(os.path.join(output_dir, cur_save_name + ".pdb"), "w") as f: f.write(protein.to_pdb(cur_protein)) if args.save_raw_output: - with gzip.open(os.path.join(output_dir, cur_save_name + '_outputs.pkl.gz'), 'wb') as f: + with gzip.open( + os.path.join(output_dir, cur_save_name + "_outputs.pkl.gz"), "wb" + ) as f: pickle.dump(out, f) del out diff --git a/unifold/loss.py b/unifold/loss.py index 407b6ce..67be08e 100644 --- a/unifold/loss.py +++ b/unifold/loss.py @@ -39,11 +39,11 @@ def forward(self, model, batch, reduce=True): # return config in model. out, config = model(batch) - num_recycling = batch["msa_feat"].shape[0] - + num_recycling = batch["num_recycling_iters"] + 1 + # remove recyling dim batch = tensor_tree_map(lambda t: t[-1, ...], batch) - + loss, sample_size, logging_output = self.loss(out, batch, config) logging_output["num_recycling"] = num_recycling return loss, sample_size, logging_output @@ -52,12 +52,12 @@ def loss(self, out, batch, config): if "violation" not in out.keys() and config.violation.weight: out["violation"] = find_structural_violations( - batch, out["sm"]["positions"], **config.violation) + batch, out["sm"]["positions"], **config.violation + ) if "renamed_atom14_gt_positions" not in out.keys(): - batch.update( - compute_renamed_ground_truth(batch, out["sm"]["positions"])) - + batch.update(compute_renamed_ground_truth(batch, out["sm"]["positions"])) + loss_dict = {} loss_fns = { "chain_centre_mass": lambda: chain_centre_mass_loss( @@ -143,14 +143,14 @@ def loss(self, out, batch, config): with torch.no_grad(): seq_len = torch.sum(batch["seq_mask"].float(), dim=-1) seq_length_weight = seq_len**0.5 - + assert ( len(seq_length_weight.shape) == 1 and seq_length_weight.shape[0] == bsz ), seq_length_weight.shape - + for loss_name, loss_fn in loss_fns.items(): weight = config[loss_name].weight - if weight > 0.: + if weight > 0.0: loss = loss_fn() # always use float type for loss assert loss.dtype == torch.float, loss.dtype @@ -159,7 +159,7 @@ def loss(self, out, batch, config): if any(torch.isnan(loss)) or any(torch.isinf(loss)): logging.warning(f"{loss_name} loss is NaN. Skipping...") loss = loss.new_tensor(0.0, requires_grad=True) - + cum_loss = cum_loss + weight * loss for key in loss_dict: @@ -207,11 +207,11 @@ def forward(self, model, batch, reduce=True): # return config in model. out, config = model(features) - num_recycling = features["msa_feat"].shape[0] - + num_recycling = features["num_recycling_iters"] + 1 + # remove recycling dim features = tensor_tree_map(lambda t: t[-1, ...], features) - + # perform multi-chain permutation alignment. if labels: with torch.no_grad(): @@ -230,12 +230,12 @@ def forward(self, model, batch, reduce=True): ) new_labels.append(cur_new_labels) new_labels = data_utils.collate_dict(new_labels, dim=0) - + # check for consistency of label and feature. assert (new_labels["aatype"] == features["aatype"]).all() features.update(new_labels) loss, sample_size, logging_output = self.loss(out, features, config) logging_output["num_recycling"] = num_recycling - + return loss, sample_size, logging_output diff --git a/unifold/modules/__init__.py b/unifold/modules/__init__.py index eeb8ed9..36af787 100644 --- a/unifold/modules/__init__.py +++ b/unifold/modules/__init__.py @@ -4,4 +4,5 @@ set_jit_fusion_options, ) -set_jit_fusion_options() \ No newline at end of file +set_jit_fusion_options() +from .alphafold import AlphaFold \ No newline at end of file diff --git a/unifold/modules/alphafold.py b/unifold/modules/alphafold.py index ce8229d..b0aac4e 100644 --- a/unifold/modules/alphafold.py +++ b/unifold/modules/alphafold.py @@ -64,25 +64,27 @@ def __init__(self, config): self.template_pair_stack = TemplatePairStack( **template_config["template_pair_stack"], ) + + self.enable_template_pointwise_attention = template_config[ + "template_pointwise_attention" + ].enabled + if self.enable_template_pointwise_attention: + self.template_pointwise_att = TemplatePointwiseAttention( + **template_config["template_pointwise_attention"], + ) + else: + self.template_proj = TemplateProjection( + **template_config["template_pointwise_attention"], + ) else: self.template_pair_stack = None - self.enable_template_pointwise_attention = template_config[ - "template_pointwise_attention" - ].enabled - if self.enable_template_pointwise_attention: - self.template_pointwise_att = TemplatePointwiseAttention( - **template_config["template_pointwise_attention"], + if extra_msa_config.enabled: + self.extra_msa_embedder = ExtraMSAEmbedder( + **extra_msa_config["extra_msa_embedder"], ) - else: - self.template_proj = TemplateProjection( - **template_config["template_pointwise_attention"], + self.extra_msa_stack = ExtraMSAStack( + **extra_msa_config["extra_msa_stack"], ) - self.extra_msa_embedder = ExtraMSAEmbedder( - **extra_msa_config["extra_msa_embedder"], - ) - self.extra_msa_stack = ExtraMSAStack( - **extra_msa_config["extra_msa_stack"], - ) self.evoformer = EvoformerStack( **config["evoformer_stack"], ) @@ -106,14 +108,14 @@ def __make_input_float__(self): def half(self): super().half() - if (not getattr(self, "inference", False)): + if not getattr(self, "inference", False): self.__make_input_float__() self.dtype = torch.half return self def bfloat16(self): super().bfloat16() - if (not getattr(self, "inference", False)): + if not getattr(self, "inference", False): self.__make_input_float__() self.dtype = torch.bfloat16 return self @@ -135,6 +137,7 @@ def set_alphafold_original_mode(module): def inference_mode(self): def set_inference_mode(module): setattr(module, "inference", True) + self.apply(set_inference_mode) def __convert_input_dtype__(self, batch): @@ -144,7 +147,16 @@ def __convert_input_dtype__(self, batch): batch[key] = batch[key].type(self.dtype) return batch - def embed_templates_pair_core(self, batch, z, pair_mask, tri_start_attn_mask, tri_end_attn_mask, templ_dim, multichain_mask_2d): + def embed_templates_pair_core( + self, + batch, + z, + pair_mask, + tri_start_attn_mask, + tri_end_attn_mask, + templ_dim, + multichain_mask_2d, + ): if self.config.template.template_pair_embedder.v2_feature: t = build_template_pair_feat_v2( batch, @@ -185,7 +197,10 @@ def embed_templates_pair_core(self, batch, z, pair_mask, tri_start_attn_mask, tr def embed_templates_pair( self, batch, z, pair_mask, tri_start_attn_mask, tri_end_attn_mask, templ_dim ): - if self.config.template.template_pair_embedder.v2_feature and "asym_id" in batch: + if ( + self.config.template.template_pair_embedder.v2_feature + and "asym_id" in batch + ): multichain_mask_2d = ( batch["asym_id"][..., :, None] == batch["asym_id"][..., None, :] ) @@ -194,7 +209,15 @@ def embed_templates_pair( multichain_mask_2d = None if self.training or self.enable_template_pointwise_attention: - t = self.embed_templates_pair_core(batch, z, pair_mask, tri_start_attn_mask, tri_end_attn_mask, templ_dim, multichain_mask_2d) + t = self.embed_templates_pair_core( + batch, + z, + pair_mask, + tri_start_attn_mask, + tri_end_attn_mask, + templ_dim, + multichain_mask_2d, + ) if self.enable_template_pointwise_attention: t = self.template_pointwise_att( t, @@ -216,17 +239,29 @@ def embed_templates_pair( if n_templ <= 0: t = None else: - template_batch = { k: v for k, v in batch.items() if k.startswith("template_") } + template_batch = { + k: v for k, v in batch.items() if k.startswith("template_") + } + def embed_one_template(i): def slice_template_tensor(t): s = [slice(None) for _ in t.shape] s[batch_templ_dim] = slice(i, i + 1) return t[s] + template_feats = tensor_tree_map( slice_template_tensor, template_batch, ) - t = self.embed_templates_pair_core(template_feats, z, pair_mask, tri_start_attn_mask, tri_end_attn_mask, templ_dim, multichain_mask_2d) + t = self.embed_templates_pair_core( + template_feats, + z, + pair_mask, + tri_start_attn_mask, + tri_end_attn_mask, + templ_dim, + multichain_mask_2d, + ) return t t = embed_one_template(0) @@ -404,7 +439,7 @@ def iteration_evoformer_structure_module( outputs["pred_frame_tensor"] = outputs["sm"]["frames"][-1] # use float32 for numerical stability - if (not getattr(self, "inference", False)): + if not getattr(self, "inference", False): m_1_prev = m[..., 0, :, :].float() z_prev = z.float() x_prev = outputs["final_atom_positions"].float() diff --git a/unifold/modules/auxillary_heads.py b/unifold/modules/auxillary_heads.py index 335f7bd..56617ef 100644 --- a/unifold/modules/auxillary_heads.py +++ b/unifold/modules/auxillary_heads.py @@ -43,8 +43,9 @@ def forward(self, outputs): distogram_logits = self.distogram(outputs["pair"]) aux_out["distogram_logits"] = distogram_logits - masked_msa_logits = self.masked_msa(outputs["msa"]) - aux_out["masked_msa_logits"] = masked_msa_logits + if self.masked_msa is not None: + masked_msa_logits = self.masked_msa(outputs["msa"]) + aux_out["masked_msa_logits"] = masked_msa_logits if self.config.experimentally_resolved.enabled: exp_res_logits = self.experimentally_resolved(outputs["single"]) diff --git a/unifold/msa/utils.py b/unifold/msa/utils.py index 52716c9..e8a2cd8 100644 --- a/unifold/msa/utils.py +++ b/unifold/msa/utils.py @@ -9,6 +9,7 @@ def get_chain_id_map( sequences: Sequence[str], descriptions: Sequence[str], + fasta_name: str ): """ Makes a mapping from PDB-format chain ID to sequence and description, @@ -20,15 +21,16 @@ def get_chain_id_map( unique_seqs.append(seq) chain_id_map = { - chain_id: {"descriptions": [], "sequence": seq} + f"{fasta_name}_{chain_id}": {"descriptions": [], "sequence": seq} for chain_id, seq in zip(protein.PDB_CHAIN_IDS, unique_seqs) } chain_order = [] for seq, des in zip(sequences, descriptions): chain_id = protein.PDB_CHAIN_IDS[unique_seqs.index(seq)] - chain_id_map[chain_id]["descriptions"].append(des) - chain_order.append(chain_id) + chain_name = f"{fasta_name}_{chain_id}" + chain_id_map[chain_name]["descriptions"].append(des) + chain_order.append(chain_name) return chain_id_map, chain_order @@ -54,7 +56,7 @@ def divide_multi_chains( f"Got {len(sequences)} chains." ) - chain_id_map, chain_order = get_chain_id_map(sequences, descriptions) + chain_id_map, chain_order = get_chain_id_map(sequences, descriptions, fasta_name) output_dir = os.path.join(output_dir_base, fasta_name) if not os.path.exists(output_dir): @@ -75,7 +77,7 @@ def divide_multi_chains( temp_names = [] temp_paths = [] for chain_id in chain_id_map.keys(): - temp_name = fasta_name + "_{}".format(chain_id) + temp_name = chain_id temp_path = os.path.join(output_dir, temp_name + ".fasta") des = "chain_{}".format(chain_id) seq = chain_id_map[chain_id]["sequence"] diff --git a/unifold/musse/__init__.py b/unifold/musse/__init__.py new file mode 100644 index 0000000..bfbe3d3 --- /dev/null +++ b/unifold/musse/__init__.py @@ -0,0 +1,2 @@ +from . import model +from .dataset import UnifoldSingleMultimerDataset \ No newline at end of file diff --git a/unifold/musse/dataset.py b/unifold/musse/dataset.py new file mode 100644 index 0000000..263a744 --- /dev/null +++ b/unifold/musse/dataset.py @@ -0,0 +1,478 @@ +from unifold.dataset import * +from unifold.data.process import process_features_single +import gzip +import pickle + + +@utils.lru_cache(maxsize=8, copy=True) +def load_emb( + sequence_id: str, + monomer_feature_dir: str, + emb_dir: str, +) -> NumpyDict: + + monomer_feature = utils.load_pickle( + os.path.join(monomer_feature_dir, f"{sequence_id}.feature.pkl.gz") + ) + + chain_feature = {} + chain_feature["aatype"] = np.argmax(monomer_feature["aatype"], axis=-1).astype( + np.int32 + ) + chain_feature["sequence"] = monomer_feature["sequence"] + + chain_feature["seq_length"] = np.array([len(chain_feature["aatype"])]) + chain_feature["residue_index"] = np.arange(0, len(chain_feature["aatype"])) + + return chain_feature + + +def merge_multi_emb(all_chain_features): + merge_features = {} + num_chains = len(all_chain_features) + for key in all_chain_features[0]: + if key not in ["sequence", "resolution", "pair", "seq_length"]: + merge_features[key] = np.concatenate( + [x[key] for x in all_chain_features], axis=0 + ) + + merge_features["seq_length"] = np.asarray( + merge_features["aatype"].shape[0], dtype=np.int32 + ) + return merge_features + + +def load_crop_emb( + features, asymid_2_seq, per_asym_residue_index, emb_dir, mode="train" +): + total_len = features["aatype"].shape[-1] + all_pair = None + all_token = None + offset = 0 + for asym_id in per_asym_residue_index: + crop_idx = per_asym_residue_index[asym_id] + seq = asymid_2_seq[asym_id] + emb_feature = utils.load_pickle( + os.path.join(emb_dir, f"{seq}.esm2_multimer_finetune_emb.pkl.gz") + ) + token = torch.from_numpy(emb_feature["token"]) + pair = torch.from_numpy(emb_feature["pair"]) + if all_token is None: + all_token = token.new_zeros(total_len, token.shape[-1]) + if all_pair is None: + all_pair = pair.new_zeros(total_len, total_len, pair.shape[-3]) + if mode != "predict": + token = torch.index_select(token, 0, crop_idx) + pair = torch.index_select(pair, -1, crop_idx) + pair = torch.index_select(pair, -2, crop_idx) + pair = pair.permute(1, 2, 0) + cur_len = token.shape[0] + all_token[offset : offset + cur_len, :] = token + all_pair[offset : offset + cur_len, offset : offset + cur_len, :] = pair + offset += cur_len + + features["token"] = all_token[None, ...] + features["pair"] = all_pair[None, ...] + return features + + +def load_assembly_esm(embdir, assem_name): + fn = os.path.join(embdir, f"{assem_name}.esm2_multimer_finetune_emb.pkl.gz") + if not os.path.exists(fn): + fn = os.path.join(embdir, f"{assem_name}.esm2_3b_emb.pkl.gz") + if not os.path.exists(fn): + fn = os.path.join(embdir, f"{assem_name}.esm2_15b_emb.pkl.gz") + with gzip.GzipFile(fn, "rb") as f: + embeddings = pickle.load(f) + return embeddings["token"] + + +def load( + sequence_ids: List[str], + feature_dir, + msa_feature_dir: str, + emb_dir: str, + template_feature_dir: str, + uniprot_msa_feature_dir: Optional[str] = None, + label_ids: Optional[List[str]] = None, + label_dir: Optional[str] = None, + symmetry_operations: Optional[List[Operation]] = None, + is_monomer: bool = False, + train_max_date: Optional[str] = None, + is_distillation=False, +) -> NumpyExample: + + if is_distillation: + assemb_name = sequence_ids[0] + else: + try: + assemb_name = label_ids[0].split("_")[0] + except: + assemb_name = sequence_ids[0].split("_")[0] + embeddings = load_assembly_esm(emb_dir, assemb_name) + + all_chain_features = [load_emb(s, feature_dir, emb_dir) for s in sequence_ids] + assert embeddings.shape[0] == sum( + [len(chain_feature["aatype"]) for chain_feature in all_chain_features] + ), "embedding shape error {} {} {}".format( + str(label_ids), + embeddings.shape[0], + sum([len(chain_feature["aatype"]) for chain_feature in all_chain_features]), + ) + curpos = 0 + for feat in all_chain_features: + offset = len(feat["aatype"]) + feat["token"] = embeddings[curpos : curpos + offset] + curpos += offset + + if label_ids is not None: + # load labels + assert len(label_ids) == len(sequence_ids) + assert label_dir is not None + if symmetry_operations is None: + symmetry_operations = ["I" for _ in label_ids] + all_chain_labels = [ + load_single_label(l, label_dir, o) + for l, o in zip(label_ids, symmetry_operations) + ] + # update labels into features to calculate spatial cropping etc. + [f.update(l) for f, l in zip(all_chain_features, all_chain_labels)] + + all_chain_features = add_assembly_features(all_chain_features) + + # get labels back from features, as add_assembly_features may alter the order of inputs. + if label_ids is not None: + all_chain_labels = [ + { + k: f[k] + for k in ["aatype", "all_atom_positions", "all_atom_mask", "resolution"] + } + for f in all_chain_features + ] + else: + all_chain_labels = None + + asym_len = np.array( + [int(c["seq_length"]) for c in all_chain_features], dtype=np.int64 + ) + all_chain_features = merge_multi_emb(all_chain_features) + all_chain_features["asym_len"] = asym_len + + return all_chain_features, all_chain_labels + + +def process( + config: mlc.ConfigDict, + mode: str, + features: NumpyDict, + labels: Optional[List[NumpyDict]] = None, + seed: int = 0, + batch_idx: Optional[int] = None, + data_idx: Optional[int] = None, + is_distillation: bool = False, + emb_dir: str = None, + **kwargs, +) -> TorchExample: + + if mode == "train": + assert batch_idx is not None + with data_utils.numpy_seed(seed, batch_idx, key="recycling"): + num_iters = np.random.randint(0, config.common.max_recycling_iters + 1) + use_clamped_fape = np.random.rand() < config[mode].use_clamped_fape_prob + else: + num_iters = config.common.max_recycling_iters + use_clamped_fape = 1 + + features["num_recycling_iters"] = int(num_iters) + features["use_clamped_fape"] = int(use_clamped_fape) + features["is_distillation"] = int(is_distillation) + if is_distillation and "msa_chains" in features: + features.pop("msa_chains") + + num_res = int(features["seq_length"]) + cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res) + + if labels is not None: + features["resolution"] = labels[0]["resolution"].reshape(-1) + + with data_utils.numpy_seed(seed, data_idx, key="protein_feature"): + features["crop_and_fix_size_seed"] = np.random.randint(0, 63355) + features = utils.filter(features, desired_keys=feature_names) + features = {k: torch.tensor(v) for k, v in features.items()} + with torch.no_grad(): + features = process_features_single(features, cfg.common, cfg[mode]) + + if labels is not None: + labels = [{k: torch.tensor(v) for k, v in l.items()} for l in labels] + with torch.no_grad(): + labels = process_labels(labels) + + return features, labels + + +def load_and_process( + config: mlc.ConfigDict, + mode: str, + seed: int = 0, + batch_idx: Optional[int] = None, + data_idx: Optional[int] = None, + is_distillation: bool = False, + **load_kwargs, +): + try: + is_monomer = ( + is_distillation + if "is_monomer" not in load_kwargs + else load_kwargs.pop("is_monomer") + ) + features, labels = load( + **load_kwargs, + is_monomer=is_monomer, + train_max_date=config.common.train_max_date, + is_distillation=is_distillation, + ) + features, labels = process( + config, + mode, + features, + labels, + seed, + batch_idx, + data_idx, + is_distillation, + **load_kwargs, + ) + return features, labels + except Exception as e: + print("Error loading data", load_kwargs, e) + raise e + + +class UnifoldSingleMultimerDataset(UnifoldMultimerDataset): + def __init__( + self, + args: mlc.ConfigDict, + seed: int, + config: mlc.ConfigDict, + data_path: str, + mode: str = "train", + max_step: Optional[int] = None, + disable_sd: bool = False, + json_prefix: str = "", + **kwargs, + ): + super().__init__( + args, + seed, + config, + data_path, + mode, + max_step, + disable_sd, + json_prefix, + **kwargs, + ) + + def load_sample_weight(multi_label, dir, mode): + cluster_size = load_json( + os.path.join( + self.path, dir, json_prefix + mode + "_cluster_size_all.json" + ) + ) + seq_length = load_json( + os.path.join( + self.path, dir, json_prefix + mode + "_seq_length_all.json" + ) + ) + sample_weight = {} + keys = multi_label.keys() if multi_label is not None else seq_length.keys() + for seq in keys: + sample_weight[seq] = get_seq_sample_weight( + seq_length[seq], cluster_size[seq] + ) + return sample_weight + + if mode == "train" and not disable_sd: + self.sd_sample_weight = load_sample_weight(None, "sd", "sd") + logger.info( + "load {} self-distillation samples.".format(len(self.sd_sample_weight)) + ) + ( + self.sd_feature_path, + self.sd_msa_feature_path, + self.sd_template_feature_path, + self.sd_label_path, + ) = load_folders(self.path, mode="sd") + else: + self.sd_sample_weight = None + if self.sd_sample_weight is not None: + ( + self.sd_num_chain, + self.sd_chain_keys, + self.sd_sample_prob, + ) = self.cal_sample_weight(self.sd_sample_weight) + + self.data_path = data_path + self.pdb_assembly = json.load( + open( + os.path.join( + self.data_path, + "traineval", + json_prefix + mode + "_mmcif_assembly.json", + ) + ) + ) + self.pdb_chains = self.get_chains(self.inverse_multi_label) + self.uniprot_msa_feature_path = os.path.join( + self.data_path, "traineval", "uniprot_features" + ) + self.max_chains = args.max_chains + + def filter_pdb_assembly(pdb_assembly, config): + + new_pdb_assembly = {} + if config.data.common.feature_src: + filter_keys = json.load( + open( + os.path.join( + self.data_path, + "traineval", + f"{config.data.common.feature_src}_filter", + json_prefix + "filtered_" + mode + "_keys.json", + ) + ) + ) + else: + filter_keys = json.load( + open( + os.path.join( + self.data_path, + "traineval", + json_prefix + "filtered_" + mode + "_keys.json", + ) + ) + ) + for pdb_id in pdb_assembly: + if pdb_id in filter_keys: + # print(f"filter {pdb_id} too long") + continue + content = pdb_assembly[pdb_id] + new_content = {"chains": [], "opers": []} + has_content = False + for i, chain in enumerate(content["chains"]): + if (pdb_id + "_" + chain) in self.inverse_multi_label: + new_content["chains"].append(chain) + new_content["opers"].append(content["opers"][i]) + has_content = True + if has_content: + new_pdb_assembly[pdb_id] = new_content + return new_pdb_assembly + + self.pdb_assembly = filter_pdb_assembly(self.pdb_assembly, config=config) + + def load_chain_cluster_size(mode): + cluster_size = load_json( + os.path.join( + self.path, "traineval", json_prefix + mode + "_cluster_size.json" + ) + ) + seq_cnt = {} + for pdb_id in self.pdb_assembly: + for chain in self.pdb_assembly[pdb_id]["chains"]: + seq = self.inverse_multi_label[pdb_id + "_" + chain] + if seq not in seq_cnt: + seq_cnt[seq] = 0 + seq_cnt[seq] += 1 + new_cluster_size = {} + for seq in seq_cnt: + assert seq in cluster_size, seq + assert seq in seq_cnt, seq + new_cluster_size[seq] = cluster_size[seq] * seq_cnt[seq] + return new_cluster_size + + chain_cluster_size = load_chain_cluster_size(mode) + + def cal_pdb_sample_weight(mode, pdb_assembly, cluster_size): + seq_length = load_json( + os.path.join( + self.path, "traineval", json_prefix + mode + "_seq_length.json" + ) + ) + sample_weight = {} + total_seq_length = {} + for pdb_id in pdb_assembly: + cur_sample_weight = 0.0 + cur_seq_length = 0 + for chain in pdb_assembly[pdb_id]["chains"]: + seq = self.inverse_multi_label[pdb_id + "_" + chain] + cur_sample_weight += get_chain_sample_weight(cluster_size[seq]) + cur_seq_length += seq_length[seq] + # avoid too large sample weights + sample_weight[pdb_id] = min(cur_sample_weight, 2.0) + total_seq_length[pdb_id] = cur_seq_length + return (sample_weight, total_seq_length) + + self.sample_weight, total_seq_length = cal_pdb_sample_weight( + mode, self.pdb_assembly, chain_cluster_size + ) + self.pdb_assembly, self.sample_weight = self.filter_pdb_by_max_chains( + self.pdb_assembly, self.sample_weight, self.max_chains, total_seq_length + ) + self.num_pdb, self.pdb_keys, self.sample_prob = self.cal_sample_weight( + self.sample_weight + ) + if config.data.common.feature_src: + self.emb_path = os.path.join( + self.data_path, "traineval", f"esms_{config.data.common.feature_src}" + ) + self.sd_emb_path = os.path.join( + self.data_path, "sd", f"esms_{config.data.common.feature_src}" + ) + else: + self.emb_path = os.path.join(self.data_path, "traineval", "esms") + self.sd_emb_path = os.path.join(self.data_path, "sd", "esms") + + def __getitem__(self, idx): + label_id, is_distillation = self.sample_pdb(idx) + if is_distillation: + label_ids = [label_id] + sequence_ids = [label_id] + monomer_feature_path, label_path, emb_path = ( + self.sd_feature_path, + self.sd_label_path, + self.sd_emb_path, + ) + symmetry_operations = None + else: + pdb_id = label_id + label_ids = [ + pdb_id + "_" + id for id in self.pdb_assembly[pdb_id]["chains"] + ] + symmetry_operations = [t for t in self.pdb_assembly[pdb_id]["opers"]] + sequence_ids = [ + self.inverse_multi_label[chain_id] for chain_id in label_ids + ] + monomer_feature_path, label_path, emb_path = ( + self.feature_path, + self.label_path, + self.emb_path, + ) + + return load_and_process( + self.config, + self.mode, + self.seed, + batch_idx=(idx // self.batch_size), + data_idx=idx, + is_distillation=is_distillation, + sequence_ids=sequence_ids, + feature_dir=monomer_feature_path, + msa_feature_dir=None, + template_feature_dir=None, + uniprot_msa_feature_dir=None, + emb_dir=emb_path, + label_ids=label_ids, + label_dir=label_path, + symmetry_operations=symmetry_operations, + is_monomer=False, + ) diff --git a/unifold/musse/model.py b/unifold/musse/model.py new file mode 100644 index 0000000..aea6bee --- /dev/null +++ b/unifold/musse/model.py @@ -0,0 +1,47 @@ +from unicore.models import BaseUnicoreModel, register_model, register_model_architecture +from unifold.config import model_config +from .modules.alphafold import AlphaFoldMusse + + +@register_model("af2_single") +class AlphafoldSingleModel(BaseUnicoreModel): + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument( + "--model-name", + help="choose the model config", + ) + + def __init__(self, args): + super().__init__() + base_architecture(args) + self.args = args + config = model_config( + self.args.model_name, + train=True, + ) + self.model = AlphaFoldSingle(config) + self.config = config + + def half(self): + self.model = self.model.half() + return self + + def bfloat16(self): + self.model = self.model.bfloat16() + return self + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + return cls(args) + + def forward(self, batch, **kwargs): + outputs = self.model.forward(batch) + return outputs, self.config.loss + + +@register_model_architecture("af2_single", "af2_single") +def base_architecture(args): + args.model_name = getattr(args, "model_name", "single_multimer") diff --git a/unifold/musse/modules/__init__.py b/unifold/musse/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/unifold/musse/modules/alphafold.py b/unifold/musse/modules/alphafold.py new file mode 100644 index 0000000..4bf383c --- /dev/null +++ b/unifold/musse/modules/alphafold.py @@ -0,0 +1,242 @@ +from unifold.modules.alphafold import * +from .embedders import InputEmbedderSingle, Esm2Embedder +from .evoformer import EvoformerStackSingle +from .auxiliary_heads import AuxiliaryHeadsSingle + + +class AlphaFoldMusse(AlphaFold): + def __init__(self, config): + super().__init__(config) + config = config.model + self.input_embedder = InputEmbedderSingle( + **config["input_embedder"], + use_chain_relative=config.is_multimer, + ) + self.esm2_embedder = Esm2Embedder( + **config["esm2_embedder"], + ) + + self.evoformer = EvoformerStackSingle( + **config["evoformer_stack"], + ) + self.aux_heads = AuxiliaryHeadsSingle( + config["heads"], + ) + + def __make_input_float__(self): + super().__make_input_float__() + self.esm2_embedder = self.esm2_embedder.float() + + def iteration_evoformer(self, feats, m_1_prev, z_prev, x_prev): + batch_dims = feats["target_feat"].shape[:-2] + n = feats["target_feat"].shape[-2] + seq_mask = feats["seq_mask"] + pair_mask = seq_mask[..., None] * seq_mask[..., None, :] + msa_mask = seq_mask.unsqueeze(-2) + + m, z = self.input_embedder( + feats["target_feat"], + None, + ) + + if m_1_prev is None: + m_1_prev = m.new_zeros( + (*batch_dims, n, self.config.input_embedder.d_msa), + requires_grad=False, + ) + if z_prev is None: + z_prev = z.new_zeros( + (*batch_dims, n, n, self.config.input_embedder.d_pair), + requires_grad=False, + ) + if x_prev is None: + x_prev = z.new_zeros( + (*batch_dims, n, residue_constants.atom_type_num, 3), + requires_grad=False, + ) + x_prev = pseudo_beta_fn(feats["aatype"], x_prev, None) + + z += self.recycling_embedder.recyle_pos(x_prev) + + m, z = self.esm2_embedder(m, z, feats["token"]) + + m_1_prev_emb, z_prev_emb = self.recycling_embedder( + m_1_prev, + z_prev, + ) + + m[..., 0, :, :] += m_1_prev_emb + + z += z_prev_emb + + z += self.input_embedder.relpos_emb( + feats["residue_index"].long(), + feats.get("sym_id", None), + feats.get("asym_id", None), + feats.get("entity_id", None), + feats.get("num_sym", None), + ) + + m = m.type(self.dtype) + z = z.type(self.dtype) + tri_start_attn_mask, tri_end_attn_mask = gen_tri_attn_mask(pair_mask, self.inf) + + if self.config.template.enabled: + template_mask = feats["template_mask"] + if torch.any(template_mask): + z = residual( + z, + self.embed_templates_pair( + feats, + z, + pair_mask, + tri_start_attn_mask, + tri_end_attn_mask, + templ_dim=-4, + ), + self.training, + ) + + if self.config.extra_msa.enabled: + a = self.extra_msa_embedder(build_extra_msa_feat(feats)) + extra_msa_row_mask = gen_msa_attn_mask( + feats["extra_msa_mask"], + inf=self.inf, + gen_col_mask=False, + ) + z = self.extra_msa_stack( + a, + z, + msa_mask=feats["extra_msa_mask"], + chunk_size=self.globals.chunk_size, + block_size=self.globals.block_size, + pair_mask=pair_mask, + msa_row_attn_mask=extra_msa_row_mask, + msa_col_attn_mask=None, + tri_start_attn_mask=tri_start_attn_mask, + tri_end_attn_mask=tri_end_attn_mask, + ) + + if self.config.template.embed_angles: + template_1d_feat, template_1d_mask = self.embed_templates_angle(feats) + m = torch.cat([m, template_1d_feat], dim=-3) + msa_mask = torch.cat([feats["msa_mask"], template_1d_mask], dim=-2) + + msa_row_mask, msa_col_mask = gen_msa_attn_mask( + msa_mask, + inf=self.inf, + ) + + m, z, s = self.evoformer( + m, + z, + msa_mask=msa_mask, + pair_mask=pair_mask, + msa_row_attn_mask=msa_row_mask, + msa_col_attn_mask=msa_col_mask, + tri_start_attn_mask=tri_start_attn_mask, + tri_end_attn_mask=tri_end_attn_mask, + chunk_size=self.globals.chunk_size, + block_size=self.globals.block_size, + ) + return m, z, s, msa_mask, m_1_prev_emb, z_prev_emb + + def iteration_evoformer_structure_module( + self, batch, m_1_prev, z_prev, x_prev, cycle_no, num_recycling, num_ensembles=1 + ): + z, s = 0, 0 + n_seq = 1 + assert num_ensembles >= 1 + for ensemble_no in range(num_ensembles): + idx = cycle_no * num_ensembles + ensemble_no + fetch_cur_batch = lambda t: t[min(t.shape[0] - 1, idx), ...] + feats = tensor_tree_map(fetch_cur_batch, batch) + m, z0, s0, msa_mask, m_1_prev_emb, z_prev_emb = self.iteration_evoformer( + feats, m_1_prev, z_prev, x_prev + ) + z += z0 + s += s0 + del z0, s0 + if num_ensembles > 1: + z /= float(num_ensembles) + s /= float(num_ensembles) + + outputs = {} + + outputs["msa"] = m[..., :n_seq, :, :] + outputs["pair"] = z + outputs["single"] = s + + # norm loss + if (not getattr(self, "inference", False)) and num_recycling == (cycle_no + 1): + delta_msa = m + delta_msa[..., 0, :, :] = delta_msa[..., 0, :, :] - m_1_prev_emb.detach() + delta_pair = z - z_prev_emb.detach() + outputs["delta_msa"] = delta_msa + outputs["delta_pair"] = delta_pair + outputs["msa_norm_mask"] = msa_mask + + outputs["sm"] = self.structure_module( + s, + z, + feats["aatype"], + mask=feats["seq_mask"], + ) + outputs["final_atom_positions"] = atom14_to_atom37( + outputs["sm"]["positions"], feats + ) + outputs["final_atom_mask"] = feats["atom37_atom_exists"] + outputs["pred_frame_tensor"] = outputs["sm"]["frames"][-1] + + # use float32 for numerical stability + if not getattr(self, "inference", False): + m_1_prev = m[..., 0, :, :].float() + z_prev = z.float() + x_prev = outputs["final_atom_positions"].float() + else: + m_1_prev = m[..., 0, :, :] + z_prev = z + x_prev = outputs["final_atom_positions"] + + return outputs, m_1_prev, z_prev, x_prev + + def forward(self, batch): + + m_1_prev = batch.get("m_1_prev", None) + z_prev = batch.get("z_prev", None) + x_prev = batch.get("x_prev", None) + + is_grad_enabled = torch.is_grad_enabled() + + num_iters = int(batch["num_recycling_iters"]) + 1 + num_ensembles = int(batch["num_ensembles"]) + if self.training: + # don't use ensemble during training + assert num_ensembles == 1 + + # convert dtypes in batch + batch = self.__convert_input_dtype__(batch) + for cycle_no in range(num_iters): + is_final_iter = cycle_no == (num_iters - 1) + with torch.set_grad_enabled(is_grad_enabled and is_final_iter): + ( + outputs, + m_1_prev, + z_prev, + x_prev, + ) = self.iteration_evoformer_structure_module( + batch, + m_1_prev, + z_prev, + x_prev, + cycle_no=cycle_no, + num_recycling=num_iters, + num_ensembles=num_ensembles, + ) + if not is_final_iter: + del outputs + + if "asym_id" in batch: + outputs["asym_id"] = batch["asym_id"][0, ...] + outputs.update(self.aux_heads(outputs)) + return outputs diff --git a/unifold/musse/modules/auxiliary_heads.py b/unifold/musse/modules/auxiliary_heads.py new file mode 100644 index 0000000..e95bb90 --- /dev/null +++ b/unifold/musse/modules/auxiliary_heads.py @@ -0,0 +1,7 @@ +from unifold.modules.auxillary_heads import * + + +class AuxiliaryHeadsSingle(AuxiliaryHeads): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.masked_msa = None diff --git a/unifold/musse/modules/embedders.py b/unifold/musse/modules/embedders.py new file mode 100644 index 0000000..6e096b7 --- /dev/null +++ b/unifold/musse/modules/embedders.py @@ -0,0 +1,90 @@ +from unifold.modules.embedders import * + + +class InputEmbedderSingle(InputEmbedder): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.linear_msa_m = None + + def forward( + self, + tf: torch.Tensor, + msa: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # [*, N_res, d_pair] + if self.tf_dim == 21: + # multimer use 21 target dim + tf = tf[..., 1:] + # convert type if necessary + tf = tf.type(self.linear_tf_z_i.weight.dtype) + # msa = msa.type(self.linear_tf_z_i.weight.dtype) + # n_clust = msa.shape[-3] + n_clust = 1 + + # msa_emb = self.linear_msa_m(msa) + # target_feat (aatype) into msa representation + tf_m = ( + self.linear_tf_m(tf) + .unsqueeze(-3) + .expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1))) # expand -3 dim + ) + msa_emb = tf_m + + tf_emb_i = self.linear_tf_z_i(tf) + tf_emb_j = self.linear_tf_z_j(tf) + pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :] + + return msa_emb, pair_emb + + +class Esm2Embedder(nn.Module): + def __init__( + self, + token_dim: int, + d_msa: int, + d_pair: int, + dropout: float, + **kwargs, + ): + super(Esm2Embedder, self).__init__() + + + self.linear_token = Linear(token_dim, d_msa) + # self.linear_pair = Linear(pair_dim, d_pair) + self.combine = nn.Parameter(torch.tensor([0.0, 2.3])) + self.dropout = dropout + + def forward( + self, + m, + z, + token: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + mask_shape = token.shape[:-2] + token = token[..., None, :, :, :] + token = token.type(self.linear_token.weight.dtype) + token = torch.einsum( + "...nh,n->...h", + token, + nn.functional.softmax(self.combine.float(), dim=0).type( + self.linear_token.weight.dtype + ), + ) + # pair = pair.type(self.linear_pair.weight.dtype) + + with torch.no_grad(): + token_mask = ( + torch.rand(mask_shape, dtype=token.dtype, device=token.device) + >= self.dropout + ).type(token.dtype) + + # pair_mask = token_mask[..., None, :] * token_mask[..., None] + token_mask = token_mask[..., None, :, None] # / (1.0 - self.dropout) + # pair_mask = pair_mask[..., None] # / (1.0 - self.dropout) + + token = token * token_mask + # pair = pair * pair_mask + m = residual(m, self.linear_token(token), self.training) + # z = residual(z, self.linear_pair(pair), self.training) + return m, z + diff --git a/unifold/musse/modules/evoformer.py b/unifold/musse/modules/evoformer.py new file mode 100644 index 0000000..3bc62ae --- /dev/null +++ b/unifold/musse/modules/evoformer.py @@ -0,0 +1,165 @@ +from unifold.modules.evoformer import * + + +class EvoformerIterationSingle(EvoformerIteration): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.msa_att_col = None + + def forward( + self, + m: torch.Tensor, + z: torch.Tensor, + msa_mask: torch.Tensor, + pair_mask: torch.Tensor, + msa_row_attn_mask: torch.Tensor, + msa_col_attn_mask: Optional[torch.Tensor], + tri_start_attn_mask: torch.Tensor, + tri_end_attn_mask: torch.Tensor, + chunk_size: Optional[int] = None, + block_size: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + if self.outer_product_mean_first: + z = residual( + z, + self.outer_product_mean(m, mask=msa_mask, chunk_size=chunk_size), + self.training, + ) + + m = bias_dropout_residual( + self.msa_att_row, + m, + self.msa_att_row( + m, z=z, attn_mask=msa_row_attn_mask, chunk_size=chunk_size + ), + self.row_dropout_share_dim, + self.msa_dropout, + self.training, + ) + # if self._is_extra_msa_stack: + # m = residual( + # m, self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size), + # self.training + # ) + # else: + # m = bias_dropout_residual( + # self.msa_att_col, + # m, + # self.msa_att_col(m, attn_mask=msa_col_attn_mask, chunk_size=chunk_size), + # self.col_dropout_share_dim, + # self.msa_dropout, + # self.training, + # ) + m = residual(m, self.msa_transition(m, chunk_size=chunk_size), self.training) + if not self.outer_product_mean_first: + z = residual( + z, + self.outer_product_mean(m, mask=msa_mask, chunk_size=chunk_size), + self.training, + ) + + z = tri_mul_residual( + self.tri_mul_out, + z, + self.tri_mul_out(z, mask=pair_mask, block_size=block_size), + self.row_dropout_share_dim, + self.pair_dropout, + self.training, + block_size=block_size, + ) + + z = tri_mul_residual( + self.tri_mul_in, + z, + self.tri_mul_in(z, mask=pair_mask, block_size=block_size), + self.row_dropout_share_dim, + self.pair_dropout, + self.training, + block_size=block_size, + ) + + z = bias_dropout_residual( + self.tri_att_start, + z, + self.tri_att_start(z, attn_mask=tri_start_attn_mask, chunk_size=chunk_size), + self.row_dropout_share_dim, + self.pair_dropout, + self.training, + ) + + z = bias_dropout_residual( + self.tri_att_end, + z, + self.tri_att_end(z, attn_mask=tri_end_attn_mask, chunk_size=chunk_size), + self.col_dropout_share_dim, + self.pair_dropout, + self.training, + ) + z = residual(z, self.pair_transition(z, chunk_size=chunk_size), self.training) + return m, z + + +class EvoformerStackSingle(EvoformerStack): + def __init__( + self, + d_msa: int, + d_pair: int, + d_hid_msa_att: int, + d_hid_opm: int, + d_hid_mul: int, + d_hid_pair_att: int, + d_single: int, + num_heads_msa: int, + num_heads_pair: int, + num_blocks: int, + transition_n: int, + msa_dropout: float, + pair_dropout: float, + outer_product_mean_first: bool, + inf: float, + eps: float, + _is_extra_msa_stack: bool = False, + **kwargs + ): + super().__init__( + d_msa, + d_pair, + d_hid_msa_att, + d_hid_opm, + d_hid_mul, + d_hid_pair_att, + d_single, + num_heads_msa, + num_heads_pair, + num_blocks, + transition_n, + msa_dropout, + pair_dropout, + outer_product_mean_first, + inf, + eps, + _is_extra_msa_stack, + **kwargs + ) + self.blocks = SimpleModuleList() + for _ in range(num_blocks): + self.blocks.append( + EvoformerIterationSingle( + d_msa=d_msa, + d_pair=d_pair, + d_hid_msa_att=d_hid_msa_att, + d_hid_opm=d_hid_opm, + d_hid_mul=d_hid_mul, + d_hid_pair_att=d_hid_pair_att, + num_heads_msa=num_heads_msa, + num_heads_pair=num_heads_pair, + transition_n=transition_n, + msa_dropout=msa_dropout, + pair_dropout=pair_dropout, + outer_product_mean_first=outer_product_mean_first, + inf=inf, + eps=eps, + _is_extra_msa_stack=_is_extra_msa_stack, + ) + ) diff --git a/unifold/musse/plm/__init__.py b/unifold/musse/plm/__init__.py new file mode 100755 index 0000000..00a6d51 --- /dev/null +++ b/unifold/musse/plm/__init__.py @@ -0,0 +1 @@ +import importlib \ No newline at end of file diff --git a/unifold/musse/plm/dict_esm.txt b/unifold/musse/plm/dict_esm.txt new file mode 100755 index 0000000..5effb61 --- /dev/null +++ b/unifold/musse/plm/dict_esm.txt @@ -0,0 +1,33 @@ +[CLS] +[PAD] +[SEP] +[UNK] +L +A +G +V +S +E +R +T +I +D +P +K +Q +N +F +Y +M +H +W +C +X +B +U +Z +O +. +- + +[MASK] \ No newline at end of file diff --git a/unifold/musse/plm/model/__init__.py b/unifold/musse/plm/model/__init__.py new file mode 100755 index 0000000..b89bc03 --- /dev/null +++ b/unifold/musse/plm/model/__init__.py @@ -0,0 +1,3 @@ +from pathlib import Path +import importlib + diff --git a/unifold/musse/plm/model/bert.py b/unifold/musse/plm/model/bert.py new file mode 100755 index 0000000..6d1037e --- /dev/null +++ b/unifold/musse/plm/model/bert.py @@ -0,0 +1,365 @@ +import logging +import re + +import torch +import torch.nn as nn +import torch.nn.functional as F +from unicore import utils +from unicore.models import BaseUnicoreModel, register_model, register_model_architecture +from unicore.data import Dictionary +from unicore.modules import LayerNorm # , init_bert_params +from .transformer_encoder import TransformerEncoder + + +logger = logging.getLogger(__name__) + +def init_position_params(module): + if not getattr(module, 'can_global_init', True): + return + def normal_(data): + data.copy_( + data.cpu().normal_(mean=0.0, std=0.02).to(data.device) + ) + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + module.weight.data.zero_() + + +@register_model("bert") +class BertModel(BaseUnicoreModel): + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument( + "--encoder-layers", type=int, metavar="L", help="num encoder layers" + ) + parser.add_argument( + "--encoder-embed-dim", + type=int, + metavar="H", + help="encoder embedding dimension", + ) + parser.add_argument( + "--encoder-ffn-embed-dim", + type=int, + metavar="F", + help="encoder embedding dimension for FFN", + ) + parser.add_argument( + "--encoder-attention-heads", + type=int, + metavar="A", + help="num encoder attention heads", + ) + parser.add_argument( + "--activation-fn", + choices=utils.get_available_activation_fns(), + help="activation function to use", + ) + parser.add_argument( + "--pooler-activation-fn", + choices=utils.get_available_activation_fns(), + help="activation function to use for pooler layer", + ) + parser.add_argument( + "--emb-dropout", type=float, metavar="D", help="dropout probability for embeddings" + ) + parser.add_argument( + "--dropout", type=float, metavar="D", help="dropout probability" + ) + parser.add_argument( + "--attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights", + ) + parser.add_argument( + "--activation-dropout", + type=float, + metavar="D", + help="dropout probability after activation in FFN", + ) + parser.add_argument( + "--pooler-dropout", + type=float, + metavar="D", + help="dropout probability in the masked_lm pooler layers", + ) + parser.add_argument( + "--max-seq-len", type=int, help="number of positional embeddings to learn" + ) + parser.add_argument( + "--post-ln", type=bool, help="use post layernorm or pre layernorm" + ) + parser.add_argument("--ignore-inter-rotary", type=bool, default="") + parser.add_argument("--share-pos-emb", type=bool, default="") + + + def __init__(self, args, dictionary): + super().__init__() + base_architecture(args) + self.args = args + self.padding_idx = dictionary.pad() + self.mask_idx = dictionary.index('[MASK]') + self.embed_tokens = nn.Embedding(len(dictionary), args.encoder_embed_dim, self.padding_idx) + # self.embed_positions = nn.Embedding(args.max_seq_len, args.encoder_embed_dim) + self.sentence_encoder = TransformerEncoder( + encoder_layers=args.encoder_layers, + embed_dim=args.encoder_embed_dim, + ffn_embed_dim=args.encoder_ffn_embed_dim, + attention_heads=args.encoder_attention_heads, + emb_dropout=args.emb_dropout, + dropout=args.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + max_seq_len=args.max_seq_len, + activation_fn=args.activation_fn, + rel_pos=True, + rel_pos_bins=32, + max_rel_pos=128, + post_ln=args.post_ln, + ignore_inter_rotary=args.ignore_inter_rotary, + share_pos_emb=args.share_pos_emb, + ) + + self.lm_head = RobertaLMHead(embed_dim=args.encoder_embed_dim, + output_dim=len(dictionary), + activation_fn=args.activation_fn, + weight=self.embed_tokens.weight, + ) + # self.freq_head = FreqHead(embed_dim=args.encoder_embed_dim, + # output_dim=22, + # activation_fn=args.activation_fn, + # # weight=self.embed_tokens.weight, + # ) + self.classification_heads = nn.ModuleDict() + self.apply(init_position_params) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + return cls(args, task.dictionary) + + # def __make_freq_head_float__(self): + # self.freq_head = self.freq_head.float() + + def load_state_dict( + self, + state_dict, + strict=True, + model_args = None, + ): + def upgrade_state_dict(state_dict): + """Removes prefixes 'model.encoder.sentence_encoder.' and 'model.encoder.'.""" + prefixes = ["encoder."]# ["encoder.sentence_encoder.", "encoder."] + pattern = re.compile("^" + "|".join(prefixes)) + state_dict = {pattern.sub("", name): param for name, param in state_dict.items()} + + prefixes = ["sentence_encoder.embed_tokens"]# ["encoder.sentence_encoder.", "encoder."] + pattern = re.compile("^" + "|".join(prefixes)) + state_dict = {pattern.sub("embed_tokens", name): param for name, param in state_dict.items()} + return state_dict + + state_dict = upgrade_state_dict(state_dict) + + return super().load_state_dict(state_dict, strict) + + def half(self): + super().half() + # if (not getattr(self, "inference", False)): + # self.__make_freq_head_float__() + self.dtype = torch.half + return self + + def forward( + self, + src_tokens, + is_same_entity, + has_same_sequence, + masked_tokens=None, + features_only=False, + classification_head_name=None, + return_attn=False, + **kwargs + ): + if classification_head_name is not None: + features_only = True + # print("src_tokens:", src_tokens, src_tokens.shape) + # print("is_same_entity:", is_same_entity, is_same_entity.shape) + # print("has_same_sequence:", has_same_sequence[:, -20:, :], has_same_sequence.shape) + # None.shape + padding_mask = src_tokens.eq(self.padding_idx) + x = self.embed_tokens(src_tokens) + # x += self.embed_positions.weight[:src_tokens.size(1), :] + x.masked_fill_((src_tokens == self.mask_idx).unsqueeze(-1), 0.0) + # x: B x T x C + if not self.training: + mask_ratio_train = 0.15 * 0.8 + src_lengths = (~padding_mask).sum(-1) + mask_ratio_observed = (src_tokens == self.mask_idx).sum(-1).to(x.dtype) / src_lengths + x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None] + + if not padding_mask.any(): + padding_mask = None + + x = self.sentence_encoder(x, padding_mask=padding_mask, return_attn=return_attn, is_same_entity=is_same_entity, has_same_sequence=has_same_sequence, features_only=features_only) + + if return_attn: + _, attn = x + return attn + if not features_only: + x = self.lm_head(x, masked_tokens) + # freq_x = self.freq_head(x) + if classification_head_name is not None: + x = self.classification_heads[classification_head_name](x) + + return x + + + def register_classification_head( + self, name, num_classes=None, inner_dim=None, **kwargs + ): + """Register a classification head.""" + if name in self.classification_heads: + prev_num_classes = self.classification_heads[name].out_proj.out_features + prev_inner_dim = self.classification_heads[name].dense.out_features + if num_classes != prev_num_classes or inner_dim != prev_inner_dim: + logger.warning( + 're-registering head "{}" with num_classes {} (prev: {}) ' + "and inner_dim {} (prev: {})".format( + name, num_classes, prev_num_classes, inner_dim, prev_inner_dim + ) + ) + self.classification_heads[name] = BertClassificationHead( + input_dim=self.args.encoder_embed_dim, + inner_dim=inner_dim or self.args.encoder_embed_dim, + num_classes=num_classes, + activation_fn=self.args.pooler_activation_fn, + pooler_dropout=self.args.pooler_dropout, + ) + + +class RobertaLMHead(nn.Module): + """Head for masked language modeling.""" + + def __init__(self, embed_dim, output_dim, activation_fn, weight=None): + super().__init__() + self.dense = nn.Linear(embed_dim, embed_dim) + self.activation_fn = utils.get_activation_fn(activation_fn) + self.layer_norm = LayerNorm(embed_dim) + + if weight is None: + weight = nn.Linear(embed_dim, output_dim, bias=False).weight + self.weight = weight + self.bias = nn.Parameter(torch.zeros(output_dim)) + + def forward(self, features, masked_tokens=None, **kwargs): + # Only project the masked tokens while training, + # saves both memory and computation + if masked_tokens is not None: + features = features[masked_tokens, :] + features = features.type(self.dense.weight.dtype) + x = self.dense(features) + x = self.activation_fn(x) + x = self.layer_norm(x) + # project back to size of vocabulary with bias + x = F.linear(x, self.weight) + self.bias + return x + +class FreqHead(nn.Module): + """Head for masked language modeling.""" + + def __init__(self, embed_dim, output_dim, activation_fn, weight=None): + super().__init__() + self.dense = nn.Linear(embed_dim, embed_dim) + self.activation_fn = utils.get_activation_fn(activation_fn) + self.layer_norm = LayerNorm(embed_dim) + + if weight is None: + weight = nn.Linear(embed_dim, output_dim, bias=False).weight + self.weight = weight + self.bias = nn.Parameter(torch.zeros(output_dim)) + + def forward(self, features, masked_tokens=None, **kwargs): + # Only project the masked tokens while training, + # saves both memory and computation + if masked_tokens is not None: + features = features[masked_tokens, :] + features = features.type(self.dense.weight.dtype) + x = self.layer_norm(features) + x = self.dense(x) + x = self.activation_fn(x) + # project back to size of vocabulary with bias + x = F.linear(x, self.weight) + self.bias + return x + + +class BertClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim, + inner_dim, + num_classes, + activation_fn, + pooler_dropout, + ): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.activation_fn = utils.get_activation_fn(activation_fn) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = self.activation_fn(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + +@register_model_architecture("bert", "bert") +def base_architecture(args): + args.encoder_layers = getattr(args, "encoder_layers", 33) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1280) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 5120) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 20) + args.dropout = getattr(args, "dropout", 0.0) + args.emb_dropout = getattr(args, "emb_dropout", 0.0) + args.attention_dropout = getattr(args, "attention_dropout", 0.0) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) + args.max_seq_len = getattr(args, "max_seq_len", 1024) + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") + args.post_ln = getattr(args, "post_ln", False) + args.ignore_inter_rotary = getattr(args, "ignore_inter_rotary", False) + + +@register_model_architecture("bert", "bert_base") +def bert_base_architecture(args): + base_architecture(args) + + +@register_model_architecture("bert", "bert_large") +def bert_large_architecture(args): + args.encoder_layers = getattr(args, "encoder_layers", 24) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + base_architecture(args) + + +@register_model_architecture("bert", "xlm") +def xlm_architecture(args): + args.encoder_layers = getattr(args, "encoder_layers", 16) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1280) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1280 * 4) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16) + base_architecture(args) \ No newline at end of file diff --git a/unifold/musse/plm/model/multihead_attention.py b/unifold/musse/plm/model/multihead_attention.py new file mode 100755 index 0000000..e444b65 --- /dev/null +++ b/unifold/musse/plm/model/multihead_attention.py @@ -0,0 +1,260 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, Optional + +import torch +from torch import Tensor, nn +from .rotary_embedding import RotaryEmbedding +from unicore.modules.softmax_dropout import softmax_dropout + + +class SelfMultiheadAttention(nn.Module): + def __init__( + self, + embed_dim, + num_heads, + dropout=0.1, + bias=True, + scaling_factor=1, + ignore_inter_rotary=False, + share_pos_emb=False, + ): + super().__init__() + self.embed_dim = embed_dim + + self.num_heads = num_heads + self.dropout = dropout + + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = (self.head_dim * scaling_factor) ** -0.5 + + # self.in_proj = nn.Linear(embed_dim, embed_dim * 3, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.rot_emb = RotaryEmbedding(dim=self.head_dim) + self.is_same_entity_emb = nn.Embedding(2, num_heads) + self.has_same_sequence_emb = nn.Embedding(2, num_heads) + self.ignore_inter_rotary = ignore_inter_rotary + self.share_pos_emb = share_pos_emb + + def forward( + self, + query, + key_padding_mask: Optional[Tensor] = None, + attn_bias: Optional[Tensor] = None, + is_same_entity: Optional[Tensor] = None, + has_same_sequence: Optional[Tensor] = None, + multimer_pos_emb: Optional[Tensor] = None, + return_attn: bool = False, + ) -> Tensor: + + bsz, tgt_len, embed_dim = query.size() + assert embed_dim == self.embed_dim + + # q, k, v = self.in_proj(query).chunk(3, dim=-1) + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + + q = ( + q.view(bsz, tgt_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + .view(bsz * self.num_heads, -1, self.head_dim) + * self.scaling + ) + if k is not None: + k = ( + k.view(bsz, -1, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + .view(bsz * self.num_heads, -1, self.head_dim) + ) + if v is not None: + v = ( + v.view(bsz, -1, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + .view(bsz * self.num_heads, -1, self.head_dim) + ) + + assert k is not None + src_len = k.size(1) + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.ignore_inter_rotary: + q_raw, k_raw = q, k + attn_weights_raw = torch.bmm(q_raw, k_raw.transpose(1, 2)) # bsz * head, tgt_len, src_len + + q, k = self.rot_emb(q, k) + attn_weights = torch.bmm(q, k.transpose(1, 2)) + + + if self.ignore_inter_rotary: + + has_same_sequence_head = has_same_sequence.bool().unsqueeze(-1).expand(-1, -1, -1, self.num_heads).permute(0, 3, 1, 2).contiguous().view(bsz*self.num_heads, tgt_len, src_len) + attn_diff = (attn_weights - attn_weights_raw).view(bsz, self.num_heads, -1).mean(-1).detach() + attn_weights_raw = (attn_weights_raw.view(bsz, self.num_heads, tgt_len, src_len) + attn_diff[..., None, None]).view(bsz*self.num_heads, tgt_len, src_len) + attn_weights = torch.where(has_same_sequence_head, attn_weights, attn_weights_raw) + + if not self.share_pos_emb: + multimer_pos_emb = self.is_same_entity_emb(is_same_entity) + self.has_same_sequence_emb(has_same_sequence) + multimer_pos_emb = multimer_pos_emb \ + .permute(0, 3, 1, 2).contiguous() \ + .view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = attn_weights + multimer_pos_emb + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights.masked_fill_( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if not return_attn: + attn = softmax_dropout( + attn_weights, self.dropout, self.training, bias=attn_bias, + ) + else: + if attn_bias is not None: + attn_weights += attn_bias + attn = softmax_dropout( + attn_weights, self.dropout, self.training, inplace=False, + ) + + o = torch.bmm(attn, v) + assert list(o.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + o = ( + o.view(bsz, self.num_heads, tgt_len, self.head_dim) + .transpose(1, 2) + .contiguous() + .view(bsz, tgt_len, embed_dim) + ) + o = self.out_proj(o) + if not return_attn: + return o + else: + return o, attn_weights, attn + + +class CrossMultiheadAttention(nn.Module): + def __init__( + self, + embed_dim, + num_heads, + dropout=0.1, + bias=True, + scaling_factor=1, + ): + super().__init__() + self.embed_dim = embed_dim + + self.num_heads = num_heads + self.dropout = dropout + + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = (self.head_dim * scaling_factor) ** -0.5 + + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def forward( + self, + query, + key, + value, + key_padding_mask: Optional[Tensor] = None, + attn_bias: Optional[Tensor] = None, + ) -> Tensor: + + bsz, tgt_len, embed_dim = query.size() + assert embed_dim == self.embed_dim + + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + + q = ( + q.view(bsz, tgt_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + .view(bsz * self.num_heads, -1, self.head_dim) + * self.scaling + ) + if k is not None: + k = ( + k.view(bsz, -1, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + .view(bsz * self.num_heads, -1, self.head_dim) + ) + if v is not None: + v = ( + v.view(bsz, -1, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + .view(bsz * self.num_heads, -1, self.head_dim) + ) + + assert k is not None + src_len = k.size(1) + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights.masked_fill_( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn = softmax_dropout(attn_weights, self.dropout, self.training, bias=attn_bias) + + o = torch.bmm(attn, v) + assert list(o.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + + o = ( + o.view(bsz, self.num_heads, tgt_len, self.head_dim) + .transpose(1, 2) + .contiguous() + .view(bsz, tgt_len, embed_dim) + ) + o = self.out_proj(o) + return o \ No newline at end of file diff --git a/unifold/musse/plm/model/rotary_embedding.py b/unifold/musse/plm/model/rotary_embedding.py new file mode 100755 index 0000000..e862196 --- /dev/null +++ b/unifold/musse/plm/model/rotary_embedding.py @@ -0,0 +1,69 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + + +def rotate_half(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(x, cos, sin): + cos = cos[:, : x.shape[-2], :] + sin = sin[:, : x.shape[-2], :] + + return (x * cos) + (rotate_half(x) * sin) + + +class RotaryEmbedding(torch.nn.Module): + """ + The rotary position embeddings from RoFormer_ (Su et. al). + A crucial insight from the method is that the query and keys are + transformed by rotation matrices which depend on the relative positions. + Other implementations are available in the Rotary Transformer repo_ and in + GPT-NeoX_, GPT-NeoX was an inspiration + .. _RoFormer: https://arxiv.org/abs/2104.09864 + .. _repo: https://github.com/ZhuiyiTechnology/roformer + .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox + .. warning: Please note that this embedding is not registered on purpose, as it is transformative + (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis + """ + + def __init__(self, dim: int, *_, **__): + super().__init__() + # Generate and save the inverse frequency buffer (non trainable) + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + self._seq_len_cached = None + self._cos_cached = None + self._sin_cached = None + + def _update_cos_sin_tables(self, x, seq_dimension=1): + seq_len = x.shape[seq_dimension] + + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if seq_len != self._seq_len_cached or self._cos_cached.device != x.device: + self._seq_len_cached = seq_len + t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + + self._cos_cached = emb.cos()[None, :, :] + self._sin_cached = emb.sin()[None, :, :] + + return self._cos_cached, self._sin_cached + + def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2) + + return ( + apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), + apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), + ) diff --git a/unifold/musse/plm/model/transformer_encoder.py b/unifold/musse/plm/model/transformer_encoder.py new file mode 100755 index 0000000..3c31246 --- /dev/null +++ b/unifold/musse/plm/model/transformer_encoder.py @@ -0,0 +1,235 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import logging +from typing import Optional + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from unicore.modules import LayerNorm +from functools import partial +from torch.utils.checkpoint import checkpoint +from .transformer_encoder_layer import TransformerLayer + +logger = logging.getLogger(__name__) + +def checkpoint_sequential( + functions, + input, +): + def wrap_tuple(a): + return (a,) if type(a) is not tuple else a + + def exec(func, a): + return wrap_tuple(func(*a)) + + def get_wrap_exec(func): + def wrap_exec(*a): + return exec(func, a) + + return wrap_exec + + input = wrap_tuple(input) + + is_grad_enabled = torch.is_grad_enabled() + + if is_grad_enabled: + for func in functions: + input = checkpoint(get_wrap_exec(func), *input) + else: + for func in functions: + input = exec(func, input) + return input + +def init_bert_params(module): + if not getattr(module, 'can_global_init', True): + return + def normal_(data): + data.copy_( + data.cpu().normal_(mean=0.0, std=0.02).to(data.device) + ) + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +def relative_position_bucket(relative_position, num_buckets=32, max_distance=128): + sign = torch.sign(relative_position) + num_buckets //= 2 + n = torch.abs(relative_position) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = n < max_exact + max_bucket_val = num_buckets - 1 - max_exact + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + val_if_large = max_exact + torch.ceil( + torch.log(n.float() / max_exact) / math.log((max_distance - 1) / max_exact) * (max_bucket_val) + ).long() + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + ret = torch.where(is_small, n, val_if_large) * sign + return ret + + +class TransformerEncoder(nn.Module): + def __init__( + self, + encoder_layers: int = 6, + embed_dim: int = 768, + ffn_embed_dim: int = 3072, + attention_heads: int = 8, + emb_dropout: float = 0.1, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.0, + max_seq_len: int = 256, + activation_fn: str = "gelu", + rel_pos: bool = True, + rel_pos_bins: int = 32, + max_rel_pos: int = 128, + post_ln: bool = False, + ignore_inter_rotary = False, + share_pos_emb = False, + ) -> None: + + super().__init__() + self.emb_dropout = emb_dropout + self.max_seq_len = max_seq_len + self.embed_dim = embed_dim + self.attention_heads = attention_heads + # self.emb_layer_norm = LayerNorm(self.embed_dim) + if not post_ln: + self.emb_layer_norm_after = LayerNorm(self.embed_dim) + else: + self.emb_layer_norm_after = None + + self.share_pos_emb = share_pos_emb + + self.layers = nn.ModuleList( + [ + TransformerLayer( + embed_dim=self.embed_dim, + ffn_embed_dim=ffn_embed_dim, + attention_heads=attention_heads, + dropout=dropout, + attention_dropout=attention_dropout, + activation_dropout=activation_dropout, + activation_fn=activation_fn, + post_ln=post_ln, + ignore_inter_rotary=ignore_inter_rotary, + share_pos_emb=share_pos_emb, + ) + for _ in range(encoder_layers) + ] + ) + if share_pos_emb: + self.is_same_entity_emb = nn.Embedding(2, attention_heads) + self.has_same_sequence_emb = nn.Embedding(2, attention_heads) + + self.rel_pos = rel_pos + + + def forward( + self, + emb: torch.Tensor, + return_attn=False, + features_only=False, + attn_mask: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + is_same_entity: Optional[torch.Tensor] = None, + has_same_sequence: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + seq_len = emb.size(1) + bsz = emb.size(0) + # x = self.emb_layer_norm(emb) + x = F.dropout(emb, p=self.emb_dropout, training=self.training) + hidden_representations = {} + # account for padding while computing the representation + if padding_mask is not None: + x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) + + if self.share_pos_emb: + multimer_pos_emb = self.is_same_entity_emb(is_same_entity) + self.has_same_sequence_emb(has_same_sequence) + + + if attn_mask is not None and padding_mask is not None: + # merge key_padding_mask and attn_mask + attn_mask = attn_mask.view(x.size(0), -1, seq_len, seq_len) + attn_mask.masked_fill_( + padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf") + ) + attn_mask = attn_mask.view(-1, seq_len, seq_len) + padding_mask = None + + attn_probs_list = [] + if not self.training and not return_attn: + if self.share_pos_emb: + multimer_pos_emb = multimer_pos_emb.permute(0, 3, 1, 2).contiguous() \ + .view(bsz * self.attention_heads, seq_len, seq_len) + for layer_id, layer in enumerate(self.layers): + x = layer(x, padding_mask=padding_mask, attn_bias=attn_mask, return_attn=return_attn, is_same_entity=is_same_entity, has_same_sequence=has_same_sequence, multimer_pos_emb=multimer_pos_emb) + if self.share_pos_emb: + x, multimer_pos_emb = x + if features_only and layer_id == 0: + hidden_representations[layer_id + 1] = x + + elif return_attn: + if self.share_pos_emb: + multimer_pos_emb = multimer_pos_emb.permute(0, 3, 1, 2).contiguous() \ + .view(bsz * self.attention_heads, seq_len, seq_len) + for layer in self.layers: + x, attn_weights, attn_probs = layer(x, padding_mask=padding_mask, attn_bias=attn_mask, return_attn=return_attn, is_same_entity=is_same_entity, has_same_sequence=has_same_sequence, multimer_pos_emb=multimer_pos_emb) + if attn_probs.dim() == 3: + attn_probs = attn_probs.unsqueeze(0) + attn_probs_list.append(attn_probs) # B*H, L, L + + + else: + blocks = [ + partial( + b, + padding_mask=padding_mask, + is_same_entity=is_same_entity, + has_same_sequence=has_same_sequence, + # multimer_pos_emb=multimer_pos_emb, + # attn_bias=attn_mask, + ) + for b in self.layers + ] + if self.share_pos_emb: + multimer_pos_emb = multimer_pos_emb.permute(0, 3, 1, 2).contiguous() \ + .view(bsz * self.attention_heads, seq_len, seq_len) + x = checkpoint_sequential( + blocks, + input=(x, multimer_pos_emb), + )[0] + else: + x = checkpoint_sequential( + blocks, + input=x, + )[0] + + if self.emb_layer_norm_after is not None: + x = self.emb_layer_norm_after(x) + + if features_only: + hidden_representations[layer_id + 1] = x + assert 36 in hidden_representations + assert 1 in hidden_representations + if return_attn: + return x, torch.cat(attn_probs_list, dim=0) # num_layer, B*H, L, L + elif features_only: + return hidden_representations + else: + return x \ No newline at end of file diff --git a/unifold/musse/plm/model/transformer_encoder_layer.py b/unifold/musse/plm/model/transformer_encoder_layer.py new file mode 100755 index 0000000..e8a2604 --- /dev/null +++ b/unifold/musse/plm/model/transformer_encoder_layer.py @@ -0,0 +1,111 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, Optional + +import torch +import torch.nn.functional as F +from unicore import utils +from torch import nn +from unicore.modules import LayerNorm +from .multihead_attention import SelfMultiheadAttention + +class TransformerLayer(nn.Module): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embed_dim: int = 768, + ffn_embed_dim: int = 3072, + attention_heads: int = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.0, + activation_fn: str = "gelu", + post_ln = False, + ignore_inter_rotary = False, + share_pos_emb = False, + ) -> None: + super().__init__() + + # Initialize parameters + self.embed_dim = embed_dim + self.attention_heads = attention_heads + self.attention_dropout = attention_dropout + + self.dropout = dropout + self.activation_dropout = activation_dropout + self.activation_fn = utils.get_activation_fn(activation_fn) + + self.self_attn = SelfMultiheadAttention( + self.embed_dim, + attention_heads, + dropout=attention_dropout, + ignore_inter_rotary=ignore_inter_rotary, + share_pos_emb=share_pos_emb, + ) + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, ffn_embed_dim) + self.fc2 = nn.Linear(ffn_embed_dim, self.embed_dim) + self.final_layer_norm = LayerNorm(self.embed_dim) + self.post_ln = post_ln + self.share_pos_emb = share_pos_emb + + + def forward( + self, + x: torch.Tensor, + multimer_pos_emb: Optional[torch.Tensor] = None, + attn_bias: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + is_same_entity: Optional[torch.Tensor] = None, + has_same_sequence: Optional[torch.Tensor] = None, + return_attn: bool=False, + ) -> torch.Tensor: + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer implementation. + """ + assert attn_bias is None + residual = x + if not self.post_ln: + x = self.self_attn_layer_norm(x) + x = self.self_attn( + query=x, + key_padding_mask=padding_mask, + attn_bias=attn_bias, + return_attn=return_attn, + is_same_entity=is_same_entity, + has_same_sequence=has_same_sequence, + multimer_pos_emb=multimer_pos_emb, + ) + if return_attn: + x, attn_weights, attn_probs = x + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + if self.post_ln: + x = self.self_attn_layer_norm(x) + residual = x + if not self.post_ln: + x = self.final_layer_norm(x) + x = self.fc1(x) + x = self.activation_fn(x) + x = F.dropout(x, p=self.activation_dropout, training=self.training) + x = self.fc2(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = residual + x + if self.post_ln: + x = self.final_layer_norm(x) + if not return_attn and not self.share_pos_emb: + return x + elif not return_attn and self.share_pos_emb: + return x, multimer_pos_emb + else: + return x, attn_weights, attn_probs \ No newline at end of file diff --git a/unifold/pack_feat.py b/unifold/pack_feat.py new file mode 100644 index 0000000..35189fc --- /dev/null +++ b/unifold/pack_feat.py @@ -0,0 +1,149 @@ +# Copyright 2022 DP Technology +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run CPU MSA & template searching to get pickled features.""" +import json +import os +import pickle +from pathlib import Path +import time +import gzip + +from absl import app +from absl import flags +from absl import logging + +from unifold.data.utils import compress_features +from unifold.msa import parsers +from unifold.msa import pipeline +from unifold.msa import templates +from unifold.msa.utils import divide_multi_chains +from unifold.msa.tools import hmmsearch +from unifold.msa.pipeline import make_sequence_features + +logging.set_verbosity(logging.INFO) + +flags.DEFINE_string( + "fasta_path", + None, + "Path to FASTA file, If a FASTA file contains multiple sequences, " + "then it will be divided into several single sequences. ", +) + +flags.DEFINE_string( + "output_dir", None, "Path to a directory that will " "store the results." +) + +FLAGS = flags.FLAGS + + + +def _check_flag(flag_name: str, other_flag_name: str, should_be_set: bool): + if should_be_set != bool(FLAGS[flag_name].value): + verb = "be" if should_be_set else "not be" + raise ValueError( + f"{flag_name} must {verb} set when running with " + f'"--{other_flag_name}={FLAGS[other_flag_name].value}".' + ) + + +def generate_pkl_features( + fasta_path: str, + fasta_name: str, + output_dir_base: str, +): + """ + Predicts structure using AlphaFold for the given sequence. + """ + timings = {} + output_dir = os.path.join(output_dir_base, fasta_name.split("_")[0]) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + chain_id = fasta_name.split("_")[1] if len(fasta_name.split("_")) > 1 else "A" + input_seqs, input_descs = parsers.parse_fasta(open(fasta_path, "r").read()) + assert len(input_seqs) == 1 + input_seq = input_seqs[0] + input_desc = input_descs[0] + num_res = len(input_seq) + feature_dict = make_sequence_features(input_seq, input_desc, num_res) + + + # Get features. + features_output_path = os.path.join( + output_dir, "{}.feature.pkl.gz".format(fasta_name) + ) + if not os.path.exists(features_output_path): + t_0 = time.time() + + timings["features"] = time.time() - t_0 + feature_dict = compress_features(feature_dict) + pickle.dump(feature_dict, gzip.GzipFile(features_output_path, "wb"), protocol=4) + + + + logging.info("Final timings for %s: %s", fasta_name, timings) + + timings_output_path = os.path.join(output_dir, "{}.timings.json".format(fasta_name)) + with open(timings_output_path, "w") as f: + f.write(json.dumps(timings, indent=4)) + + +def main(argv): + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + fasta_path = FLAGS.fasta_path + fasta_name = Path(fasta_path).stem + input_fasta_str = open(fasta_path).read() + input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) + if len(input_seqs) > 1: + temp_names, temp_paths = divide_multi_chains( + fasta_name, FLAGS.output_dir, input_seqs, input_descs + ) + fasta_names = temp_names + fasta_paths = temp_paths + else: + output_dir = os.path.join(FLAGS.output_dir, fasta_name) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + chain_order_path = os.path.join(output_dir, "chains.txt") + with open(chain_order_path, "w") as f: + f.write("A") + fasta_names = [fasta_name] + fasta_paths = [fasta_path] + + # Check for duplicate FASTA file names. + if len(fasta_names) != len(set(fasta_names)): + raise ValueError("All FASTA paths must have a unique basename.") + + # Predict structure for each of the sequences. + for i, fasta_path in enumerate(fasta_paths): + fasta_name = fasta_names[i] + generate_pkl_features( + fasta_path=fasta_path, + fasta_name=fasta_name, + output_dir_base=FLAGS.output_dir, + ) + + +if __name__ == "__main__": + flags.mark_flags_as_required( + [ + "fasta_path", + "output_dir", + ] + ) + + app.run(main) diff --git a/unifold/task.py b/unifold/task.py index 594abf6..1f9ae18 100644 --- a/unifold/task.py +++ b/unifold/task.py @@ -7,6 +7,7 @@ import numpy as np from unifold.dataset import UnifoldDataset, UnifoldMultimerDataset +from unifold.musse import UnifoldSingleMultimerDataset from unicore.data import data_utils from unicore.tasks import UnicoreTask, register_task @@ -54,7 +55,10 @@ def load_dataset(self, split, combine=False, **kwargs): split (str): name of the split (e.g., train, valid, test) """ if self.config.model.is_multimer: - data_class = UnifoldMultimerDataset + if self.config.data.common.use_musse: + data_class = UnifoldSingleMultimerDataset + else: + data_class = UnifoldMultimerDataset else: data_class = UnifoldDataset if split == "train": @@ -75,7 +79,8 @@ def load_dataset(self, split, combine=False, **kwargs): self.config, self.args.data, mode="eval", - max_step=None, + max_step=128, + disable_sd=True, json_prefix=self.args.json_prefix, )