<a href="https://colab.research.google.com/github/leemnj/iterative-dna-reconstruction/blob/main/gen-seq-and-emb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Sequence generation + embeddings (sequential model loading)

This notebook fetches gene sequences, generates sequences per model,
then creates cross-embeddings while loading one model at a time.

In [1]:
%%capture
# Colab setup
!git clone https://github.com/leemnj/iterative-dna-reconstruction.git
%cd iterative-dna-reconstruction
!pip -q install -e .
!pip -q install biopython # Install Biopython
!pip list | grep idr

from pathlib import Path
RESULTS_DIR = Path("data")  # ÎòêÎäî "/content/drive/MyDrive/..."
RESULTS_DIR.mkdir(exist_ok=True)

fatal: destination path 'iterative-dna-reconstruction' already exists and is not an empty directory.
/content/iterative-dna-reconstruction
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
  Building editable for idr (pyproject.toml) ... [?25l[?25hdone
idr                                      0.1.0               /content/iterative-dna-reconstruction


In [2]:
import gc
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import torch
# import sys

# sys.path.append('src') # Add src directory to Python path

from idr.utils import get_device, iter_models, load_model
from idr.preparation import (
    fetch_gene_sequences,
    sort_genes_by_length,
    DEFAULT_DECODING_STRATEGIES,
    get_sequence_from_ensembl,
    DEFAULT_GENES,
)

print('‚úÖ Î™®Îì† Î™®Îìà ÏûÑÌè¨Ìä∏ ÏôÑÎ£å')

‚úÖ Î™®Îì† Î™®Îìà ÏûÑÌè¨Ìä∏ ÏôÑÎ£å


In [None]:
NCBI_EMAIL = 'dlalswo0321@gmail.com'
ITERATIONS = 30
MASK_RATIO = 0.2

RESULTS_DIR = Path('data')
SEQ_DIR = RESULTS_DIR / 'sequences'
EMB_DIR = RESULTS_DIR / 'embeddings'
SEQ_DIR.mkdir(parents=True, exist_ok=True)
EMB_DIR.mkdir(parents=True, exist_ok=True)

model_configs = {
    'DNABERT-2': 'zhihan1996/DNABERT-2-117M',
    'NT-v2-50m': 'InstaDeepAI/nucleotide-transformer-v2-50m-multi-species',
    'NT-v2-500m': 'InstaDeepAI/nucleotide-transformer-v2-500m-multi-species',
}
torch_dtype = 'float32'
model_labels = list(model_configs.keys())

print(f'NCBI email: {NCBI_EMAIL}')
print(f'ITERATIONS={ITERATIONS}, MASK_RATIO={MASK_RATIO}')
print(f'Î™®Îç∏: {model_labels}')


NCBI email: dlalswo0321@gmail.com
ITERATIONS=30, MASK_RATIO=0.2
Î™®Îç∏: ['DNABERT-2', 'NT-v2-50m', 'NT-v2-500m']


In [None]:
device = get_device()
print(f'ÏÇ¨Ïö© Ï§ëÏù∏ ÎîîÎ∞îÏù¥Ïä§: {device}')


PyTorch Version: 2.9.0+cu126
Using device: cuda
ÏÇ¨Ïö© Ï§ëÏù∏ ÎîîÎ∞îÏù¥Ïä§: cuda


In [None]:
from copy import deepcopy

gene_dict = deepcopy(DEFAULT_GENES)

print('ÏÑúÏó¥ Îã§Ïö¥Î°úÎìú ÏãúÏûë...')
for gene, info in gene_dict.items():
    try:
        seq = get_sequence_from_ensembl(info['id'], seq_type='genomic')
        gene_dict[gene]['sequence'] = seq
        print(f"[{gene}] ÏÑúÏó¥ ÌôïÎ≥¥ ÏôÑÎ£å (Í∏∏Ïù¥: {len(seq)} bp)")
    except Exception as e:
        print(f"[{gene}] ÏóêÎü¨ Î∞úÏÉù: {e}")

print('\n--- Í≤∞Í≥º ÏÉòÌîå (GAPDH) ---')
print(gene_dict['GAPDH']['sequence'][:50])


ÏÑúÏó¥ Îã§Ïö¥Î°úÎìú ÏãúÏûë...
[PTEN] ÏÑúÏó¥ ÌôïÎ≥¥ ÏôÑÎ£å (Í∏∏Ïù¥: 109293 bp)
[PTENP1] ÏÑúÏó¥ ÌôïÎ≥¥ ÏôÑÎ£å (Í∏∏Ïù¥: 16585 bp)
[GAPDH] ÏÑúÏó¥ ÌôïÎ≥¥ ÏôÑÎ£å (Í∏∏Ïù¥: 4919 bp)
[GAPDHP1] ÏÑúÏó¥ ÌôïÎ≥¥ ÏôÑÎ£å (Í∏∏Ïù¥: 1005 bp)
[HBB] ÏÑúÏó¥ ÌôïÎ≥¥ ÏôÑÎ£å (Í∏∏Ïù¥: 3932 bp)
[H19] ÏÑúÏó¥ ÌôïÎ≥¥ ÏôÑÎ£å (Í∏∏Ïù¥: 9388 bp)
[RPS29] ÏÑúÏó¥ ÌôïÎ≥¥ ÏôÑÎ£å (Í∏∏Ïù¥: 10008 bp)
[GAS5] ÏÑúÏó¥ ÌôïÎ≥¥ ÏôÑÎ£å (Í∏∏Ïù¥: 17517 bp)
[TP53] ÏÑúÏó¥ ÌôïÎ≥¥ ÏôÑÎ£å (Í∏∏Ïù¥: 25768 bp)
[SNHG1] ÏÑúÏó¥ ÌôïÎ≥¥ ÏôÑÎ£å (Í∏∏Ïù¥: 2044 bp)

--- Í≤∞Í≥º ÏÉòÌîå (GAPDH) ---
GTCACTACCGCAGAGCCTCGAGGAGAAGTTCCCCAACTTTCCCGCCTCTC


In [None]:
# --- 0. Í∏¥ Ïú†Ï†ÑÏûê Í∏∏Ïù¥ ÎßûÏ∂§ (Preprocessing) ---
# Colab GPU ÌïúÍ≥Ñ(ÏïΩ 5~6kb) Î∞è ÎπÑÍµê Ïã§ÌóòÏùò Í≥µÏ†ïÏÑ±ÏùÑ ÏúÑÌï¥ Í∏∏Ïù¥Î•º ÎßûÏ∂•ÎãàÎã§.
MAX_SEQ_LEN = 10000  # 6kb (Ïù¥ Ï†ïÎèÑÎ©¥ ÏïàÏ†ÑÌïòÍ≥† Ï∂©Î∂ÑÌï©ÎãàÎã§)

print("‚úÇÔ∏è Sequence Length Adjustment Check:")
for gene, info in gene_dict.items():
    if info['sequence'] is None:
        continue

    seq_len = len(info['sequence'])

    if seq_len > MAX_SEQ_LEN:
        # ÏÑúÏó¥Ïù¥ ÎÑàÎ¨¥ Í∏∏Î©¥ ÏïûÎ∂ÄÎ∂ÑÎßå ÎÇ®Í∏∞Í≥† ÏûêÎ¶Ñ
        info['sequence'] = info['sequence'][:MAX_SEQ_LEN]
        print(f"  - {gene}: {seq_len} bp -> {MAX_SEQ_LEN} bp (Cropped)")
    else:
        print(f"  - {gene}: {seq_len} bp (Keep)")

gene_selection = {
    gene: info['sequence'] for gene, info in gene_dict.items()
    if isinstance(info.get('sequence'), str) and info.get('sequence')
}


‚úÇÔ∏è Sequence Length Adjustment Check:
  - PTEN: 109293 bp -> 10000 bp (Cropped)
  - PTENP1: 16585 bp -> 10000 bp (Cropped)
  - GAPDH: 4919 bp (Keep)
  - GAPDHP1: 1005 bp (Keep)
  - HBB: 3932 bp (Keep)
  - H19: 9388 bp (Keep)
  - RPS29: 10008 bp -> 10000 bp (Cropped)
  - GAS5: 17517 bp -> 10000 bp (Cropped)
  - TP53: 25768 bp -> 10000 bp (Cropped)
  - SNHG1: 2044 bp (Keep)


In [None]:
import warnings
warnings.filterwarnings(
    "ignore",
    message="Increasing alibi size"
)

print('ÏãúÌÄÄÏä§ ÏÉùÏÑ± Ï§ë...')
for model_label, model_instance in iter_models(
    device, model_configs, torch_dtype=torch_dtype
):
    print(f'Î™®Îç∏: {model_label}')

    model_name = model_label.replace('/', '-')
    model_dir = SEQ_DIR / model_name
    model_dir.mkdir(parents=True, exist_ok=True)

    gene_iter = tqdm(
        gene_selection.items(),
        desc=f'{model_label} genes',
        leave=False,
    )

    for gene_id, original_sequence in gene_iter:
        output_csv = model_dir / f'{gene_id}.csv'
        results_data = {}

        for strategy_base_key, strategy_cfg in DEFAULT_DECODING_STRATEGIES.items():
            strategy_type = strategy_cfg['type']
            temperatures = strategy_cfg.get('temperatures', [1.0])
            top_k = strategy_cfg.get('top_k', 50)

            for temp in temperatures:
                if strategy_type == 'greedy':
                    strategy_key = strategy_base_key
                else:
                    strategy_key = f'{strategy_base_key}_t{temp}'

                generated_sequences = model_instance.run(
                    sequence=original_sequence,
                    steps=ITERATIONS,
                    mask_ratio=MASK_RATIO,
                    strategy=strategy_type,
                    temperature=temp,
                    top_k=top_k,
                    save_all=True,
                    save_interval=1,
                )

                results_data[strategy_key] = generated_sequences

                del generated_sequences
                gc.collect()
                if device == 'cuda':
                    torch.cuda.empty_cache()
                elif device == 'mps':
                    torch.mps.empty_cache()

        df = pd.DataFrame(results_data).T
        df.columns = [f'iteration_{i}' for i in range(df.shape[1])]
        df.to_csv(output_csv)

    del model_instance
    gc.collect()
    if device == 'cuda':
        torch.cuda.empty_cache()
    elif device == 'mps':
        torch.mps.empty_cache()

print('‚úÖ ÏãúÌÄÄÏä§ ÏÉùÏÑ± ÏôÑÎ£å')

removed = 0
for cache_file in EMB_DIR.glob('**/*.pkl'):
    try:
        cache_file.unlink()
        removed += 1
    except Exception as e:
        print(f'  Ï∫êÏãú ÏÇ≠Ï†ú Ïã§Ìå®: {cache_file} ({e})')
print(f'üßπ Í∏∞Ï°¥ ÏûÑÎ≤†Îî© Ï∫êÏãú ÏÇ≠Ï†ú ÏôÑÎ£å: {removed} files')


ÏãúÌÄÄÏä§ ÏÉùÏÑ± Ï§ë...
üì• Downloading DNABERT-2...


Fetching 12 files:   0%|          | 0/12 [00:00<?, ?it/s]

.gitattributes: 0.00B [00:00, ?B/s]

bert_padding.py: 0.00B [00:00, ?B/s]

README.md: 0.00B [00:00, ?B/s]

bert_layers.py: 0.00B [00:00, ?B/s]

configuration_bert.py: 0.00B [00:00, ?B/s]

LICENSE: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/904 [00:00<?, ?B/s]

flash_attn_triton.py: 0.00B [00:00, ?B/s]

generation_config.json:   0%|          | 0.00/90.0 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/158 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/468M [00:00<?, ?B/s]

‚úÖ DNABERT-2 Triton patch applied successfully.
[DNABERT-2] Loading model...
[DNABERT-2] Model loaded successfully.
‚úÖ DNABERT-2 loaded successfully.
Î™®Îç∏: DNABERT-2




[NT-v2-50m] Loading model...


tokenizer_config.json:   0%|          | 0.00/129 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/101 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

esm_config.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/InstaDeepAI/nucleotide-transformer-v2-50m-multi-species:
- esm_config.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


modeling_esm.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/InstaDeepAI/nucleotide-transformer-v2-50m-multi-species:
- modeling_esm.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors:   0%|          | 0.00/224M [00:00<?, ?B/s]

[NT-v2-50m] Model loaded successfully.
‚úÖ NT-v2-50m loaded successfully.
Î™®Îç∏: NT-v2-50m




[NT-v2-500m] Loading model...


tokenizer_config.json:   0%|          | 0.00/129 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/101 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

esm_config.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/InstaDeepAI/nucleotide-transformer-v2-500m-multi-species:
- esm_config.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


modeling_esm.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/InstaDeepAI/nucleotide-transformer-v2-500m-multi-species:
- modeling_esm.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors:   0%|          | 0.00/1.99G [00:00<?, ?B/s]

[NT-v2-500m] Model loaded successfully.
‚úÖ NT-v2-500m loaded successfully.
Î™®Îç∏: NT-v2-500m




‚úÖ ÏãúÌÄÄÏä§ ÏÉùÏÑ± ÏôÑÎ£å
üßπ Í∏∞Ï°¥ ÏûÑÎ≤†Îî© Ï∫êÏãú ÏÇ≠Ï†ú ÏôÑÎ£å: 0 files


In [None]:
# Î™®Îì† ÏãúÌÄÄÏä§ Î°úÎìú
all_sequences = {}

for model_key in model_labels:
    model_name = model_key.replace('/', '-')
    model_dir = SEQ_DIR / model_name
    all_sequences[model_name] = {}

    for csv_file in model_dir.glob('*.csv'):
        gene_id = csv_file.stem
        df = pd.read_csv(csv_file, index_col=0)
        sequences_dict = {}
        for strategy in df.index:
            sequences_dict[strategy] = df.loc[strategy].tolist()
        all_sequences[model_name][gene_id] = sequences_dict

print('‚úÖ ÏãúÌÄÄÏä§ Î°úÎìú ÏôÑÎ£å')


‚úÖ ÏãúÌÄÄÏä§ Î°úÎìú ÏôÑÎ£å


In [None]:
import warnings
warnings.filterwarnings(
    "ignore",
    message="Increasing alibi size"
)

def build_cross_embeddings(source_sequences, target_model_instance):
    cross_emb = {}
    for gene_id, strategies in source_sequences.items():
        cross_emb[gene_id] = {}
        for strategy, sequences in strategies.items():
            if not sequences:
                continue
            embeddings = []
            for seq in sequences:
                if pd.isna(seq) or seq == '':
                    continue
                emb = target_model_instance.get_embedding(str(seq))
                embeddings.append(emb)
            if embeddings:
                cross_emb[gene_id][strategy] = embeddings
            gc.collect()
            if device == 'cuda':
                torch.cuda.empty_cache()
            elif device == 'mps':
                torch.mps.empty_cache()
    return cross_emb

print('ÏûÑÎ≤†Îî© ÏÉùÏÑ± Ï§ë...')
generator_labels = list(model_labels)
evaluator_labels = list(model_labels)

cross_embeddings = {}

for eval_label in evaluator_labels:
    model_instance = load_model(
        device,
        eval_label,
        model_configs[eval_label],
        torch_dtype=torch_dtype,
    )

    # if eval_label == 'DNABERT-2':
    #     _ = model_instance.get_embedding('A' * 1024)


    for gen_label in generator_labels:
        cache_name = f'embeddings_{gen_label}__by__{eval_label}.pkl'
        cache_path = EMB_DIR / cache_name

        print(f'ÏûÑÎ≤†Îî© ÏÉùÏÑ±: {gen_label} sequences ‚Üí {eval_label} evaluator')
        cross_emb = build_cross_embeddings(
            all_sequences[gen_label.replace('/', '-')],
            model_instance,
        )
        cross_embeddings[(gen_label, eval_label)] = cross_emb
        with open(cache_path, 'wb') as f:
            import pickle
            pickle.dump(cross_emb, f, protocol=4)
        print(f'  ‚úì Ï†ÄÏû•: {cache_path}')

    del model_instance
    gc.collect()
    if device == 'cuda':
        torch.cuda.empty_cache()
    elif device == 'mps':
        torch.mps.empty_cache()

print('‚úÖ ÏûÑÎ≤†Îî© ÏÉùÏÑ± ÏôÑÎ£å')


ÏûÑÎ≤†Îî© ÏÉùÏÑ± Ï§ë...
üì• Downloading DNABERT-2...


Fetching 12 files:   0%|          | 0/12 [00:00<?, ?it/s]

‚úÖ DNABERT-2 Triton patch applied successfully.
[DNABERT-2] Loading model...
[DNABERT-2] Model loaded successfully.
ÏûÑÎ≤†Îî© ÏÉùÏÑ±: DNABERT-2 sequences ‚Üí DNABERT-2 evaluator


KeyboardInterrupt: 

In [None]:
import os
from google.colab import files

# Define the parent directory to be zipped
results_dir = "data"

# Check if the results directory exists and contains files/subdirectories
if os.path.exists(results_dir) and (os.listdir(results_dir) or os.path.isdir(results_dir)):
    print(f"Zipping '{results_dir}'...")
    !zip -r /content/data.zip {results_dir}
    print(f"‚úÖ '{results_dir}' zipped to /content/data.zip")
    files.download('/content/data.zip')
else:
    print(f"Skipping '{results_dir}' as it is empty or does not exist.")

print("Download process initiated.")