From bc982432b374b3e8b18cb85bc108067b19c3dafb Mon Sep 17 00:00:00 2001 From: Kevin Kaichuang Yang Date: Tue, 19 Nov 2024 15:55:46 -0500 Subject: [PATCH 1/6] Refactor common functions to utils.py. --- dayhoff/model.py | 23 +++++-- dayhoff/utils.py | 147 ++++++++++++++++++++++++++++++++++++++++- src/generate.py | 137 ++++---------------------------------- src/train-msa.py | 166 ++++------------------------------------------- src/valid.py | 107 ++---------------------------- 5 files changed, 193 insertions(+), 387 deletions(-) diff --git a/dayhoff/model.py b/dayhoff/model.py index 6710f1d..2481e05 100644 --- a/dayhoff/model.py +++ b/dayhoff/model.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel -from dayhoff.constants import MSA_ALPHABET_PLUS, TaskType +from dayhoff.constants import UL_ALPHABET_PLUS, TaskType from dayhoff.losses import OAMaskedCrossEntropyLoss OTHER_METRICS_KEY = "other_metrics" @@ -120,6 +120,13 @@ def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> dict: } return outputs + def inference(self, src: torch.Tensor) -> torch.Tensor: + self.module.eval() + with torch.inference_mode(): + output = self.module(src) + output = output["logits"] + return output + class MSAModelWithMetrics(nn.Module): """ @@ -213,7 +220,7 @@ def _create_bytenet( if pretrained: raise ValueError("Pretrained models not supported for ByteNet") - n_tokens = len(MSA_ALPHABET_PLUS) + n_tokens = len(UL_ALPHABET_PLUS) d_embed = model_config["d_embed"] d_model = model_config["d_model"] n_layers = model_config["n_layers"] @@ -250,6 +257,8 @@ def _get_hf_model( model_config: Optional[dict] = None, pretrained: bool = False, trust_remote_code: bool = False, + use_flash_attention_2: bool = False, + alphabet=UL_ALPHABET_PLUS ) -> nn.Module: if model_config and pretrained: # can't overwrite the config of a pretrained model @@ -283,19 +292,19 @@ def _get_hf_model( # ensure the vocab size is a multiple of 8 to maximize tensor core utilization model_config["vocab_size"] = ( - np.ceil(len(MSA_ALPHABET_PLUS) / 8).astype(int).item() * 8 + np.ceil(len(alphabet) / 8).astype(int).item() * 8 ) # TODO: This could be bad if alphabet gets bigger - model_config["pad_token_id"] = MSA_ALPHABET_PLUS.index( + model_config["pad_token_id"] = alphabet.index( MSA_PAD ) # FIXME: MSA_PAD or pad_token_id (which is mask_id in bytenet - model_config["bos_token_id"] = MSA_ALPHABET_PLUS.index(START) - model_config["eos_token_id"] = MSA_ALPHABET_PLUS.index(STOP) + model_config["bos_token_id"] = alphabet.index(START) + model_config["eos_token_id"] = alphabet.index(STOP) # merge the updates into the default config config = type(config).from_dict({**config.to_dict(), **model_config}) model = AutoModelForCausalLM.from_config( - config, trust_remote_code=trust_remote_code + config, trust_remote_code=trust_remote_code, use_flash_attention_2=use_flash_attention_2 ) return model diff --git a/dayhoff/utils.py b/dayhoff/utils.py index 7fbdf9b..2ca22be 100644 --- a/dayhoff/utils.py +++ b/dayhoff/utils.py @@ -1,4 +1,17 @@ +import json +import os +import random +from typing import Optional, Tuple + import numpy as np +import torch +from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict +import torch.distributed.checkpoint as dcp +import torch.nn as nn + +from evodiff.utils import Tokenizer +from dayhoff.model import _get_hf_model, ARDiffusionModel +from dayhoff.constants import UL_ALPHABET_PLUS def cosine_anneal_with_warmup(n_warmup_steps, n_anneal_steps, final_ratio=0.0): @@ -9,4 +22,136 @@ def get_lr(step): return step / n_warmup_steps else: return final_ratio + 0.5 * (1 - final_ratio) * (1 + np.cos((step - n_warmup_steps) * np.pi / n_anneal_steps)) - return get_lr \ No newline at end of file + return get_lr + + +def get_latest_dcp_checkpoint_path(ckpt_dir: str, last_step: int = -1) -> Optional[str]: + ckpt_path = None + if last_step == -1: + print("last step") + for dir_name in os.listdir(ckpt_dir): + if "dcp" in dir_name: + step = int(dir_name.split("dcp_")[-1]) + if step > last_step: + ckpt_path = os.path.join(ckpt_dir, dir_name) + last_step = step + else: + print("else") + ckpt_path = os.path.join(ckpt_dir, f"dcp_{last_step}") + return ckpt_path + + +def load_msa_config_and_model(config_fpath, alphabet=UL_ALPHABET_PLUS, use_flash_attention_2=True): + with open(config_fpath, "r") as f: + config = json.load(f) + + tokenizer = Tokenizer(protein_alphabet=alphabet) + model_config = config["model_config"] + pretrained = model_config.pop("pretrained", False) + success = False + while not success: + try: + model = _get_hf_model( + "ai21labs/Jamba-v0.1", + tokenizer.pad_id, + pretrained=pretrained, + model_config=model_config, + trust_remote_code=True, + use_flash_attention_2=use_flash_attention_2, + alphabet=UL_ALPHABET_PLUS + ) + success = True + except FileNotFoundError: + pass + block = {type(layer) for layer in model.model.layers} + aux_loss_weight = config.get("aux_loss_weight", 0.0) + model = ARDiffusionModel(model, aux_loss_weight=aux_loss_weight) + return config, tokenizer, model, block + + +def seed_everything(seed): + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + +def save_checkpoint( + out_dir: str, + model: nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler._LRScheduler, + step: int, + epoch: int, + tokens: int, + sequences: int, + iterations: int, + rank: int +) -> None: + out_path = os.path.join(out_dir, f"dcp_{step}") + print(f"Saving checkpoint to {out_path}", rank, flush=True) + model_state, optim_state = get_state_dict(model, optimizer) + sd = { + "model_state_dict": model_state, + "optimizer_state_dict": optim_state, + } + fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter(out_path) + _ = dcp.save(sd, storage_writer=fs_storage_writer) + sched_state = scheduler.state_dict() + sd = { + "step": step, + "tokens": tokens, + "sequences": sequences, + "scheduler_state_dict": sched_state, + "epoch": epoch, + "iterations": iterations + } + torch.save(sd, os.path.join(out_path, "scheduler%d.pt" %rank)) + + +def load_checkpoint( + model, optimizer, scheduler, ckpt_dir: str, last_step: int = -1, fast_forward=True, rank: int = 0 +) -> Tuple[int, int, int, int, int]: + ckpt_path = get_latest_dcp_checkpoint_path(ckpt_dir, last_step=last_step) + if ckpt_path: + print(f"Loading weights from {ckpt_path}...") + fs_storage_reader = torch.distributed.checkpoint.FileSystemReader(ckpt_path) + + model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) + state_dict = { + "model_state_dict": model_state_dict, + "optimizer_state_dict": optimizer_state_dict, + } + dcp.load( + state_dict=state_dict, + storage_reader=fs_storage_reader, + ) + # sets our state dicts on the model and optimizer, now that we've loaded + set_state_dict( + model, + optimizer, + model_state_dict=model_state_dict, + optim_state_dict=optimizer_state_dict, + ) + if os.path.exists(os.path.join(ckpt_path, "scheduler%d.pt" %rank)): + sd = torch.load( + os.path.join(ckpt_path, "scheduler%d.pt" %rank), map_location=torch.device("cpu") + ) + else: + sd = torch.load( + os.path.join(ckpt_path, "scheduler.pt"), map_location=torch.device("cpu") + ) + scheduler.load_state_dict(sd["scheduler_state_dict"]) + epoch = sd["epoch"] + if "iterations" in sd: + its = sd["iterations"] + else: + its = 0 + epoch += 100 + if not fast_forward: + epoch = epoch + 102 + its = 0 + return epoch, sd["step"], sd["tokens"], sd["sequences"], its + else: + return 0, 0, 0, 0, 0 \ No newline at end of file diff --git a/src/generate.py b/src/generate.py index a7a1a7d..886d5f4 100644 --- a/src/generate.py +++ b/src/generate.py @@ -3,32 +3,24 @@ import os import random from typing import Optional, Tuple +from tqdm import tqdm + -from esm.modules import AxialTransformerLayer -from evodiff.utils import Tokenizer -from evodiff.metrics import MaskedAccuracyMSA import numpy as np -from sequence_models.esm import MSATransformer -from sequence_models.losses import MaskedCrossEntropyLossMSA -from sequence_models.utils import warmup, transformer_lr + import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict from torch.optim import Adam from torch.optim.lr_scheduler import LambdaLR -from tqdm import tqdm - -from dayhoff.constants import MSA_ALPHABET_PLUS, END_AL -from dayhoff.model import MSAModelWithMetrics, _get_hf_model -def seed_everything(seed: int) -> None: - random.seed(seed) - os.environ["PYTHONHASHSEED"] = str(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) +from sequence_models.utils import warmup, transformer_lr +from sequence_models.constants import STOP +from dayhoff.constants import UL_ALPHABET_PLUS, END_AL, END_UL +from dayhoff.utils import (load_msa_config_and_model, get_latest_dcp_checkpoint_path, + load_checkpoint, seed_everything) # default to a single-GPU setup if not present @@ -38,104 +30,6 @@ def seed_everything(seed: int) -> None: DEVICE = torch.device(f"cuda:{RANK}") print("device", DEVICE) -def load_msa_config_and_model(config_fpath): - with open(config_fpath, "r") as f: - config = json.load(f) - n_tokens = len(MSA_ALPHABET_PLUS) - - tokenizer = Tokenizer(protein_alphabet=MSA_ALPHABET_PLUS) - accu_func = MaskedAccuracyMSA() - loss_func = MaskedCrossEntropyLossMSA(ignore_index=tokenizer.pad_id) - if config["model_type"] == "jamba": - model_config = config["model_config"] - pretrained = model_config.pop("pretrained", False) - model = _get_hf_model( - "ai21labs/Jamba-v0.1", - tokenizer.pad_id, - pretrained=pretrained, - model_config=model_config, - trust_remote_code=True, - ) - block = {type(layer) for layer in model.model.layers} - causal = True # must be true for jamba - elif config["model_type"] == "msa_transformer": - n_layers = config["n_layers"] - d_hidden = config["d_hidden"] - n_heads = config["n_heads"] - d_embed = config["d_embed"] - tie_weights = config.get("tie_weights", 0.0) # true if not empty - print("tie_weights", tie_weights) - # config["tie_weights"] = tie_weights # save - model = MSATransformer( - d_embed, - d_hidden, - n_layers, - n_heads, - use_ckpt=True, - n_tokens=n_tokens, - padding_idx=tokenizer.pad_id, - mask_idx=tokenizer.mask_id, - tie_weights=tie_weights, - ) - block = {AxialTransformerLayer} - causal = config.get("causal", False) # true if not empty - else: - raise Exception("Unknown model: {}".format(config["model"])) - aux_loss_weight = config.get("aux_loss_weight", 0.0) - config["causal"] = causal # save - model = MSAModelWithMetrics( - model, - loss_func, - accu_func, - tokenizer.pad_id, - tokenizer, - aux_loss_weight=aux_loss_weight, - model_type=config["model_type"], - ) - return config, tokenizer, model, block, causal - - -def get_latest_dcp_checkpoint_path(ckpt_dir: str, last_step: int = -1) -> Optional[str]: - ckpt_path = None - if last_step == -1: - print("last step") - for dir_name in os.listdir(ckpt_dir): - if "dcp" in dir_name: - step = int(dir_name.split("dcp_")[-1]) - if step > last_step: - ckpt_path = os.path.join(ckpt_dir, dir_name) - last_step = step - else: - print("else") - ckpt_path = os.path.join(ckpt_dir, f"dcp_{last_step}") - return ckpt_path - - -def load_checkpoint(model, optimizer, scheduler, ckpt_dir: str, last_step: int = -1) -> Tuple[int, int, int, int]: - ckpt_path = get_latest_dcp_checkpoint_path(ckpt_dir, last_step=last_step) - print(ckpt_path) - if ckpt_path: - print(f"Loading weights from {ckpt_path}...") - fs_storage_reader = torch.distributed.checkpoint.FileSystemReader(ckpt_path) - - model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) - state_dict = {"model_state_dict": model_state_dict, "optimizer_state_dict": optimizer_state_dict} - dcp.load( - state_dict=state_dict, - storage_reader=fs_storage_reader, - ) - # sets our state dicts on the model and optimizer, now that we've loaded - set_state_dict(model, optimizer, model_state_dict=model_state_dict, optim_state_dict=optimizer_state_dict) - checkpoint_path = os.path.join(ckpt_path, "scheduler.pt") - if os.path.exists(os.path.join(ckpt_path, "scheduler0.pt")): - checkpoint_path = os.path.join(ckpt_path, "scheduler0.pt") - sd = torch.load(checkpoint_path, map_location=torch.device("cpu")) - scheduler.load_state_dict(sd["scheduler_state_dict"]) - - # sequences must optionally return 0 for backwards compatibility with old checkpoints - return sd["epoch"] + 1, sd["step"], sd["tokens"], sd.get("sequences", 0) - else: - return 0, 0, 0, 0 def generate(args: argparse.Namespace) -> None: @@ -146,19 +40,14 @@ def generate(args: argparse.Namespace) -> None: #print("Initializing model...", RANK) # load model parameters from config file - config, tokenizer, model, block, causal = load_msa_config_and_model(os.path.join(args.in_fpath, "config.json")) + config, tokenizer, model, block = load_msa_config_and_model(os.path.join(args.in_fpath, "config.json")) #if args.verbose: #print("Done initializing model.", RANK) - lr = config["lr"] - weight_decay = 0 # filler , doesnt matter - optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay) warmup_steps = max(config["warmup_steps"], 1) - lr_func = transformer_lr(warmup_steps) - scheduler = LambdaLR(optimizer, lr_func) # Load model and optimizer onto CPU - initial_epoch, total_steps, total_tokens, total_seqs = load_checkpoint( - model, optimizer, scheduler, args.in_fpath, args.checkpoint_step + initial_epoch, total_steps, total_tokens, total_seqs, _ = load_checkpoint( + model, None, None, args.in_fpath, args.checkpoint_step, rank=RANK ) # Move only model to GPU model = model.to(DEVICE) @@ -179,7 +68,7 @@ def generate(args: argparse.Namespace) -> None: for s in tqdm(range(args.n_generations)): if args.verbose: - print(MSA_ALPHABET_PLUS) + print(UL_ALPHABET_PLUS) print(tokenizer.a_to_i) print(tokenizer.i_to_a) # Start from START token @@ -191,7 +80,7 @@ def generate(args: argparse.Namespace) -> None: for i in tqdm(range(max_len)): if reach_stop == False: # Add residues until it predicts STOP token or hits max seq len prediction = model.inference(sample) - p = prediction[:, -1, : len(MSA_ALPHABET_PLUS)] # predict next token + p = prediction[:, -1, : len(UL_ALPHABET_PLUS)] # predict next token p = torch.nn.functional.softmax(p / args.temp, dim=1) # exp p_sample = torch.multinomial(p, num_samples=1).to(DEVICE) sample = torch.cat((sample, p_sample), dim=1) diff --git a/src/train-msa.py b/src/train-msa.py index 4611ad2..11096f8 100644 --- a/src/train-msa.py +++ b/src/train-msa.py @@ -3,24 +3,18 @@ import functools import json import os -import random -from typing import Optional, Tuple +from typing import Tuple import time -from evodiff.utils import Tokenizer -from evodiff.metrics import MaskedAccuracyMSA import numpy as np -from sequence_models.losses import MaskedCrossEntropyLossMSA from sequence_models.samplers import SortishSampler, ClusteredSortishSampler, ApproxBatchSampler -from sequence_models.utils import warmup, transformer_lr +from sequence_models.utils import transformer_lr import torch import torch.nn as nn import torch.distributed as dist from torch.distributed.fsdp import ( BackwardPrefetch, FullyShardedDataParallel as FSDP, - FullOptimStateDictConfig, - FullStateDictConfig, ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType, @@ -31,25 +25,21 @@ from torch.optim import Adam from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader +import torch.distributed as dist +from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict +import torch.distributed.checkpoint as dcp import wandb + from dayhoff.activation_checkpointing import apply_activation_checkpointing -from dayhoff.collators import MSAOAMaskCollator, MSAARCollator -from dayhoff.constants import UL_ALPHABET_PLUS +from dayhoff.collators import MSAARCollator from dayhoff.datasets import OpenProteinDataset, UniRefDataset -from dayhoff.model import ARDiffusionModel, _get_hf_model from dayhoff.samplers import ApproxBatchSamplerMSA -from dayhoff.utils import cosine_anneal_with_warmup +from dayhoff.utils import (cosine_anneal_with_warmup, load_msa_config_and_model, + get_latest_dcp_checkpoint_path, seed_everything, load_checkpoint, save_checkpoint) -import torch.distributed as dist -from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict -import torch.distributed.checkpoint as dcp -from wandb import config -# FLOP Profiler -# from fvcore.nn import FlopCountAnalysis -# from flops_profiler.profiler import FlopsProfiler # default to a single-GPU setup if not present RANK = int(os.environ["RANK"]) @@ -59,35 +49,6 @@ OTHER_METRICS_KEY = "other_metrics" -def load_msa_config_and_model(config_fpath): - with open(config_fpath, "r") as f: - config = json.load(f) - n_tokens = len(UL_ALPHABET_PLUS) - - tokenizer = Tokenizer(protein_alphabet=UL_ALPHABET_PLUS) - accu_func = MaskedAccuracyMSA() - loss_func = MaskedCrossEntropyLossMSA(ignore_index=tokenizer.pad_id) - model_config = config["model_config"] - pretrained = model_config.pop("pretrained", False) - success = False - while not success: - try: - model = _get_hf_model( - "ai21labs/Jamba-v0.1", - tokenizer.pad_id, - pretrained=pretrained, - model_config=model_config, - trust_remote_code=True, - ) - success = True - except FileNotFoundError: - pass - block = {type(layer) for layer in model.model.layers} - causal = True # must be true for jamba - aux_loss_weight = config.get("aux_loss_weight", 0.0) - config["causal"] = causal # save - model = ARDiffusionModel(model, aux_loss_weight=aux_loss_weight) - return config, tokenizer, model, block, causal def is_amlt(): @@ -253,108 +214,6 @@ def get_msa_dataloader(config, tokenizer, args): return dl_train -def seed_everything(seed): - random.seed(seed) - os.environ["PYTHONHASHSEED"] = str(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - - -def load_checkpoint( - model, optimizer, scheduler, ckpt_dir: str, last_step: int = -1, fast_forward=True -) -> Tuple[int, int, int, int, int]: - ckpt_path = get_latest_dcp_checkpoint_path(ckpt_dir, last_step=last_step) - if ckpt_path: - print(f"Loading weights from {ckpt_path}...") - fs_storage_reader = torch.distributed.checkpoint.FileSystemReader(ckpt_path) - - model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) - state_dict = { - "model_state_dict": model_state_dict, - "optimizer_state_dict": optimizer_state_dict, - } - dcp.load( - state_dict=state_dict, - storage_reader=fs_storage_reader, - ) - # sets our state dicts on the model and optimizer, now that we've loaded - set_state_dict( - model, - optimizer, - model_state_dict=model_state_dict, - optim_state_dict=optimizer_state_dict, - ) - if os.path.exists(os.path.join(ckpt_path, "scheduler%d.pt" %RANK)): - sd = torch.load( - os.path.join(ckpt_path, "scheduler%d.pt" %RANK), map_location=torch.device("cpu") - ) - else: - sd = torch.load( - os.path.join(ckpt_path, "scheduler.pt"), map_location=torch.device("cpu") - ) - scheduler.load_state_dict(sd["scheduler_state_dict"]) - epoch = sd["epoch"] - if "iterations" in sd: - its = sd["iterations"] - else: - its = 0 - epoch += 100 - if not fast_forward: - epoch = epoch + 102 - its = 0 - return epoch, sd["step"], sd["tokens"], sd["sequences"], its - else: - return 0, 0, 0, 0, 0 - - -def save_checkpoint( - out_dir: str, - model: nn.Module, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler._LRScheduler, - step: int, - epoch: int, - tokens: int, - sequences: int, - iterations: int -) -> None: - out_path = os.path.join(out_dir, f"dcp_{step}") - print(f"Saving checkpoint to {out_path}", RANK, flush=True) - model_state, optim_state = get_state_dict(model, optimizer) - sd = { - "model_state_dict": model_state, - "optimizer_state_dict": optim_state, - } - fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter(out_path) - _ = dcp.save(sd, storage_writer=fs_storage_writer) - sched_state = scheduler.state_dict() - sd = { - "step": step, - "tokens": tokens, - "sequences": sequences, - "scheduler_state_dict": sched_state, - "epoch": epoch, - "iterations": iterations - } - torch.save(sd, os.path.join(out_path, "scheduler%d.pt" %RANK)) - - -def get_latest_dcp_checkpoint_path(ckpt_dir: str, last_step: int = -1) -> Optional[str]: - ckpt_path = None - if last_step == -1: - print("last step") - for dir_name in os.listdir(ckpt_dir): - if "dcp" in dir_name: - step = int(dir_name.split("dcp_")[-1]) - if step > last_step: - ckpt_path = os.path.join(ckpt_dir, dir_name) - last_step = step - else: - print("else") - ckpt_path = os.path.join(ckpt_dir, f"dcp_{last_step}") - return ckpt_path - def epoch( model: nn.Module, @@ -420,7 +279,8 @@ def epoch( epoch=current_epoch, tokens=total_tokens, sequences=total_seq, - iterations=i + iterations=i, + rank=RANK ) if args.cosine and total_steps == terminate_steps: return total_steps, total_tokens, total_seq @@ -448,7 +308,7 @@ def train(args): if args.verbose: print("Initializing model...", RANK, flush=True) - config, tokenizer, model, block, causal = load_msa_config_and_model( + config, tokenizer, model, block = load_msa_config_and_model( args.config_fpath ) @@ -539,7 +399,7 @@ def train(args): # load state initial_epoch, total_steps, total_tokens, total_seqs, current_it = load_checkpoint( - model, optimizer, scheduler, args.out_fpath, args.last_step, fast_forward=args.no_msas + model, optimizer, scheduler, args.out_fpath, args.last_step, fast_forward=args.no_msas, rank=RANK ) # override from config optimizer.param_groups[0]["lr"] = config["lr"] * lr_func(total_steps + 1) diff --git a/src/valid.py b/src/valid.py index c649d7c..fd8acee 100644 --- a/src/valid.py +++ b/src/valid.py @@ -1,35 +1,25 @@ import argparse import functools -import json import os -import random -from typing import Optional, Sequence, Tuple - -from evodiff.utils import Tokenizer -import numpy as np -from sequence_models.utils import transformer_lr +from typing import Sequence, Tuple import torch import torch.distributed as dist -from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict from torch.distributed.fsdp import ( BackwardPrefetch, FullyShardedDataParallel as FSDP, MixedPrecision, ShardingStrategy, ) -import torch.distributed.checkpoint as dcp from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy import torch.nn as nn -from torch.optim import Adam -from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import wandb import pandas as pd from dayhoff.collators import MSAARCollator -from dayhoff.constants import UL_ALPHABET_PLUS from dayhoff.datasets import UniRefDataset, OpenProteinDataset -from dayhoff.model import ARDiffusionModel, _get_hf_model, OTHER_METRICS_KEY +from dayhoff.model import OTHER_METRICS_KEY +from dayhoff.utils import (load_msa_config_and_model, seed_everything, load_checkpoint, get_latest_dcp_checkpoint_path) @@ -43,32 +33,6 @@ def is_amlt() -> bool: return os.environ.get("AMLT_OUTPUT_DIR", None) is not None -def load_config_and_model(config_fpath): - with open(config_fpath, "r") as f: - config = json.load(f) - tokenizer = Tokenizer(protein_alphabet=UL_ALPHABET_PLUS) - model_config = config["model_config"] - pretrained = model_config.pop("pretrained", False) - success = False - while not success: - try: - model = _get_hf_model( - "ai21labs/Jamba-v0.1", - tokenizer.pad_id, - pretrained=pretrained, - model_config=model_config, - trust_remote_code=True, - ) - success = True - except FileNotFoundError: - pass - block = {type(layer) for layer in model.model.layers} - causal = True # must be true for jamba - aux_loss_weight = config.get("aux_loss_weight", 0.0) - config["causal"] = causal # save - model = ARDiffusionModel(model, aux_loss_weight=aux_loss_weight) - return config, tokenizer, model, block - def get_val_dataloader(config, tokenizer, args): collator = MSAARCollator( @@ -115,13 +79,6 @@ def get_val_dataloader(config, tokenizer, args): return dl -def seed_everything(seed: int) -> None: - random.seed(seed) - os.environ["PYTHONHASHSEED"] = str(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - def step( model: nn.Module, @@ -165,55 +122,6 @@ def epoch( return total_loss / total_tokens, total_accu / total_tokens -def get_latest_dcp_checkpoint_path(ckpt_dir: str, last_step: int = -1) -> Optional[str]: - ckpt_path = None - if last_step == -1: - for dir_name in os.listdir(ckpt_dir): - if "dcp_" in dir_name: - step = int(dir_name.split('dcp_')[-1]) - if step > last_step: - ckpt_path = os.path.join(ckpt_dir, dir_name) - last_step = step - else: - ckpt_path = os.path.join(ckpt_dir, f'dcp_{last_step}') - return ckpt_path - - -def load_checkpoint(model, optimizer, scheduler, ckpt_dir: str, last_step: int = -1) -> Tuple[int, int, int, int]: - ckpt_path = get_latest_dcp_checkpoint_path(ckpt_dir, last_step=last_step) - if ckpt_path: - print(f"Loading weights from {ckpt_path}...", flush=True) - fs_storage_reader = torch.distributed.checkpoint.FileSystemReader(ckpt_path) - - model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) - state_dict = { - "model_state_dict": model_state_dict, - "optimizer_state_dict": optimizer_state_dict - } - dcp.load( - state_dict=state_dict, - storage_reader=fs_storage_reader, - ) - # sets our state dicts on the model and optimizer, now that we've loaded - set_state_dict( - model, - optimizer, - model_state_dict=model_state_dict, - optim_state_dict=optimizer_state_dict - ) - if os.path.exists(os.path.join(ckpt_path, 'scheduler.pt')): - sd = torch.load(os.path.join(ckpt_path, 'scheduler.pt'), map_location=torch.device("cpu")) - else: - sd = torch.load(os.path.join(ckpt_path, 'scheduler0.pt'), map_location=torch.device("cpu")) - - scheduler.load_state_dict(sd["scheduler_state_dict"]) - - # sequences must optionally return 0 for backwards compatibility with old checkpoints - return sd["epoch"] + 1, sd["step"], sd["tokens"], sd.get("sequences", 0) - else: - return 0, 0, 0, 0 - - def train(args: argparse.Namespace) -> None: print(f"Starting job on rank {RANK} with local rank {LOCAL_RANK} and world size {WORLD_SIZE}") seed_everything(0) @@ -223,7 +131,7 @@ def train(args: argparse.Namespace) -> None: torch.cuda.set_device(LOCAL_RANK) if args.verbose: print("Initializing model...", RANK) - config, tokenizer, model, blk_types = load_config_and_model(args.out_fpath + 'config.json') + config, tokenizer, model, blk_types = load_msa_config_and_model(args.out_fpath + 'config.json') if RANK == 0: wandb.init(config=config, mode='online') if args.verbose: @@ -268,11 +176,6 @@ def train(args: argparse.Namespace) -> None: mixed_precision=mixed_precision, backward_prefetch=bwd_prefetch, ) - lr = config["lr"] - warmup_steps = max(config["warmup_steps"], 1) - optimizer = Adam(model.parameters(), lr=lr, weight_decay=config.get("weight_decay", 0.0)) - lr_func = transformer_lr(warmup_steps) - scheduler = LambdaLR(optimizer, lr_func) results = pd.DataFrame(columns=['nsteps', 'ce_loss', 'accuracy']) if args.gigaref: out_fname = args.out_fpath + 'gigaref.csv' @@ -307,7 +210,7 @@ def train(args: argparse.Namespace) -> None: accu = r['accuracy'].values[0] else: # load the state - _ = load_checkpoint(model, optimizer, scheduler, args.out_fpath, step) + _ = load_checkpoint(model, None, None, args.out_fpath, step, rank=RANK) dl_valid.sampler.set_epoch(0) model = model.eval() loss, accu = epoch(model, dl_valid) From a35ffecdc6506ae558fabe3b14c03a3229af52ff Mon Sep 17 00:00:00 2001 From: Kevin Kaichuang Yang Date: Tue, 19 Nov 2024 17:14:25 -0500 Subject: [PATCH 2/6] Allow loading without optimizer and scheduler. --- dayhoff/utils.py | 60 +++++++++++++++++++++++++++++++----------------- 1 file changed, 39 insertions(+), 21 deletions(-) diff --git a/dayhoff/utils.py b/dayhoff/utils.py index 2ca22be..3157043 100644 --- a/dayhoff/utils.py +++ b/dayhoff/utils.py @@ -5,7 +5,8 @@ import numpy as np import torch -from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict +from torch.distributed.checkpoint.state_dict import (get_state_dict, set_state_dict, + get_model_state_dict, set_model_state_dict) import torch.distributed.checkpoint as dcp import torch.nn as nn @@ -58,7 +59,7 @@ def load_msa_config_and_model(config_fpath, alphabet=UL_ALPHABET_PLUS, use_flash model_config=model_config, trust_remote_code=True, use_flash_attention_2=use_flash_attention_2, - alphabet=UL_ALPHABET_PLUS + alphabet=alphabet ) success = True except FileNotFoundError: @@ -117,32 +118,49 @@ def load_checkpoint( if ckpt_path: print(f"Loading weights from {ckpt_path}...") fs_storage_reader = torch.distributed.checkpoint.FileSystemReader(ckpt_path) - - model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) - state_dict = { - "model_state_dict": model_state_dict, - "optimizer_state_dict": optimizer_state_dict, - } - dcp.load( - state_dict=state_dict, - storage_reader=fs_storage_reader, - ) - # sets our state dicts on the model and optimizer, now that we've loaded - set_state_dict( - model, - optimizer, - model_state_dict=model_state_dict, - optim_state_dict=optimizer_state_dict, - ) + if optimizer is not None: + model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer) + state_dict = { + "model_state_dict": model_state_dict, + "optimizer_state_dict": optimizer_state_dict, + } + dcp.load( + state_dict=state_dict, + storage_reader=fs_storage_reader, + ) + # sets our state dicts on the model and optimizer, now that we've loaded + set_state_dict( + model, + optimizer, + model_state_dict=model_state_dict, + optim_state_dict=optimizer_state_dict, + ) + else: + model_state_dict = get_model_state_dict(model) + state_dict = { + "model_state_dict": model_state_dict, + } + dcp.load( + state_dict=state_dict, + storage_reader=fs_storage_reader, + ) + # sets our state dicts on the model, now that we've loaded + set_model_state_dict( + model, + model_state_dict=model_state_dict + ) if os.path.exists(os.path.join(ckpt_path, "scheduler%d.pt" %rank)): sd = torch.load( os.path.join(ckpt_path, "scheduler%d.pt" %rank), map_location=torch.device("cpu") ) - else: + elif os.path.exists(os.path.join(ckpt_path, "scheduler.pt" %rank)): sd = torch.load( os.path.join(ckpt_path, "scheduler.pt"), map_location=torch.device("cpu") ) - scheduler.load_state_dict(sd["scheduler_state_dict"]) + else: + return 0, 0, 0, 0, 0 + if scheduler is not None: + scheduler.load_state_dict(sd["scheduler_state_dict"]) epoch = sd["epoch"] if "iterations" in sd: its = sd["iterations"] From 5fbce33b25365335f70aa565d3dff78e1f157208 Mon Sep 17 00:00:00 2001 From: Kevin Kaichuang Yang Date: Wed, 20 Nov 2024 16:54:56 -0500 Subject: [PATCH 3/6] Small bugfixes. --- dayhoff/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/dayhoff/utils.py b/dayhoff/utils.py index 3157043..88967a0 100644 --- a/dayhoff/utils.py +++ b/dayhoff/utils.py @@ -29,7 +29,6 @@ def get_lr(step): def get_latest_dcp_checkpoint_path(ckpt_dir: str, last_step: int = -1) -> Optional[str]: ckpt_path = None if last_step == -1: - print("last step") for dir_name in os.listdir(ckpt_dir): if "dcp" in dir_name: step = int(dir_name.split("dcp_")[-1]) @@ -37,7 +36,6 @@ def get_latest_dcp_checkpoint_path(ckpt_dir: str, last_step: int = -1) -> Option ckpt_path = os.path.join(ckpt_dir, dir_name) last_step = step else: - print("else") ckpt_path = os.path.join(ckpt_dir, f"dcp_{last_step}") return ckpt_path @@ -153,7 +151,7 @@ def load_checkpoint( sd = torch.load( os.path.join(ckpt_path, "scheduler%d.pt" %rank), map_location=torch.device("cpu") ) - elif os.path.exists(os.path.join(ckpt_path, "scheduler.pt" %rank)): + elif os.path.exists(os.path.join(ckpt_path, "scheduler.pt")): sd = torch.load( os.path.join(ckpt_path, "scheduler.pt"), map_location=torch.device("cpu") ) From 9f2757eb86425feee134ee981bdbe2eb3636fe35 Mon Sep 17 00:00:00 2001 From: Kevin Kaichuang Yang Date: Wed, 20 Nov 2024 16:55:09 -0500 Subject: [PATCH 4/6] More robust unconditional generation. --- src/generate.py | 165 ++++++++++++++++++++++++++++++------------------ 1 file changed, 105 insertions(+), 60 deletions(-) diff --git a/src/generate.py b/src/generate.py index 886d5f4..4aaf484 100644 --- a/src/generate.py +++ b/src/generate.py @@ -1,24 +1,20 @@ import argparse +import datetime import json import os import random from typing import Optional, Tuple from tqdm import tqdm +import re import numpy as np +from transformers import SuppressTokensLogitsProcessor import torch -import torch.distributed as dist -import torch.distributed.checkpoint as dcp -from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict -from torch.optim import Adam -from torch.optim.lr_scheduler import LambdaLR - -from sequence_models.utils import warmup, transformer_lr -from sequence_models.constants import STOP -from dayhoff.constants import UL_ALPHABET_PLUS, END_AL, END_UL +from sequence_models.constants import START, STOP, CAN_AAS, SEP, GAP, MSA_PAD +from dayhoff.constants import UL_ALPHABET_PLUS, END_AL, END_UL, START_AL, START_UL from dayhoff.utils import (load_msa_config_and_model, get_latest_dcp_checkpoint_path, load_checkpoint, seed_everything) @@ -34,16 +30,13 @@ def generate(args: argparse.Namespace) -> None: #print(f"Starting job on rank {RANK} with local rank {LOCAL_RANK} and world size {WORLD_SIZE}") - seed_everything(args.random_seed) - dist.init_process_group(backend="nccl") - #if args.verbose: - #print("Initializing model...", RANK) + seed_everything(args.random_seed + RANK) # load model parameters from config file - config, tokenizer, model, block = load_msa_config_and_model(os.path.join(args.in_fpath, "config.json")) - #if args.verbose: - #print("Done initializing model.", RANK) - warmup_steps = max(config["warmup_steps"], 1) + config, tokenizer, model, block = load_msa_config_and_model(os.path.join(args.in_fpath, "config.json"), + use_flash_attention_2=True) + if args.verbose: + print("Done initializing model.", RANK) # Load model and optimizer onto CPU initial_epoch, total_steps, total_tokens, total_seqs, _ = load_checkpoint( @@ -51,57 +44,101 @@ def generate(args: argparse.Namespace) -> None: ) # Move only model to GPU model = model.to(DEVICE) + model = model.to(torch.bfloat16) + all_tokens = list(range(40)) + allowed_tokens = [UL_ALPHABET_PLUS.index(aa) for aa in CAN_AAS] if args.task == "sequence": if args.start_rev: - start = tokenizer.stop_id - stop = tokenizer.start_id + start_seq = tokenizer.stop_id + eos_id = int(tokenizer.start_id) else: - start = tokenizer.start_id - stop = tokenizer.stop_id + start_seq = tokenizer.start_id + eos_id = int(tokenizer.stop_id) max_len = config["max_len"] - elif args.task == "msa": - start = tokenizer.start_id - stop = tokenizer.tokenize(END_AL) + elif args.task == "gap_no_query": + if args.start_rev: + start_seq = UL_ALPHABET_PLUS.index(END_AL) + eos_id = UL_ALPHABET_PLUS.index(START_AL) + else: + start_seq = UL_ALPHABET_PLUS.index(START_AL) + eos_id = UL_ALPHABET_PLUS.index(END_AL) + max_len = config["n_sequences"] * config["max_seq_len"] + elif args.task == "indel_no_query": + if args.start_rev: + start_seq = UL_ALPHABET_PLUS.index(END_UL) + eos_id = UL_ALPHABET_PLUS.index(START_UL) + else: + start_seq = UL_ALPHABET_PLUS.index(START_UL) + eos_id = UL_ALPHABET_PLUS.index(END_UL) + max_len = config["n_sequences"] * config["max_seq_len"] + elif args.task == "gap_query": + max_len = config["n_sequences"] * config["max_seq_len"] + if args.start_rev: + start_seq = UL_ALPHABET_PLUS.index(STOP) + eos_id = UL_ALPHABET_PLUS.index(START_AL) + all_tokens += [UL_ALPHABET_PLUS.index(START), UL_ALPHABET_PLUS.index(END_AL)] + else: + start_seq = UL_ALPHABET_PLUS.index(START) + eos_id = UL_ALPHABET_PLUS.index(END_AL) + all_tokens += [UL_ALPHABET_PLUS.index(STOP), UL_ALPHABET_PLUS.index(START_AL)] + elif args.task == "indel_query": max_len = config["n_sequences"] * config["max_seq_len"] - - untokenized_out = [] - - for s in tqdm(range(args.n_generations)): - if args.verbose: - print(UL_ALPHABET_PLUS) - print(tokenizer.a_to_i) - print(tokenizer.i_to_a) - # Start from START token - batch_size = 1 - sample = torch.full((batch_size, 1), start, dtype=torch.long).to(DEVICE) - - # Iterate over each residue until STOP or max length - reach_stop = False # initialize - for i in tqdm(range(max_len)): - if reach_stop == False: # Add residues until it predicts STOP token or hits max seq len - prediction = model.inference(sample) - p = prediction[:, -1, : len(UL_ALPHABET_PLUS)] # predict next token - p = torch.nn.functional.softmax(p / args.temp, dim=1) # exp - p_sample = torch.multinomial(p, num_samples=1).to(DEVICE) - sample = torch.cat((sample, p_sample), dim=1) - if args.verbose: - print(tokenizer.untokenize(sample[0])) - if p_sample == stop: - reach_stop = True - else: - break - # print(sample) - untokenized = tokenizer.untokenize(sample[0]) - print("final sequence: ", untokenized) if args.start_rev: - untokenized_out.append(untokenized[::-1]) # append forward sequence - # print("fixed", untokenized[::-1]) + start_seq = UL_ALPHABET_PLUS.index(STOP) + eos_id = UL_ALPHABET_PLUS.index(START_UL) + all_tokens += [UL_ALPHABET_PLUS.index(START), UL_ALPHABET_PLUS.index(END_UL)] else: - untokenized_out.append(untokenized) + start_seq = UL_ALPHABET_PLUS.index(START) + eos_id = UL_ALPHABET_PLUS.index(END_UL) + all_tokens += [UL_ALPHABET_PLUS.index(STOP), UL_ALPHABET_PLUS.index(START_UL)] + if "gap" in args.task or "indel" in args.task: + allowed_tokens += [UL_ALPHABET_PLUS.index(SEP)] + if "gap" in args.task: + allowed_tokens += [UL_ALPHABET_PLUS.index(GAP)] + allowed_tokens += [eos_id] + seps = [SEP, START, STOP, END_UL, START_UL, END_AL, START_AL] + seps_regex = "|".join(seps) + start = torch.tensor([[start_seq]]).to(DEVICE) + start = torch.repeat_interleave(start, args.batch_size, dim=0) + model.module.generation_config.eos_token_id = eos_id + sup = SuppressTokensLogitsProcessor([t for t in all_tokens if not t in allowed_tokens], device=DEVICE) + if args.start_rev: + task = args.task + ".rev" + else: + task = args.task + ".fwd" + out_dir = os.path.join(args.out_fpath, args.model_name + '_' + str(total_steps) + "_" + task + '_t%.1f' %args.temp) + if RANK == 0: + os.makedirs(out_dir, exist_ok=True) + for s in tqdm(range(args.n_generations // args.batch_size)): + generated = model.module.generate(start, do_sample=True, logits_processor=[sup], + temperature=args.temp, num_beams=1, max_new_tokens=max_len, + use_cache=True) + untokenized = [tokenizer.untokenize(g) for g in generated] if args.task == "sequence": - with open(args.out_fpath + "/generated_samples.fasta", "a") as f: - f.write(">3BCOOLED_SEQUENCE_" + str(s) + "\n" + str(untokenized[1:-1]) + "\n") - + for n, unt in enumerate(untokenized): + n_gen = s * args.batch_size + n + print(unt, flush=True) + with open(os.path.join(out_dir, 'rank%d.fasta' %RANK), "a") as f: + f.write(">%d_%d\n" %(RANK, n_gen)) + if args.start_rev: + unt = unt[::-1] + f.write(unt.replace(START, "").replace(STOP, "").replace(MSA_PAD, "") + "\n") + else: + for n, unt in enumerate(untokenized): + n_gen = s * args.batch_size + n + with open(os.path.join(out_dir, 'rank%d_%d.fasta' %(RANK, n_gen)), "w") as f: + unt = unt.replace(MSA_PAD, "")[1:-1] # Strip out whatever stop and start + # Replace all things in the middle with new lines + for sep in seps: + unt = unt.replace(sep, " ") + unt = unt.split() + for i, seq in enumerate(unt): + f.write(">%d\n" %i) + if args.start_rev: + seq = seq[::-1] + f.write(seq + "\n") + print(">%d" %i) + print(seq, flush=True) @@ -109,6 +146,7 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument("in_fpath", type=str) # location of checkpoint parser.add_argument("out_fpath", type=str) # location to write to + parser.add_argument("model_name", type=str) parser.add_argument("--verbose", action="store_true") parser.add_argument("--checkpoint_step", type=int, default=-1) parser.add_argument("--n_generations", type=int, default=100) @@ -116,7 +154,14 @@ def main(): parser.add_argument("--temp", type=float, default=1.0) # parser.add_argument("--random_seed", type=int, default=0) # parser.add_argument("--start_rev", action="store_true") + parser.add_argument("--dir", type=str, default="") + parser.add_argument("--batch_size", type=int, default=1) + args = parser.parse_args() + if args.dir == "fwd": + args.start_rev = False + elif args.dir == "rev": + args.start_rev = True generate(args) From b8d3358f2c7f89ca140b8a8bbc834a4104003e5d Mon Sep 17 00:00:00 2001 From: Kevin Kaichuang Yang Date: Wed, 8 Jan 2025 16:22:31 -0500 Subject: [PATCH 5/6] Cleanup utilities and a bunch of analysis code. --- analysis/embed.py | 33 ++++ analysis/extract_test_fastas.py | 86 +++++++++ analysis/fpd.py | 214 ++++++++++++++++++++++ analysis/plot_valid.py | 308 ++++++++++++++++++++++++++++++++ analysis/plot_zs.py | 66 +++++++ analysis/zeroshot.py | 245 +++++++++++++++++++++++++ dayhoff/collators.py | 8 +- dayhoff/datasets.py | 37 ++++ src/cgenerate.py | 205 +++++++++++++++++++++ src/generate.py | 8 +- src/valid_position.py | 191 ++++++++++++++++++++ 11 files changed, 1398 insertions(+), 3 deletions(-) create mode 100644 analysis/embed.py create mode 100644 analysis/extract_test_fastas.py create mode 100644 analysis/fpd.py create mode 100644 analysis/plot_valid.py create mode 100644 analysis/plot_zs.py create mode 100644 analysis/zeroshot.py create mode 100644 src/cgenerate.py create mode 100644 src/valid_position.py diff --git a/analysis/embed.py b/analysis/embed.py new file mode 100644 index 0000000..7343201 --- /dev/null +++ b/analysis/embed.py @@ -0,0 +1,33 @@ +import os +import subprocess + + + +# in_dir = '/home/kevyan/generations/sequences/' +# out_dir = '/home/kevyan/generations/generations_to_embed/' +# directories = os.listdir(in_dir) +# for directory in directories: +# if 'sequence' in directory: +# print(directory) +# p = subprocess.run('cat ' + in_dir + directory + '/*' + '>' + out_dir + directory + '.fasta', shell=True) + +# in_dir = '/home/kevyan/generations/generations_to_embed/' +# out_dir = '/home/kevyan/generations/proteinfer/' +# fastas = os.listdir(in_dir) +# for fasta in fastas: +# p = subprocess.run('python /home/kevyan/src/proteinfertorch/bin/get_embeddings.py --data-path %s --weights-dir samirchar/proteinfertorch-go-random-13731645 --num-embedding-partitions 1 --output-dir ~/generations/embeddings/%s/' %(in_dir + fasta, fasta[:-6]), shell=True) +# +# for fasta in fastas: +# p = subprocess.run('python /home/kevyan/src/ProtTrans/Embedding/prott5_embedder.py --input %s --output ~/generations/protbert/%s.h5 --model ProtBert-BFD --per_protein 1' %(in_dir + fasta, fasta[:-6]), shell=True) + +in_dir = '/home/kevyan/generations/natural_sequences/' +out_dir = '/home/kevyan/generations/proteinfer/' +fastas = os.listdir(in_dir) +# for fasta in fastas: +# p = subprocess.run('python /home/kevyan/src/proteinfertorch/bin/get_embeddings.py --data-path %s --weights-dir samirchar/proteinfertorch-go-random-13731645 --num-embedding-partitions 1 --output-dir ~/generations/proteinfer/%s/' %(in_dir + fasta, fasta[:-6]), shell=True) + +for fasta in fastas: + if 'gigaref' not in fasta: + continue + print(fasta) + p = subprocess.run('python /home/kevyan/src/ProtTrans/Embedding/prott5_embedder.py --input %s --output ~/generations/protbert/%s.h5 --model ProtBert-BFD --per_protein 1' %(in_dir + fasta, fasta[:-6]), shell=True) \ No newline at end of file diff --git a/analysis/extract_test_fastas.py b/analysis/extract_test_fastas.py new file mode 100644 index 0000000..2327932 --- /dev/null +++ b/analysis/extract_test_fastas.py @@ -0,0 +1,86 @@ +import datetime +import os +from tqdm import tqdm + + +import numpy as np + +import torch + +from sequence_models.constants import OTHER_AAS, AMB_AAS +from dayhoff.utils import seed_everything +from dayhoff.datasets import UniRefDataset + + +# default to a single-GPU setup if not present +RANK = int(os.environ["RANK"]) +WORLD_SIZE = int(os.environ["WORLD_SIZE"]) +DEVICE = torch.device(f"cuda:{RANK}") + + +seed_everything(0) + + +def generate() -> None: + data_seq_dir = '/mnt/data/protein/' + data_name = 'uniref50_202401' + split_names = ['valid', 'train', 'test', 'rtest'] + n = 10000 + for split_name in split_names: + print(data_name, split_name, datetime.datetime.now(), flush=True) + ds_train = UniRefDataset(os.path.join(data_seq_dir, data_name + '/'), split_name, + max_len=2048) + with open(os.path.join('/mnt/checkpoints/evodiff/generations/', data_name + "_" + split_name + "_10k.fasta"), 'w') as f: + idx = np.arange(len(ds_train)) + np.random.shuffle(idx) + successes = 0 + i = -1 + with tqdm(total=n) as pbar: + while successes < n: + i += 1 + seq = ds_train[idx[i]][0] + for aa in OTHER_AAS + AMB_AAS: + if aa in seq: + break + else: + f.write(">%d\n" %i) + f.write(seq + "\n") + successes += 1 + pbar.update(1) + + data_name = 'gigaref' + split_names = ['train', 'test'] + for split_name in split_names: + print(data_name, split_name, datetime.datetime.now(), flush=True) + ds_train = UniRefDataset(data_seq_dir + data_name + '/', split_name, + max_len=2048, split_file=data_seq_dir + data_name + '/' + 'no_singletons/splits.json') + with open(os.path.join('/mnt/checkpoints/evodiff/generations/', data_name + "_" + split_name + "_10k.fasta"), 'w') as f: + idx = np.arange(len(ds_train)) + np.random.shuffle(idx) + successes = 0 + i = -1 + with tqdm(total=n) as pbar: + while successes < n: + i += 1 + seq = ds_train[idx[i]][0] + for aa in OTHER_AAS +AMB_AAS: + if aa in seq: + break + else: + f.write(">%d\n" %i) + f.write(seq + "\n") + successes += 1 + pbar.update(1) + + + + + + +def main(): + if RANK == 0: + generate() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/analysis/fpd.py b/analysis/fpd.py new file mode 100644 index 0000000..c03f978 --- /dev/null +++ b/analysis/fpd.py @@ -0,0 +1,214 @@ +import os +from tqdm import tqdm +import h5py +import numpy as np +import pandas as pd +from scipy import linalg +import torch +from sklearn import metrics +import matplotlib.pyplot as plt +import seaborn as sns + +sns.set_style('white') + +def mmd_rbf(X, Y, gamma=1.0): + """MMD using rbf (gaussian) kernel (i.e., k(x,y) = exp(-gamma * ||x-y||^2 / 2)) + + Arguments: + X {[n_sample1, dim]} -- [X matrix] + Y {[n_sample2, dim]} -- [Y matrix] + + Keyword Arguments: + gamma {float} -- [kernel parameter] (default: {1.0}) + + Returns: + [scalar] -- [MMD value] + """ + XX = metrics.pairwise.rbf_kernel(X, X, gamma) + YY = metrics.pairwise.rbf_kernel(Y, Y, gamma) + XY = metrics.pairwise.rbf_kernel(X, Y, gamma) + return XX.mean() + YY.mean() - 2 * XY.mean() + + +def calculate_fid(act1, act2, eps=1e-6): + """calculate frechet inception distance""" + # calculate mean and covariance statistics + mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False) + mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False) + # calculate sum squared difference between means + ssdiff = np.sum((mu1 - mu2) ** 2.0) + # calculate sqrt of product between cov + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ( + "fid calculation produces singular product; " + "adding %s to diagonal of cov estimates" + ) % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + # check and correct imaginary numbers from sqrt + if np.iscomplexobj(covmean): + covmean = covmean.real + # calculate score + fid = ssdiff + np.trace(sigma1) + np.trace(sigma2) - 2.0 * np.trace(covmean) + return fid + +# Baselines +natural_sets = ['uniref_train', 'uniref_valid', 'gigaref_train', 'gigaref_test'] +# natural_sets = ['train_GO', 'test_GO'] +natural_files = { + 'uniref_train': 'uniref50_202401_train_10k', + 'uniref_test': 'uniref50_202401_test_10k', + 'uniref_valid': 'uniref50_202401_valid_10k', + 'uniref_rtest': 'uniref50_202401_rtest_10k', + 'gigaref_train': 'gigaref_train_10k', + 'gigaref_test': 'gigaref_test_10k', + 'train_GO': 'train_GO', + 'test_GO': 'test_10000_GO' +} +embedding_dir = '/home/kevyan/generations/proteinfer/' +embedding_dict = {s: torch.load(embedding_dir + natural_files[s] + '/partition_0.pt').numpy() + for s in natural_sets} +gamma = 1e-3 +mult = 100 +mmd_dict = {} +fpd_dict = {} +for i, s in enumerate(natural_sets): + ei = embedding_dict[s] + for j, s2 in enumerate(natural_sets): + if i > j: + ej = embedding_dict[s2] + mmd = mmd_rbf(ei[:], ej[:], gamma=gamma) * mult + fpd = calculate_fid(ei[:], ej[:], eps=1e-6) + print(s, s2, mmd, fpd) + mmd_dict[s + ':' + s2] = mmd + fpd_dict[s + ':' + s2] = fpd + + +pb_embedding_dir = '/home/kevyan/generations/protbert/' +pb_embedding_dict = {} +for s in natural_sets: + fn = os.path.join(pb_embedding_dir, natural_files[s] + '.h5') + f = h5py.File(fn, 'r') + pb_embedding_dict[s] = np.array([f[k] for k in f.keys()]) +pb_gamma = 1 +mult = 100 +pb_mmd_dict = {} +pb_fpd_dict = {} +for i, s in enumerate(natural_sets): + ei = pb_embedding_dict[s] + for j, s2 in enumerate(natural_sets): + if i > j: + ej = pb_embedding_dict[s2] + mmd = mmd_rbf(ei[:], ej[:], gamma=pb_gamma) * mult + fpd = calculate_fid(ei[:], ej[:], eps=1e-6) + print(s, s2, mmd, fpd) + pb_mmd_dict[s + ':' + s2] = mmd + pb_fpd_dict[s + ':' + s2] = fpd + +models = os.listdir(embedding_dir) +models = [m for m in models if 'jamba' in m] +model_name = { + 'jamba-3b-indel-gigaclust-120k-2': '3b-msa-gigaclust', + 'jamba-3b-cooldown': '3b-msa-uniref90-cooldown', + 'jamba-3b-cooldown7': '3b-msa-uniref90-cooldown', + 'jamba-170m-10mnovelty-36w': '170m-1novelty', + 'jamba-170m-seq-36w': '170m-uniref50', + 'jamba-170m-10mrmsd-36w': '170m-rmsd', + 'jamba-170m-10mbothfilter-36w': '170m-bothfilter', + 'jamba-3b-seq-sam-biar-fsdp-tok90k': '3b-uniref90', + 'jamba-170m-10mnofilter-36w': '170m-nofilter', + 'jamba-170m-seqsam-36w': '170m-uniref90', + 'jamba-170m-gigaclust-36w': '170m-gigaclust' +} +df = pd.DataFrame(columns=[ + 'name', + 'direction', + 'temperature', + 'step', + 'proteinfer_mmd_to_uniref', + 'proterinfer_mmd_to_gigaref', + 'protbert_mmd_to_uniref', + 'protbert_mmd_to_gigaref', + 'proteinfer_fd_to_uniref', + 'proteinfer_fd_to_gigaref', + 'protbert_fd_to_uniref', + 'protbert_fd_to_gigaref', +]) +for i, m in tqdm(enumerate(models)): + # d = m.split('_') + # df.loc[i, 'name'] = model_name[d[0]] + # df.loc[i, 'step'] = int(d[1]) + # df.loc[i, 'direction'] = d[2].split('.')[1] + # df.loc[i, 'temperature'] = float(d[3][1:]) + # emb = torch.load(embedding_dir + m + '/partition_0.pt').numpy() + # if np.isnan(emb).any(): + # emb = emb[np.isnan(emb).sum(axis=1) == 0] + # df.loc[i, 'proteinfer_mmd_to_uniref'] = mmd_rbf(emb, embedding_dict['uniref_valid'], gamma=gamma) * mult + # df.loc[i, 'proteinfer_mmd_to_gigaref'] = mmd_rbf(emb, embedding_dict['gigaref_test'], gamma=gamma) * mult + # df.loc[i, 'proteinfer_fd_to_uniref'] = calculate_fid(emb, embedding_dict['uniref_valid']) + # df.loc[i, 'proteinfer_fd_to_gigaref'] = calculate_fid(emb, embedding_dict['gigaref_test']) + emb = h5py.File(pb_embedding_dir + '/' + m + '.h5') + emb = np.array([emb[k] for k in emb.keys()]) + if np.isnan(emb).any(): + emb = emb[np.isnan(emb).sum(axis=1) == 0] + df.loc[i, 'protbert_mmd_to_uniref'] = mmd_rbf(emb, pb_embedding_dict['uniref_valid'], gamma=pb_gamma) * mult + df.loc[i, 'protbert_mmd_to_gigaref'] = mmd_rbf(emb, pb_embedding_dict['gigaref_test'], gamma=pb_gamma) * mult + df.loc[i, 'protbert_fd_to_uniref'] = calculate_fid(emb, pb_embedding_dict['uniref_valid']) + df.loc[i, 'protbert_fd_to_gigaref'] = calculate_fid(emb, pb_embedding_dict['gigaref_test']) +df.to_csv('/home/kevyan/generations/fpd.csv', index=False) + +models_to_plot = ['3b-msa-gigaclust', '3b-msa-uniref90-cooldown', '3b-uniref', '170m-uniref90', '170m-gigaclust'] +uniref_hue_order = ['3b-uniref', '3b-msa-uniref90-cooldown', '170m-uniref90', '3b-msa-gigaclust', '170m-gigaclust'] +plot_me = df[(df['name'].isin(models_to_plot)) & (df['temperature'] > 0.8) & (df['temperature'] < 1.2)] + +fig, axs = plt.subplots(1, 2, figsize=(12, 4)) +_ = sns.lineplot(plot_me, x='temperature', y='proteinfer_mmd_to_uniref', + ax=axs[0], hue='name', style='direction', hue_order=uniref_hue_order, legend=False) +_ = axs[0].axhline(mmd_dict['uniref_valid:uniref_train'], color='gray') +_ = sns.lineplot(plot_me, x='temperature', y='proteinfer_mmd_to_gigaref', + ax=axs[1], hue='name', style='direction', legend=True, hue_order=uniref_hue_order) +_ = axs[1].axhline(mmd_dict['gigaref_test:gigaref_train'], color='gray') +# for ax in axs: +# _ = ax.set_ylim([-0.01, 0.6]) +_ = axs[1].legend(bbox_to_anchor=(1.1, 1.)) +_ = fig.savefig('/home/kevyan/generations/proteinfer_mmd.png', dpi=300, bbox_inches='tight') + +fig, axs = plt.subplots(1, 2, figsize=(12, 4)) +_ = sns.lineplot(plot_me, x='temperature', y='proteinfer_fd_to_uniref', + ax=axs[0], hue='name', style='direction', hue_order=uniref_hue_order, legend=False) +_ = axs[0].axhline(fpd_dict['uniref_valid:uniref_train'], color='gray') +_ = sns.lineplot(plot_me, x='temperature', y='proteinfer_fd_to_gigaref', + ax=axs[1], hue='name', style='direction', legend=True, hue_order=uniref_hue_order) +_ = axs[1].axhline(fpd_dict['gigaref_test:gigaref_train'], color='gray') +# +# for ax in axs: +# _ = ax.set_ylim([-0.01, 0.4]) +_ = axs[1].legend(bbox_to_anchor=(1.1, 1.)) +_ = fig.savefig('/home/kevyan/generations/proteinfer_fpd.png', dpi=300, bbox_inches='tight') + + +plot_me = df[(df['name'].isin(models_to_plot))] # & (df['temperature'] > 0.8) & (df['temperature'] < 1.2)] + +fig, axs = plt.subplots(1, 2, figsize=(12, 4)) +_ = sns.lineplot(plot_me, x='temperature', y='protbert_mmd_to_uniref', + ax=axs[0], hue='name', style='direction', hue_order=uniref_hue_order, legend=False) +_ = axs[0].axhline(pb_mmd_dict['uniref_valid:uniref_train'], color='gray') +_ = sns.lineplot(plot_me, x='temperature', y='protbert_mmd_to_gigaref', + ax=axs[1], hue='name', style='direction', legend=True, hue_order=uniref_hue_order) +_ = axs[1].axhline(pb_mmd_dict['gigaref_test:gigaref_train'], color='gray') +_ = axs[1].legend(bbox_to_anchor=(1.1, 1.)) +_ = fig.savefig('/home/kevyan/generations/protbert_mmd.png', dpi=300, bbox_inches='tight') +plot_me = df[(df['name'].isin(models_to_plot)) & (df['temperature'] > 0.7)] # & (df['temperature'] < 1.2)] + +fig, axs = plt.subplots(1, 2, figsize=(12, 4)) +_ = sns.lineplot(plot_me, x='temperature', y='protbert_fd_to_uniref', + ax=axs[0], hue='name', style='direction', hue_order=uniref_hue_order, legend=False) +_ = axs[0].axhline(pb_fpd_dict['uniref_valid:uniref_train'], color='gray') +_ = sns.lineplot(plot_me, x='temperature', y='protbert_fd_to_gigaref', + ax=axs[1], hue='name', style='direction', legend=True, hue_order=uniref_hue_order) +_ = axs[1].axhline(pb_fpd_dict['gigaref_test:gigaref_train'], color='gray') +_ = axs[1].legend(bbox_to_anchor=(1.1, 1.)) +_ = fig.savefig('/home/kevyan/generations/protbert_fpd.png', dpi=300, bbox_inches='tight') +df[df['name'] == '3b-msa-uniref90-cooldown'][['direction', 'temperature', 'protbert_fd_to_uniref']] \ No newline at end of file diff --git a/analysis/plot_valid.py b/analysis/plot_valid.py new file mode 100644 index 0000000..10d6bab --- /dev/null +++ b/analysis/plot_valid.py @@ -0,0 +1,308 @@ +import os + +import torch +import numpy as np +from matplotlib import pyplot as plt +import seaborn as sns +import pandas as pd + +_ = sns.set_style('white') + +world_size = 8 +models = ['jamba-3b-seq-sam-biar-fsdp-tok90k', 'jamba-170m-seqsam-36w'] +checkpoints = { + 'jamba-3b-seq-sam-biar-fsdp-tok90k': [10000, 25000, 43300 ], + 'jamba-170m-seqsam-36w': [10000, 40000, 76000] +} +directions = ['forward', 'reverse'] +out_fpath = '/home/kevyan/generations/' +pal = sns.color_palette() +for model in models: + for direction in directions: + fig1, ax1 = plt.subplots(1, 1) + for i, checkpoint in enumerate(checkpoints[model]): + ces = [] + for rank in range(world_size): + out_file = os.path.join(out_fpath, "valid_" + model + '_' + str(checkpoint) + "_" + "uniref" + "_" + direction + "_%d.pt" %rank) + dat = torch.load(out_file) + ces.append(dat["ce"]) + ces = torch.cat(ces) + ces = np.array(ces) + ces[ces == 0] = np.nan + ces = ces[:, :-1] + ce_by_pos = np.nanmean(ces, axis=0) + x = np.arange(len(ce_by_pos)) + std_by_pos = np.nanstd(ces, axis=0) + se_by_pos = std_by_pos / np.sqrt(np.isfinite(ces).sum(axis=0)) + _ = ax1.plot(x, ce_by_pos, "-", label=str(checkpoint), color=pal[i], alpha=0.7) + # _ = ax1.fill_between(x, ce_by_pos + se_by_pos, ce_by_pos - se_by_pos, alpha=0.3, color=pal[i]) + _ = ax1.set_xlabel('position') + _ = ax1.set_ylabel('cross-entropy') + _ = ax1.legend() + ax2 = ax1.twinx() + n = np.isfinite(ces).sum(axis=0) + _ = ax2.plot(x, n, "-", color="gray") + _ = ax2.set_ylabel('n') + _ = fig1.savefig(os.path.join(out_fpath, model + "_" + "uniref" + "_" + direction + ".png"), dpi=300, bbox_inches="tight") + + +models = ['jamba-170m-gigaclust-36w'] +checkpoints = { + 'jamba-170m-gigaclust-36w': [10000, 40000, 76000] +} +directions = ['forward'] +out_fpath = '/home/kevyan/generations/' +pal = sns.color_palette() +for model in models: + for direction in directions: + fig1, ax1 = plt.subplots(1, 1) + for i, checkpoint in enumerate(checkpoints[model]): + ces = [] + for rank in range(world_size): + out_file = os.path.join(out_fpath, "valid_" + model + '_' + str(checkpoint) + "_" + "gigaref" + "_" + direction + "_%d.pt" %rank) + dat = torch.load(out_file) + ces.append(dat["ce"]) + ces = torch.cat(ces) + ces = np.array(ces) + ces[ces == 0] = np.nan + ces = ces[:, :-1] + ce_by_pos = np.nanmean(ces, axis=0) + x = np.arange(len(ce_by_pos)) + std_by_pos = np.nanstd(ces, axis=0) + se_by_pos = std_by_pos / np.sqrt(np.isfinite(ces).sum(axis=0)) + _ = ax1.plot(x, ce_by_pos, "-", label=str(checkpoint), color=pal[i], alpha=0.7) + # _ = ax1.fill_between(x, ce_by_pos + se_by_pos, ce_by_pos - se_by_pos, alpha=0.3, color=pal[i]) + _ = ax1.set_xlabel('position') + _ = ax1.set_ylabel('cross-entropy') + _ = ax1.legend() + ax2 = ax1.twinx() + n = np.isfinite(ces).sum(axis=0) + _ = ax2.plot(x, n, "-", color="gray") + _ = ax2.set_ylabel('n') + _ = fig1.savefig(os.path.join(out_fpath, model + "_" + "gigaref" + "_" + direction + ".png"), dpi=300, bbox_inches="tight") + + +models = ['jamba-3b-indel-gigaclust-120k-2', 'jamba-3b-cooldown', 'jamba-3b-cooldown7'] +checkpoints = { + 'jamba-3b-indel-gigaclust-120k-2': [10000, 25000, 52000 ], + 'jamba-3b-cooldown': [12000], + 'jamba-3b-cooldown7': [25000] +} +model_name = { + 'jamba-3b-indel-gigaclust-120k-2': 'indel-gigaclust', + 'jamba-3b-cooldown': 'indel-uniref-cooldown', + 'jamba-3b-cooldown7': 'indel-uniref-cooldown' +} +total_steps = 0 +direction = 'forward' +out_fpath = '/home/kevyan/generations/' +pal = sns.color_palette() +fig1, ax1 = plt.subplots(1, 1) +pal_counter = 0 +for model in models: + for i, checkpoint in enumerate(checkpoints[model]): + ces = [] + if model == 'jamba-3b-indel-gigaclust-120k-2': + total_steps = checkpoint + else: + total_steps += checkpoint + for rank in range(world_size): + out_file = os.path.join(out_fpath, "valid_" + model + '_' + str(checkpoint) + "_" + "uniref" + "_" + direction + "_%d.pt" %rank) + dat = torch.load(out_file) + ces.append(dat["ce"]) + ces = torch.cat(ces) + ces = np.array(ces) + ces[ces == 0] = np.nan + ces = ces[:, :-1] + ce_by_pos = np.nanmean(ces, axis=0) + x = np.arange(len(ce_by_pos)) + std_by_pos = np.nanstd(ces, axis=0) + se_by_pos = std_by_pos / np.sqrt(np.isfinite(ces).sum(axis=0)) + _ = ax1.plot(x, ce_by_pos, "-", label=model_name[model] + '_' + str(total_steps), color=pal[pal_counter], alpha=0.7) + pal_counter += 1 + # _ = ax1.fill_between(x, ce_by_pos + se_by_pos, ce_by_pos - se_by_pos, alpha=0.3, color=pal[i]) + _ = ax1.set_xlabel('position') + _ = ax1.set_ylabel('cross-entropy') + _ = ax1.legend() + _ = ax1.set_ylim([1.5, 2.9]) + ax2 = ax1.twinx() + n = np.isfinite(ces).sum(axis=0) + _ = ax2.plot(x, n, "-", color="gray") + _ = ax2.set_ylabel('n') + _ = fig1.savefig(os.path.join(out_fpath, "jamba-3b-combined_" + "uniref" + "_" + direction + ".png"), dpi=300, bbox_inches="tight") + + +models = ['jamba-3b-indel-gigaclust-120k-2', 'jamba-3b-cooldown', 'jamba-3b-cooldown7'] +checkpoints = { + 'jamba-3b-indel-gigaclust-120k-2': [10000, 25000, 52000 ], + 'jamba-3b-cooldown': [12000], + 'jamba-3b-cooldown7': [25000] +} +model_name = { + 'jamba-3b-indel-gigaclust-120k-2': 'indel-gigaclust', + 'jamba-3b-cooldown': 'indel-uniref-cooldown', + 'jamba-3b-cooldown7': 'indel-uniref-cooldown' +} +total_steps = 0 +direction = 'forward' +out_fpath = '/home/kevyan/generations/' +pal = sns.color_palette() +fig1, ax1 = plt.subplots(1, 1) +pal_counter = 0 +for model in models: + for i, checkpoint in enumerate(checkpoints[model]): + ces = [] + if model == 'jamba-3b-indel-gigaclust-120k-2': + total_steps = checkpoint + else: + total_steps += checkpoint + for rank in range(world_size): + out_file = os.path.join(out_fpath, "valid_" + model + '_' + str(checkpoint) + "_" + "gigaref" + "_" + direction + "_%d.pt" %rank) + dat = torch.load(out_file) + ces.append(dat["ce"]) + ces = torch.cat(ces) + ces = np.array(ces) + ces[ces == 0] = np.nan + ces = ces[:, :-1] + ce_by_pos = np.nanmean(ces, axis=0) + x = np.arange(len(ce_by_pos)) + std_by_pos = np.nanstd(ces, axis=0) + se_by_pos = std_by_pos / np.sqrt(np.isfinite(ces).sum(axis=0)) + _ = ax1.plot(x, ce_by_pos, "-", label=model_name[model] + '_' + str(total_steps), color=pal[pal_counter], alpha=0.7) + pal_counter += 1 + # _ = ax1.fill_between(x, ce_by_pos + se_by_pos, ce_by_pos - se_by_pos, alpha=0.3, color=pal[i]) + _ = ax1.set_xlabel('position') + _ = ax1.set_ylabel('cross-entropy') + _ = ax1.legend() + _ = ax1.set_ylim([0.5, 5]) + ax2 = ax1.twinx() + n = np.isfinite(ces).sum(axis=0) + _ = ax2.plot(x, n, "-", color="gray") + _ = ax2.set_ylabel('n') + _ = fig1.savefig(os.path.join(out_fpath, "jamba-3b-combined_" + "gigaref" + "_" + direction + ".png"), dpi=300, bbox_inches="tight") + + +model = 'jamba-3b-cooldown7' +checkpoints = { + 'jamba-3b-cooldown7': [25000], +} +model_name = { + 'jamba-3b-cooldown7': 'Dayhoff-3B' +} +tasks = ["indel", "gap"] +total_steps = 0 +direction = 'forward' +out_fpath = '/home/kevyan/generations/' +pal = sns.color_palette() +fig1, ax1 = plt.subplots(1, 1) +ax2 = ax1.twinx() +pal_counter = 0 +for task in tasks: + for i, checkpoint in enumerate(checkpoints[model]): + ces = [] + for rank in range(world_size): + out_file = os.path.join(out_fpath, "valid_" + model + '_' + str(checkpoint) + "_" + task + "_" + direction + "_%d.pt" %rank) + dat = torch.load(out_file) + ces.append(dat["ce"]) + ces = torch.cat(ces) + ces = np.array(ces) + ces[ces == 0] = np.nan + ces = ces[:, :65000] + ce_by_pos = np.nanmean(ces, axis=0) + x = np.arange(len(ce_by_pos)) + _ = ax1.plot(x, ce_by_pos, "-", label=task, color=pal[pal_counter], alpha=0.7) + pal_counter += 1 + # _ = ax1.fill_between(x, ce_by_pos + se_by_pos, ce_by_pos - se_by_pos, alpha=0.3, color=pal[i]) + n = np.isfinite(ces).sum(axis=0) + _ = ax2.plot(x, n, "-", color="gray", alpha=0.7) +_ = ax1.set_xlabel('position') +_ = ax1.set_ylabel('cross-entropy') +_ = ax1.legend() +_ = ax2.set_ylabel('n') +_ = fig1.savefig(os.path.join(out_fpath, "jamba-3b-combined" + "msas" + "_" + direction + ".png"), dpi=300, bbox_inches="tight") + +model = 'jamba-3b-cooldown7' +checkpoints = { + 'jamba-3b-cooldown7': [25000], +} +model_name = { + 'jamba-3b-cooldown7': 'Dayhoff-3B' +} +tasks = ["indel", "gap"] +total_steps = 0 +direction = 'forward' +out_fpath = '/home/kevyan/generations/' +pal = sns.color_palette() +fig1, ax1 = plt.subplots(1, 1) +ax2 = ax1.twinx() +pal_counter = 0 +for task in tasks: + for i, checkpoint in enumerate(checkpoints[model]): + ces = [] + for rank in range(world_size): + out_file = os.path.join(out_fpath, "valid_long_" + model + '_' + str(checkpoint) + "_" + task + "_" + direction + "_%d.pt" %rank) + dat = torch.load(out_file) + ces.append(dat["ce"]) + ces = torch.cat(ces) + ces = np.array(ces) + ces[ces == 0] = np.nan + ces = ces[:, :130000] + ce_by_pos = np.nanmean(ces, axis=0) + x = np.arange(len(ce_by_pos)) + _ = ax1.plot(x, ce_by_pos, "-", label=task, color=pal[pal_counter], alpha=0.7) + pal_counter += 1 + # _ = ax1.fill_between(x, ce_by_pos + se_by_pos, ce_by_pos - se_by_pos, alpha=0.3, color=pal[i]) + n = np.isfinite(ces).sum(axis=0) + _ = ax2.plot(x, n, "-", color="gray", alpha=0.7) +_ = ax1.set_xlabel('position') +_ = ax1.set_ylabel('cross-entropy') +_ = ax1.legend() +_ = ax2.set_ylabel('n') +_ = fig1.savefig(os.path.join(out_fpath, "jamba-3b-combined" + "_long_msas" + "_" + direction + ".png"), dpi=300, bbox_inches="tight") + + + +ces.shape +np.nanmax(ces) +np.nanmin(ces) +np.exp(14) / 1e6 +ces_by_seq = np.nanmean(ces[1:-1], axis=1) +ces_by_seq.shape +ces_by_seq.argmax() +ces_by_seq.min() + +world_size = 8 +models = ['jamba-3b-seq-sam-biar-fsdp-tok90k'] +checkpoints = [10000, 25000, 43300 ] +direction = 'forward' +out_fpath = '/home/kevyan/generations/' +pal = sns.color_palette() +for model in models: + fig1, ax1 = plt.subplots(1, 1) + ces_by_step = [] + for i, checkpoint in enumerate(checkpoints): + ces = [] + for rank in range(world_size): + out_file = os.path.join(out_fpath, "valid_" + model + '_' + str(checkpoint) + "_" + "uniref" + "_" + direction + "_%d.pt" %rank) + dat = torch.load(out_file) + ces.append(dat["ce"]) + ces = torch.cat(ces) + ces_by_step.append(ces) + ces_by_step = torch.stack(ces_by_step) + # normalize by position and step +ces_by_step = np.array(ces_by_step) +ces = ces_by_step[:, :, 1:-1].transpose(1, 2, 0).reshape(-1, 3) +m = np.isnan(ces) +m = m.sum(axis=1) == 0 +ces_sel = ces[m] +idx = np.random.choice(len(ces_sel), 200, replace=False) +fig, ax = plt.subplots(1, 1) +for i in idx: + _ = ax.plot(checkpoints, ces_sel[i], color=pal[3], alpha=0.2) +_ = ax.set_xlabel("step") +_ = ax.set_ylabel("cross-entropy") +_ = fig.savefig(os.path.join(out_fpath, model + "_" + "uniref" + "_" + direction + "_step.png"), dpi=300, bbox_inches="tight") + + + +ces[10] diff --git a/analysis/plot_zs.py b/analysis/plot_zs.py new file mode 100644 index 0000000..8dc1230 --- /dev/null +++ b/analysis/plot_zs.py @@ -0,0 +1,66 @@ +import os + +import torch +import numpy as np +from matplotlib import pyplot as plt +import seaborn as sns +import pandas as pd + +_ = sns.set_style('white') + +world_size = 8 +model_names = [ + # 'dayhoff-3b-msa-gigaref', + 'dayhoff-3b-msa-uniref90-cooldown', + 'dayhoff-170m-1novelty', + 'dayhoff-170m-uniref50', + 'dayhoff-170m-rmsd', + 'dayhoff-170m-bothfilter', + 'dayhoff-3b-uniref90', + 'dayhoff-170m-nofilter', + 'dayhoff-170m-uniref90', + 'dayhoff-170m-gigaref' +] +dmss = ['indels', 'substitutions'] + +out_fpath = '/home/kevyan/generations/proteingym/' +pal = sns.color_palette() +fig, ax = plt.subplots() +dfs = [] +for model in model_names: + for dms in dmss: + for rank in range(world_size): + df_path = os.path.join(out_fpath, dms, model + '_{}.csv'.format(rank)) + if os.path.exists(df_path): + df = pd.read_csv(df_path) + if 'seq_spearman' in df: + df['both_spearman'] = df['indel_spearman'] + fixed_indel = [] + fixed_seq = [] + for d in df['seq_spearman'].values: + split_d = d.split('.') + if '-' in split_d[1]: + fixed_seq.append(float('.'.join(split_d[:2])[:-2])) + if len(split_d) == 3: + fixed_indel.append(float('-0.' + split_d[2])) + else: + fixed_indel.append(np.nan) + else: + fixed_seq.append(float('.'.join(split_d[:2])[:-1])) + if len(split_d) == 3: + fixed_indel.append(float('0.' + split_d[2])) + else: + fixed_indel.append(np.nan) + df['en_spearman'] = fixed_seq + df['indel_spearman'] = fixed_indel + df['model'] = model + df['dms'] = dms + dfs.append(df) +df = pd.concat(dfs, ignore_index=True) +df.groupby('model').agg({'en_spearman': [np.nanmean], 'indel_spearman': [np.nanmean], 'both_spearman': [np.nanmean]}) +df[df['dms'] == 'indels'].groupby('model').agg({'en_spearman': [np.nanmean], 'indel_spearman': [np.nanmean], 'both_spearman': [np.nanmean]}) +df[df['dms'] == 'substitutions'].groupby('model').agg({'en_spearman': [np.nanmean], 'indel_spearman': [np.nanmean], 'both_spearman': [np.nanmean]}) + + +df[df['model'] == 'dayhoff-3b-msa-uniref90-cooldown']['indel_spearman'] +df.columns \ No newline at end of file diff --git a/analysis/zeroshot.py b/analysis/zeroshot.py new file mode 100644 index 0000000..d32c214 --- /dev/null +++ b/analysis/zeroshot.py @@ -0,0 +1,245 @@ +import argparse +import functools +import os +from typing import Sequence, Tuple +from tqdm import tqdm + +import torch +import torch.distributed as dist + +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader + +from scipy.stats import spearmanr + +import pandas as pd +import numpy as np +from sequence_models.utils import parse_fasta +from dayhoff.datasets import ListDataset +from dayhoff.collators import MSAARCollator +from dayhoff.model import OTHER_METRICS_KEY +from dayhoff.utils import load_msa_config_and_model, seed_everything, load_checkpoint + + + +# default to a single-GPU setup if not present +RANK = int(os.environ["RANK"]) +LOCAL_RANK = int(os.environ["LOCAL_RANK"]) +WORLD_SIZE = int(os.environ["WORLD_SIZE"]) +OFFSET = int(os.getenv("OFFSET", 0)) +DEVICE = torch.device(f"cuda:{LOCAL_RANK + OFFSET}") + + +def is_amlt() -> bool: + return os.environ.get("AMLT_OUTPUT_DIR", None) is not None + + + + +def zero_shot( + dms_dir, + out_dir, + msa_dir, + model: nn.Module, + tokenizer, + args, +): + subst_files = os.listdir(dms_dir) + fw_collator = MSAARCollator(tokenizer, flip_prob=0.0) + bw_collator = MSAARCollator(tokenizer, flip_prob=1.0) + os.makedirs(out_dir, exist_ok=True) + summary_path = os.path.join(out_dir, args.model_name + '_%d.csv' %RANK) + if not os.path.exists(summary_path): + with open(summary_path, 'w') as f: + if args.no_seq: + f.write(','.join(['assay', 'indel_spearman']) + '\n') + elif not args.msa: + f.write(','.join(['assay', 'fw_spearman', 'bw_spearman', 'en_spearman']) + '\n') + else: + f.write(','.join(['assay', 'fw_spearman', 'bw_spearman', 'seq_spearman', + 'indel_spearman', 'en_spearman']) + '\n') + for j, file in enumerate(tqdm(subst_files)): + if j % WORLD_SIZE != RANK: + continue + assay_name = file.split('.csv')[0] + if args.no_seq: + df_out_file = os.path.join(out_dir, args.model_name + '_no_seq_' + assay_name + '.csv') + else: + df_out_file = os.path.join(out_dir, args.model_name + '_' + assay_name + '.csv') + if os.path.exists(df_out_file): + continue + df_in = pd.read_csv(os.path.join(dms_dir, file)) + df_out = pd.DataFrame() + if 'mutant' in df_in: + df_out['mutant'] = df_in['mutant'] + else: + df_out['mutated_sequence'] = df_in['mutated_sequence'] + df_out['DMS_score'] = df_in['DMS_score'] + df_out['assay'] = assay_name + sequences = list(df_in["mutated_sequence"]) + ds = ListDataset(sequences) + if not args.no_seq: + dl = DataLoader(ds, batch_size=1, collate_fn=fw_collator, num_workers=4, shuffle=False) + for i, batch in enumerate(dl): + src, tgt = batch + src = src.to(DEVICE) + tgt = tgt.to(DEVICE) + outputs = model(src, tgt) + with torch.no_grad(): + ce = outputs[OTHER_METRICS_KEY]['ce_loss'].item() + df_out.loc[i, args.model_name + '_fw_score'] = -ce + dl = DataLoader(ds, batch_size=1, collate_fn=bw_collator, num_workers=8, shuffle=False) + for i, batch in enumerate(dl): + src, tgt = batch + src = src.to(DEVICE) + tgt = tgt.to(DEVICE) + outputs = model(src, tgt) + with torch.no_grad(): + ce = outputs[OTHER_METRICS_KEY]['ce_loss'].item() + df_out.loc[i, args.model_name + '_bw_score'] = -ce + df_out[args.model_name + '_seq_score'] = (df_out[args.model_name + '_fw_score'] + df_out[ + args.model_name + '_bw_score']) / 2 + fw_spearman = spearmanr(df_out[args.model_name + '_fw_score'], df_out['DMS_score']).statistic + bw_spearman = spearmanr(df_out[args.model_name + '_bw_score'], df_out['DMS_score']).statistic + seq_spearman = spearmanr(df_out[args.model_name + '_seq_score'], df_out['DMS_score']).statistic + print(assay_name, fw_spearman, bw_spearman, seq_spearman) + with open(summary_path, 'a') as f: + f.write(','.join([assay_name, str(fw_spearman), str(bw_spearman), str(seq_spearman)])) + if not args.msa: + df_out[args.model_name + '_score'] = df_out[args.model_name + '_seq_score'] + df_out.to_csv(df_out_file, index=False) + with open(summary_path, 'a') as f: + f.write('\n') + else: + msa_files = os.listdir(msa_dir) + protein_name = '_'.join(assay_name.split('_')[:2]) + if protein_name == 'ANCSZ_Hobbs': + protein_name = 'ANCSZ' + for msa_file in msa_files: + if msa_file.startswith(protein_name): + break + else: + print(protein_name) + seqs = parse_fasta(os.path.join(msa_dir, msa_file)) + collator = MSAARCollator(tokenizer, flip_prob=0.0) + if not args.no_seq: + replicates = 4 + else: + replicates = 1 + msas = [] + for rep in range(replicates): + msa_idx = np.random.choice(len(seqs) - 1, size=63) + 1 + msa = [seqs[i].replace('-', '').replace('.', '').upper() for i in msa_idx] + msa_src, msa_tgt = collator([[None, msa]]) + msa_src = msa_src.to(DEVICE) + msas.append(msa_src) + dl = DataLoader(ds, batch_size=1, collate_fn=fw_collator, num_workers=4, shuffle=False) + for i, batch in enumerate(dl): + src, tgt = batch + src = src.to(DEVICE) + tgt = tgt.to(DEVICE) + n, ell = src.shape + for rep, msa_src in enumerate(msas): + combined_src = torch.cat([msa_src, src], dim=1) + with torch.no_grad(): + outputs = model.module(combined_src)['logits'][0, -ell:-1] + tgt = tgt[0, 1:] + ce = torch.nn.functional.cross_entropy(outputs, tgt).item() + if replicates > 1: + df_out.loc[i, args.model_name + '_indel_score%d' %rep] = -ce + df_out.loc[i, args.model_name + '_indel_score' ] = 0 # have to create it first + df_out.loc[i, args.model_name + '_indel_score' ] += -ce / len(msas) + indel_spearman = spearmanr(df_out[args.model_name + '_indel_score'], df_out['DMS_score']).statistic + if not args.no_seq: + df_out[args.model_name + '_score'] = (df_out[args.model_name + '_seq_score'] + df_out[ + args.model_name + '_indel_score']) / 2 + en_spearman = spearmanr(df_out[args.model_name + '_score'], df_out['DMS_score']).statistic + print(assay_name, indel_spearman, en_spearman) + with open(summary_path, 'a') as f: + f.write(',' + ','.join([str(indel_spearman), str(en_spearman)]) + '\n') + else: + print(assay_name, indel_spearman) + with open(summary_path, 'a') as f: + f.write(','.join([assay_name, str(indel_spearman)]) + '\n') + df_out.to_csv(df_out_file, index=False) + + + + + + + +def train(args: argparse.Namespace) -> None: + print(f"Starting job on rank {RANK} with local rank {LOCAL_RANK} and world size {WORLD_SIZE}") + seed_everything(0) + + dist.init_process_group(backend="nccl") + # get the config, tokenizer, and model + torch.cuda.set_device(LOCAL_RANK) + config, tokenizer, model, block = load_msa_config_and_model(os.path.join(args.model_path, "config.json"), + use_flash_attention_2=(not args.no_fa2)) + print("Done initializing model.", RANK) + + # Load model and optimizer onto CPU + initial_epoch, total_steps, total_tokens, total_seqs, _ = load_checkpoint( + model, None, None, args.model_path, args.checkpoint_step, rank=RANK + ) + # Move only model to GPU + model = model.to(DEVICE) + model = model.to(torch.bfloat16) + model = model.eval() + + padding_idx = tokenizer.pad_id # PROTEIN_ALPHABET.index(PAD) + print("Using {} as padding index".format(padding_idx)) + print("Using {} as masking index".format(tokenizer.mask_id)) + print(f"Model has {sum(p.numel() for p in model.parameters())} trainable parameters.") + + + # Get files + subst_dir = os.path.join(args.data_root, "DMS_ProteinGym_substitutions") + indel_dir = os.path.join(args.data_root, "DMS_ProteinGym_indels") + msa_dir = os.path.join(args.data_root, "DMS_msa_files") + zero_shot(indel_dir, os.path.join(args.out_fpath, 'indels'), msa_dir, model, tokenizer, args) + zero_shot(subst_dir, os.path.join(args.out_fpath, 'substitutions'), msa_dir, model, tokenizer, args) + + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("out_fpath", type=str) + parser.add_argument("data_root", type=str) + parser.add_argument("model_path", type=str) + parser.add_argument("model_name", type=str) + parser.add_argument("checkpoint_step", type=int) + parser.add_argument("--no_fa2", action="store_true") + parser.add_argument("--msa", action="store_true") + parser.add_argument("--no_seq", action="store_true") + + + model_name_dict = { + 'jamba-3b-indel-gigaclust-120k-2': 'dayhoff-3b-msa-gigaref', + 'jamba-3b-cooldown': 'dayhoff-3b-msa-uniref90-cooldown', + 'jamba-3b-cooldown7': 'dayhoff-3b-msa-uniref90-cooldown', + 'jamba-170m-10mnovelty-36w': 'dayhoff-170m-1novelty', + 'jamba-170m-seq-36w': 'dayhoff-170m-uniref50', + 'jamba-170m-10mrmsd-36w': 'dayhoff-170m-rmsd', + 'jamba-170m-10mbothfilter-36w': 'dayhoff-170m-bothfilter', + 'jamba-3b-seq-sam-biar-fsdp-tok90k': 'dayhoff-3b-uniref90', + 'jamba-170m-10mnofilter-36w': 'dayhoff-170m-nofilter', + 'jamba-170m-seqsam-36w': 'dayhoff-170m-uniref90', + 'jamba-170m-gigaclust-36w': 'dayhoff-170m-gigaref' + } + + args = parser.parse_args() + if args.model_name in model_name_dict: + args.model_name = model_name_dict[args.model_name] + train(args) + + +if __name__ == "__main__": + main() + + + + diff --git a/dayhoff/collators.py b/dayhoff/collators.py index ff7eda3..32e237c 100644 --- a/dayhoff/collators.py +++ b/dayhoff/collators.py @@ -321,7 +321,8 @@ def __init__( pad_to_multiple_of: Optional[int] = None, flip_prob: int = 0.5, query_last_prob: int = 0.5, - trim_to: int = None + trim_to: int = None, + trim_to2: int = None ) -> None: self.tokenizer = tokenizer self.pad_to_mult = pad_to_multiple_of @@ -341,6 +342,7 @@ def __init__( self.flip_prob = flip_prob self.query_last_prob = query_last_prob self.trim_to = trim_to + self.trim_to2 = trim_to2 def __call__(self, batch_msa: "list") -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -425,6 +427,10 @@ def __call__(self, batch_msa: "list") -> Tuple[torch.Tensor, torch.Tensor, torch if ell > 512 * 64: src = src[:, :512 * 64] tgt = tgt[:, :512 * 64] + if self.trim_to2 is not None: + n, ell = src.shape + src = src[:, :self.trim_to2] + tgt = tgt[:, :self.trim_to2] return src, tgt diff --git a/dayhoff/datasets.py b/dayhoff/datasets.py index c2e6afb..84fe566 100644 --- a/dayhoff/datasets.py +++ b/dayhoff/datasets.py @@ -89,6 +89,19 @@ def msa_subsampling(sliced_msa, n_sequences, selection_type): return msa_sequences, random_idx # Returns aligned sequences and their indices +class ListDataset(Dataset): + + def __init__(self, data): + super().__init__() + self.data = data + + def __getitem__(self, item): + return (self.data[item], ) + + def __len__(self): + return len(self.data) + + class UniRefDataset(Dataset): """ Dataset that pulls from UniRef/Uniclust downloads. @@ -197,6 +210,30 @@ def __init__( self.indices = metadata.index.values.tolist() self.depths = metadata["depth"].values.tolist() self.lengths = metadata["length"].values.tolist() + if split == "rtest": + exclude = [ + "alignments_41/B8KZ16/merged.a3m", + "alignments_41/A0A1V6BYG1/merged.a3m", + "alignments_41/A3T2H5/merged.a3m", + "alignments_41/C0CU71/merged.a3m", + "alignments_41/U6L0S1/merged.a3m", + "alignments_41/A0A2D3P9P4/merged.a3m", + "alignments_41/C4YSB1/merged.a3m", + "alignments_41/W2Z7Y7/merged.a3m", + "alignments_41/A0A1G2B126/merged.a3m", + "alignments_41/A0A233HTX8/merged.a3m" + ] + keep_idx = [] + for i, fn in enumerate(self.filenames): + for exc in exclude: + if exc in fn: + break + else: + keep_idx.append(i) + self.filenames = [self.filenames[i] for i in keep_idx] + self.indices = [self.indices[i] for i in keep_idx] + self.depths = [self.depths[i] for i in keep_idx] + self.lengths = [self.lengths[i] for i in keep_idx] self.n_sequences = n_sequences self.max_seq_len = max_seq_len self.selection_type = selection_type diff --git a/src/cgenerate.py b/src/cgenerate.py new file mode 100644 index 0000000..6e2201b --- /dev/null +++ b/src/cgenerate.py @@ -0,0 +1,205 @@ +import argparse +import datetime +import json +import os +import random +from typing import Optional, Tuple +from tqdm import tqdm +import re + + +import numpy as np +from transformers import SuppressTokensLogitsProcessor + +import torch +from torch.utils.data import DataLoader, DistributedSampler + +from sequence_models.constants import START, STOP, CAN_AAS, SEP, GAP, MSA_PAD +from dayhoff.constants import UL_ALPHABET_PLUS, END_AL, END_UL, START_AL, START_UL +from dayhoff.utils import (load_msa_config_and_model, get_latest_dcp_checkpoint_path, + load_checkpoint, seed_everything) +from dayhoff.datasets import OpenProteinDataset +from dayhoff.collators import MSAARCollator + + +# default to a single-GPU setup if not present +RANK = int(os.environ["RANK"]) +WORLD_SIZE = int(os.environ["WORLD_SIZE"]) +DEVICE = torch.device(f"cuda:{RANK}") +print("device", DEVICE) + + +def get_msa_dataloader(config, tokenizer, args, task): + msa_data_dir = args.data_fpath + # load the dataset + num_replicas = WORLD_SIZE + sampler_rank = RANK + print("Preparing MSAs", flush=True) + no_query_frac = 0. + if "query" in task: + query_last_frac = 1.0 + else: + query_last_frac = 0. + if "rev" in task: + flip_prob = 1.0 + else: + flip_prob = 0.0 + if "indel" in task: + indel_frac = 1.0 + else: + indel_frac = 0.0 + ds_train = OpenProteinDataset( + msa_data_dir, + "rtest", + "max_hamming", + config["n_sequences"], + config["max_seq_len"], + gap_fraction=config["gap_fraction"], + is_amlt=True, + indel_frac=indel_frac, + no_query_frac=no_query_frac, + ) + + trim_to = config["msa_max_tokens"] + sampler = DistributedSampler(ds_train, num_replicas=num_replicas, rank=sampler_rank, shuffle=False) + num_workers = 8 + collater = MSAARCollator( + tokenizer=tokenizer, + pad_to_multiple_of=config["pad_to_multiple_of"], + query_last_prob=query_last_frac, + flip_prob=flip_prob, + trim_to=trim_to + ) + dl_train = DataLoader( + dataset=ds_train, + sampler=sampler, + collate_fn=collater, + num_workers=num_workers, + pin_memory=False, + batch_size=1 + ) + return ds_train, dl_train + + + +def generate(args: argparse.Namespace) -> None: + seed_everything(args.random_seed + RANK) + + # load model parameters from config file + config, tokenizer, model, block = load_msa_config_and_model(os.path.join(args.in_fpath, "config.json"), + use_flash_attention_2=True) + if args.verbose: + print("Done initializing model.", RANK) + + # Load model and optimizer onto CPU + initial_epoch, total_steps, total_tokens, total_seqs, _ = load_checkpoint( + model, None, None, args.in_fpath, args.checkpoint_step, rank=RANK + ) + # Move only model to GPU + model = model.to(DEVICE) + model = model.to(torch.bfloat16) + all_tokens = list(range(40)) + allowed_tokens = [UL_ALPHABET_PLUS.index(aa) for aa in CAN_AAS] + if "gap" in args.task: + allowed_tokens += [UL_ALPHABET_PLUS.index(GAP)] + if "query" in args.task: + if args.start_rev: + bos_id = UL_ALPHABET_PLUS.index(STOP) + eos_id = UL_ALPHABET_PLUS.index(START) + else: + bos_id = UL_ALPHABET_PLUS.index(START) + eos_id = UL_ALPHABET_PLUS.index(STOP) + elif "homologs" in args.task: + allowed_tokens += [UL_ALPHABET_PLUS.index(SEP)] + if "indel" in args.task: + if args.start_rev: + bos_id = UL_ALPHABET_PLUS.index(END_UL) + eos_id = UL_ALPHABET_PLUS.index(START_UL) + else: + bos_id = UL_ALPHABET_PLUS.index(START_UL) + eos_id = UL_ALPHABET_PLUS.index(END_UL) + elif "gap" in args.task: + if args.start_rev: + bos_id = UL_ALPHABET_PLUS.index(END_AL) + eos_id = UL_ALPHABET_PLUS.index(START_AL) + else: + bos_id = UL_ALPHABET_PLUS.index(START_AL) + eos_id = UL_ALPHABET_PLUS.index(END_AL) + else: + raise ValueError("Unknown task") + else: + raise ValueError("Unknown task") + max_len = config["n_sequences"] * config["max_seq_len"] + allowed_tokens += [eos_id] + seps = [SEP, START, STOP, END_UL, START_UL, END_AL, START_AL] + model.module.generation_config.eos_token_id = eos_id + sup = SuppressTokensLogitsProcessor([t for t in all_tokens if not t in allowed_tokens], device=DEVICE) + if args.start_rev: + task = args.task + ".rev" + else: + task = args.task + ".fwd" + out_dir = os.path.join(args.out_fpath, args.model_name + '_' + str(total_steps) + "_" + task + '_t%.1f' %args.temp) + if RANK == 0: + os.makedirs(out_dir, exist_ok=True) + ds, dl = get_msa_dataloader(config, tokenizer, args, task) + dl.sampler.set_epoch(0) + for i, batch in enumerate(tqdm(dl)): + filename = ".".join(ds.filenames[(i * WORLD_SIZE + RANK) % len(ds)].split("/")[-3:-1]) + ".fasta" + filename = os.path.join(out_dir, filename) + if os.path.exists(filename): + continue + print(filename, flush=True) + src, tgt = batch + src = src.to(DEVICE) + idx = torch.where(src == bos_id)[1] + if len(idx) == 0: + continue + prompt = src[:, :idx + 1] + generated = model.module.generate(prompt, do_sample=True, logits_processor=[sup], + temperature=args.temp, num_beams=1, max_new_tokens=max_len, + use_cache=True) + untokenized = tokenizer.untokenize(generated[0]) + print(untokenized) + for sep in seps: + untokenized = untokenized.replace(sep, " ") + untokenized = untokenized.split() + if "query" in args.task: + untokenized = untokenized[::-1] + print("\n".join(untokenized)) + with open(filename, "w") as f: + for i, seq in enumerate(untokenized): + if args.start_rev: + seq = seq[::-1] + f.write(">%d\n" %i) + f.write(seq + "\n") + if i == 0 and "query" in args.task: + f.write(">original_query\n") + f.write(tokenizer.untokenize(src[0, idx + 1:-1]) + "\n") + + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("in_fpath", type=str) # location of checkpoint + parser.add_argument("out_fpath", type=str) # location to write to + parser.add_argument("model_name", type=str) + parser.add_argument("data_fpath", type=str) # Location with MSAs + parser.add_argument("--verbose", action="store_true") + parser.add_argument("--checkpoint_step", type=int, default=-1) + parser.add_argument("--n_generations", type=int, default=1) + parser.add_argument("--task", type=str, default="query-indel") + parser.add_argument("--temp", type=float, default=1.0) # + parser.add_argument("--random_seed", type=int, default=0) # + parser.add_argument("--start_rev", action="store_true") + parser.add_argument("--dir", type=str, default="") + + args = parser.parse_args() + if args.dir == "fwd": + args.start_rev = False + elif args.dir == "rev": + args.start_rev = True + generate(args) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/generate.py b/src/generate.py index 4aaf484..0514c57 100644 --- a/src/generate.py +++ b/src/generate.py @@ -97,7 +97,6 @@ def generate(args: argparse.Namespace) -> None: allowed_tokens += [UL_ALPHABET_PLUS.index(GAP)] allowed_tokens += [eos_id] seps = [SEP, START, STOP, END_UL, START_UL, END_AL, START_AL] - seps_regex = "|".join(seps) start = torch.tensor([[start_seq]]).to(DEVICE) start = torch.repeat_interleave(start, args.batch_size, dim=0) model.module.generation_config.eos_token_id = eos_id @@ -109,6 +108,11 @@ def generate(args: argparse.Namespace) -> None: out_dir = os.path.join(args.out_fpath, args.model_name + '_' + str(total_steps) + "_" + task + '_t%.1f' %args.temp) if RANK == 0: os.makedirs(out_dir, exist_ok=True) + # if args.task == "sequence": + # # wipe the output file + # with open(os.path.join(out_dir, 'rank%d.fasta' % RANK), "w") as f: + # pass + for s in tqdm(range(args.n_generations // args.batch_size)): generated = model.module.generate(start, do_sample=True, logits_processor=[sup], temperature=args.temp, num_beams=1, max_new_tokens=max_len, @@ -117,7 +121,7 @@ def generate(args: argparse.Namespace) -> None: if args.task == "sequence": for n, unt in enumerate(untokenized): n_gen = s * args.batch_size + n - print(unt, flush=True) + # print(unt, flush=True) with open(os.path.join(out_dir, 'rank%d.fasta' %RANK), "a") as f: f.write(">%d_%d\n" %(RANK, n_gen)) if args.start_rev: diff --git a/src/valid_position.py b/src/valid_position.py new file mode 100644 index 0000000..bcf377f --- /dev/null +++ b/src/valid_position.py @@ -0,0 +1,191 @@ +import argparse +import datetime +import json +import os +import random +from typing import Optional, Tuple +from tqdm import tqdm +import re + + +import numpy as np +from transformers import SuppressTokensLogitsProcessor + +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader + +from sequence_models.constants import START, STOP, CAN_AAS, SEP, GAP, MSA_PAD +from dayhoff.constants import UL_ALPHABET_PLUS, END_AL, END_UL, START_AL, START_UL +from dayhoff.utils import (load_msa_config_and_model, get_latest_dcp_checkpoint_path, + load_checkpoint, seed_everything) +from dayhoff.collators import MSAARCollator +from dayhoff.datasets import UniRefDataset, OpenProteinDataset + +# default to a single-GPU setup if not present +RANK = int(os.environ["RANK"]) +#LOCAL_RANK = int(os.environ["LOCAL_RANK"]) +WORLD_SIZE = int(os.environ["WORLD_SIZE"]) +DEVICE = torch.device(f"cuda:{RANK}") +print("device", DEVICE) + + +def get_val_dataloader(config, tokenizer, args): + collator = MSAARCollator( + tokenizer=tokenizer, + pad_to_multiple_of=config["pad_to_multiple_of"], + query_last_prob=1.0, + flip_prob=args.flip_prob, + trim_to2=130000 + ) + batch_size = 256 + if args.task == 'gigaref': + data_seq_dir = args.data_root + 'gigaref/' + if args.train: + split = "train" + else: + split = "test" + ds_train = UniRefDataset(data_seq_dir, split, + max_len=config["max_len"], split_file=data_seq_dir + 'no_singletons/splits.json') + elif args.task in ['gap', 'indel']: + if args.task == 'gap': + indel_frac = 0.0 + elif args.task == 'indel': + indel_frac = 1.0 + ds_train = OpenProteinDataset( + args.data_root + "openfold/", + "valid", + "max_hamming", + 10000, + 768, + gap_fraction=config["gap_fraction"], + is_amlt=True, + indel_frac=indel_frac, + no_query_frac=config["no_query_frac"], + ) + batch_size = 2 + else: + if args.train: + split = "train" + else: + split = "valid" + data_seq_dir = args.data_root + 'uniref50_202401/' + ds_train = UniRefDataset(data_seq_dir, split, + max_len=config["max_len"]) + + sampler = torch.utils.data.DistributedSampler(ds_train, num_replicas=WORLD_SIZE, rank=RANK, shuffle=False) + dl = DataLoader( + dataset=ds_train, batch_size=batch_size, sampler=sampler, num_workers=8, collate_fn=collator, pin_memory=True + ) + + return dl + + +def validate(args: argparse.Namespace) -> None: + #print(f"Starting job on rank {RANK} with local rank {LOCAL_RANK} and world size {WORLD_SIZE}") + seed_everything(args.random_seed + RANK) + + # load model parameters from config file + config, tokenizer, model, block = load_msa_config_and_model(os.path.join(args.in_fpath, "config.json"), + use_flash_attention_2=True) + print("Done initializing model.", RANK) + + # Load model and optimizer onto CPU + initial_epoch, total_steps, total_tokens, total_seqs, _ = load_checkpoint( + model, None, None, args.in_fpath, args.checkpoint_step, rank=RANK + ) + print("Done loading model.", RANK, flush=True) + + # Move only model to GPU + model = model.to(DEVICE) + model = model.to(torch.bfloat16) + model = model.eval() + print("Getting dataloader.", RANK, flush=True) + dl = get_val_dataloader(config, tokenizer, args) + print("Done getting dataloader.", RANK, flush=True) + write_count = 0 + seqs = [] + ces = [] + for batch in tqdm(dl): + batch = [el.to(DEVICE) for el in batch] + src, tgt = batch + n, ell = src.shape + # step through model + with torch.no_grad(): + print(1, len(batch), src.shape, tgt.shape, flush=True) + outputs = model.module(src) + print(2, flush=True) + logits = outputs["logits"] + print(3, logits.shape, flush=True) + sliced_logits = logits[:, :-1, :].reshape(-1, logits.shape[-1]) + print(4, flush=True) + sliced_tgt = tgt[:, 1:].flatten() + print(5, sliced_logits.shape, sliced_tgt.shape, flush=True) + ce = F.cross_entropy(sliced_logits, sliced_tgt, reduction='none') + print(6, flush=True) + ce = ce.view(n, -1) + n, ell = ce.shape + if args.task not in ["gap", "indel"]: + diff = config['max_len'] + 1 - ell + else: + diff = 130000 - ell + ce = F.pad(ce.detach().cpu(), (0, diff)) + ces.append(ce) + print(7, flush=True) + for s in src: + seq = tokenizer.untokenize(s) + seq = "".join([i for i in seq if i.isalpha()]) + seqs.append(seq) + if args.train and len(seqs) > 1024000: + out_file = os.path.join(args.out_fpath, "train_" + args.model_name + '_' + str( + total_steps) + "_" + args.task + "_" + args.dir + "_%d_%d.pt" % (RANK, write_count)) + write_count += 1 + torch.save( + { + "sequence": seqs, + "ce": torch.cat(ces) + }, + out_file + ) + ces = [] + seqs = [] + if args.train: + out_file = os.path.join(args.out_fpath, "train_" + args.model_name + '_' + str( + total_steps) + "_" + args.task + "_" + args.dir + "_%d_%d.pt" % (RANK, write_count)) + else: + out_file = os.path.join(args.out_fpath, "valid_long_" + args.model_name + '_' + str( + total_steps) + "_" + args.task + "_" + args.dir + "_%d.pt" %RANK) + torch.save( + { + "sequence": seqs, + "ce": torch.cat(ces) + }, + out_file + ) + + + + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("in_fpath", type=str) # location of checkpoint + parser.add_argument("out_fpath", type=str) # location to write to + parser.add_argument("data_root", type=str) # location for data + parser.add_argument("model_name", type=str) + parser.add_argument("--checkpoint_step", type=int, default=-1) + parser.add_argument("--task", type=str, default="uniref") + parser.add_argument("--random_seed", type=int, default=0) # + parser.add_argument("--dir", type=str, default="forward") + parser.add_argument("--train", action="store_true") + args = parser.parse_args() + + if args.dir == "reverse": + args.flip_prob = 1.0 + else: + args.flip_prob = 0.0 + validate(args) + + +if __name__ == "__main__": + main() \ No newline at end of file From 0b2a4910a5e1106998737394ea74b2879b13d507 Mon Sep 17 00:00:00 2001 From: Kevin Kaichuang Yang Date: Wed, 12 Feb 2025 11:14:04 -0500 Subject: [PATCH 6/6] Make writing outputs more efficient. --- src/generate.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/generate.py b/src/generate.py index 0514c57..61348e4 100644 --- a/src/generate.py +++ b/src/generate.py @@ -15,7 +15,7 @@ from sequence_models.constants import START, STOP, CAN_AAS, SEP, GAP, MSA_PAD from dayhoff.constants import UL_ALPHABET_PLUS, END_AL, END_UL, START_AL, START_UL -from dayhoff.utils import (load_msa_config_and_model, get_latest_dcp_checkpoint_path, +from dayhoff.utils import (load_msa_config_and_model, load_checkpoint, seed_everything) @@ -112,21 +112,28 @@ def generate(args: argparse.Namespace) -> None: # # wipe the output file # with open(os.path.join(out_dir, 'rank%d.fasta' % RANK), "w") as f: # pass - + unwritten_generations = [] + unwritten_ns = [] for s in tqdm(range(args.n_generations // args.batch_size)): generated = model.module.generate(start, do_sample=True, logits_processor=[sup], temperature=args.temp, num_beams=1, max_new_tokens=max_len, use_cache=True) untokenized = [tokenizer.untokenize(g) for g in generated] if args.task == "sequence": + for n, unt in enumerate(untokenized): n_gen = s * args.batch_size + n - # print(unt, flush=True) - with open(os.path.join(out_dir, 'rank%d.fasta' %RANK), "a") as f: - f.write(">%d_%d\n" %(RANK, n_gen)) - if args.start_rev: - unt = unt[::-1] - f.write(unt.replace(START, "").replace(STOP, "").replace(MSA_PAD, "") + "\n") + if args.start_rev: + unt = unt[::-1] + unwritten_generations.append(unt) + unwritten_ns.append(n_gen) + if len(unwritten_generations) == 100: + with open(os.path.join(out_dir, 'rank%d_seed%d.fasta' % (RANK, args.random_seed)), "a") as f: + for uwg, nwn in zip(unwritten_generations, unwritten_ns): + f.write(">%d_%d_%d\n" % (RANK, nwn, args.random_seed)) + f.write(uwg.replace(START, "").replace(STOP, "").replace(MSA_PAD, "") + "\n") + unwritten_generations = [] + unwritten_ns = [] else: for n, unt in enumerate(untokenized): n_gen = s * args.batch_size + n