From d44927b4075b27bf1c0e9af595542c394d41a2ac Mon Sep 17 00:00:00 2001 From: Samir Char Date: Mon, 28 Apr 2025 13:22:56 +0000 Subject: [PATCH 1/2] generate from homolgs with hf --- src/generate_from_homologs.py | 126 ++++++++++++++++++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 src/generate_from_homologs.py diff --git a/src/generate_from_homologs.py b/src/generate_from_homologs.py new file mode 100644 index 0000000..22b2c1c --- /dev/null +++ b/src/generate_from_homologs.py @@ -0,0 +1,126 @@ +import argparse +import os +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed +from transformers import SuppressTokensLogitsProcessor +import torch +from sequence_models.constants import CAN_AAS, SEP, GAP +from dayhoff.constants import UL_ALPHABET_PLUS, START_AL, START_UL +from dayhoff.utils import seed_everything +from sequence_models.utils import parse_fasta +from glob import glob + + + + +def generate(args: argparse.Namespace) -> None: + seed_everything(args.random_seed) + set_seed(args.random_seed) + device = torch.device("cuda:%d" %args.device) + # load model and tokenizer + model = AutoModelForCausalLM.from_pretrained(args.repo_id, subfolder = args.model, use_flash_attention_2=not args.no_fa2) + tokenizer = AutoTokenizer.from_pretrained(args.repo_id, trust_remote_code=True) + + print("Done initializing model.") + print("%d parameters" %(sum(p.numel() for p in model.parameters()))) + + # Move only model to GPU + model = model.to(device) + model = model.to(torch.bfloat16) + all_tokens = list(range(40)) + allowed_tokens = [UL_ALPHABET_PLUS.index(aa) for aa in CAN_AAS] + if "gap" in args.task: + allowed_tokens += [UL_ALPHABET_PLUS.index(GAP)] + else: + # eos_id = UL_ALPHABET_PLUS.index(STOP) + eos_id = UL_ALPHABET_PLUS.index(SEP) + allowed_tokens += [eos_id] + model.generation_config.eos_token_id = eos_id + sup = SuppressTokensLogitsProcessor([t for t in all_tokens if not t in allowed_tokens], device=device) + os.makedirs(args.out_dir, exist_ok=True) + out_file = os.path.join(args.out_dir, args.model + '_%s_t%.1f_%.2f_nom.fasta' %(args.task, args.temp, args.min_p)) + msa_files = glob(os.path.join(args.msas_dir, args.include_pattern)) + if args.msa_file_names is not None: + msa_files = [os.path.join(args.msas_dir, msa_file) for msa_file in args.msa_file_names] + with open(out_file, 'w') as f: + for msa_path in tqdm(msa_files): + msa_filename = os.path.basename(msa_path) + seqs = parse_fasta(msa_path) + if len(seqs) < args.min_seqs_msa: + continue + if "gap" in args.task: + tokenize_me = START_AL + args.max_length = len(seqs[0]) - 1 + else: + tokenize_me = START_UL + tokenize_me += SEP.join(seqs[1:args.max_seqs_msa]) + SEP + # if "gap" in args.task: + # pass + # # tokenize_me += END_AL + # else: + # tokenize_me += END_UL + # tokenize_me += START + start_no_m = tokenizer([tokenize_me], return_tensors="pt", return_token_type_ids=False)['input_ids'].to(device) + tokenize_me += "M" + start = tokenizer([tokenize_me], return_tensors="pt", return_token_type_ids=False)['input_ids'].to(device) + success = False + attempt = 0 + while not success: + # if attempt % 2 == 0: + # st = start + # ml = args.max_length + # else: + # st = start_no_m + # ml = args.max_length + 1 + st = start_no_m + ml = args.max_length + 1 + generated = model.generate(st, do_sample=True, logits_processor=[sup], + temperature=args.temp, min_p=args.min_p, num_beams=1, + max_new_tokens=ml, + use_cache=True) + untokenized = tokenizer.batch_decode(generated, skip_special_tokens=False) + # new_seq = untokenized[0].split(START)[-1].split(STOP)[0] + if args.task == "gap": + new_seq = untokenized[0].split(SEP)[-1] + else: + new_seq = untokenized[0].split(SEP)[-2] + for k, seq in enumerate(seqs): + if new_seq in seq or seq in new_seq: + attempt += 1 + print(attempt, k, msa_filename, len(seqs), new_seq) + break + else: + success = True + f.write(">" + msa_filename[:-6] + "\n") + f.write(new_seq + "\n") + f.flush() + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, required=True, help="The model name.") + parser.add_argument("--msas-dir", type=str,required=True, help="The directory containing the MSAs.") + parser.add_argument("--out-dir", type=str, required=True,help="The directory to save the output.") + parser.add_argument("--task", type=str,required=True, choices=["gap", "sequence"], help="The task to perform.") + parser.add_argument("--repo-id", type=str, default='microsoft/dayhoff', help="The repository ID of the model.") + parser.add_argument("--include-pattern", type=str, default="*", help="glob pattern for MSA files to include from the directory.") + parser.add_argument("--msa-file-names",nargs='*', type=str, default=None, help="List of MSA file names to include.") + parser.add_argument("--max-length", type=int, default=768, help="The maximum length of the generated text.") + parser.add_argument("--max-seqs-msa", type=int, default=57, help="The maximum number of sequences in an MSA.") + parser.add_argument("--min-seqs-msa", type=int, default=5, help="The minimum number of sequences in an MSA.") + parser.add_argument("--temp", type=float, default=1.0, help="The temperature for sampling.") + parser.add_argument("--random-seed", type=int, default=0) + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--no-fa2", action="store_true",help="Disable FlashAttention 2") + parser.add_argument("--min-p", type=float, default=0.0, help= "Minimum probability for sampling.") + + + + args = parser.parse_args() + # Can only provide include pattern or msa file names, not both + if args.include_pattern != "*" and args.msa_file_names is not None: + raise ValueError("Provide either --include-pattern or --msa-file-names, not both.") + generate(args) + + +if __name__ == "__main__": + main() \ No newline at end of file From 3608c3310228aee412bd7a8c914808f8b116762b Mon Sep 17 00:00:00 2001 From: Samir Char Date: Tue, 6 May 2025 11:51:39 +0000 Subject: [PATCH 2/2] removed unnecessary comments --- src/generate_from_homologs.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/src/generate_from_homologs.py b/src/generate_from_homologs.py index 22b2c1c..81cdf92 100644 --- a/src/generate_from_homologs.py +++ b/src/generate_from_homologs.py @@ -32,7 +32,6 @@ def generate(args: argparse.Namespace) -> None: if "gap" in args.task: allowed_tokens += [UL_ALPHABET_PLUS.index(GAP)] else: - # eos_id = UL_ALPHABET_PLUS.index(STOP) eos_id = UL_ALPHABET_PLUS.index(SEP) allowed_tokens += [eos_id] model.generation_config.eos_token_id = eos_id @@ -54,24 +53,12 @@ def generate(args: argparse.Namespace) -> None: else: tokenize_me = START_UL tokenize_me += SEP.join(seqs[1:args.max_seqs_msa]) + SEP - # if "gap" in args.task: - # pass - # # tokenize_me += END_AL - # else: - # tokenize_me += END_UL - # tokenize_me += START start_no_m = tokenizer([tokenize_me], return_tensors="pt", return_token_type_ids=False)['input_ids'].to(device) tokenize_me += "M" start = tokenizer([tokenize_me], return_tensors="pt", return_token_type_ids=False)['input_ids'].to(device) success = False attempt = 0 while not success: - # if attempt % 2 == 0: - # st = start - # ml = args.max_length - # else: - # st = start_no_m - # ml = args.max_length + 1 st = start_no_m ml = args.max_length + 1 generated = model.generate(st, do_sample=True, logits_processor=[sup], @@ -79,7 +66,6 @@ def generate(args: argparse.Namespace) -> None: max_new_tokens=ml, use_cache=True) untokenized = tokenizer.batch_decode(generated, skip_special_tokens=False) - # new_seq = untokenized[0].split(START)[-1].split(STOP)[0] if args.task == "gap": new_seq = untokenized[0].split(SEP)[-1] else: