# Sequence generation + embeddings (sequential model loading)

This notebook fetches gene sequences, generates sequences per model,
then creates cross-embeddings while loading one model at a time.

In [6]:
import sys
import gc
import pandas as pd
from pathlib import Path
from tqdm.notebook import tqdm
import torch

current_dir = Path('.').resolve()
project_root = current_dir.parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

from preparation import get_device, iter_models, load_model
from sequence_generation import (
    fetch_gene_sequences,
    sort_genes_by_length,
    DEFAULT_DECODING_STRATEGIES,
)

print('✅ 모든 모듈 임포트 완료')


✅ 모든 모듈 임포트 완료


In [7]:
NCBI_EMAIL = 'dlalswo0321@gmail.com'
ITERATIONS = 30
MASK_RATIO = 0.2

RESULTS_DIR = Path('results')
SEQ_DIR = RESULTS_DIR / 'sequences'
EMB_DIR = RESULTS_DIR / 'embeddings'
SEQ_DIR.mkdir(parents=True, exist_ok=True)
EMB_DIR.mkdir(parents=True, exist_ok=True)

model_configs = {
    'DNABERT-2': 'zhihan1996/DNABERT-2-117M',
    'NT-v2-50m': 'InstaDeepAI/nucleotide-transformer-v2-50m-multi-species',
    'NT-v2-500m': 'InstaDeepAI/nucleotide-transformer-v2-500m-multi-species',
}
torch_dtype = 'float32'
model_labels = list(model_configs.keys())

print(f'NCBI email: {NCBI_EMAIL}')
print(f'ITERATIONS={ITERATIONS}, MASK_RATIO={MASK_RATIO}')
print(f'모델: {model_labels}')


NCBI email: dlalswo0321@gmail.com
ITERATIONS=30, MASK_RATIO=0.2
모델: ['DNABERT-2', 'NT-v2-50m', 'NT-v2-500m']


In [8]:
device = get_device()
print(f'사용 중인 디바이스: {device}')


PyTorch Version: 2.9.1
Using device: mps
사용 중인 디바이스: mps


In [None]:

gene_selection = fetch_gene_sequences(NCBI_EMAIL)
gene_selection = sort_genes_by_length(gene_selection)

print('수집된 유전자:')
for gene, seq in gene_selection.items():
    print(f'  - {gene}: {len(seq)} bp')


PyTorch Version: 2.9.1
Using device: mps
사용 중인 디바이스: mps
Fetching gene sequences from NCBI...
Fetching H4C1 by UID NM_003538...
  ✅ Found and added sequence for H4C1 (Length: 402bp)
Fetching TP53 by UID NM_000546...
  ✅ Found and added sequence for TP53 (Length: 2512bp)
Fetching HBB by UID NM_000518...
  ✅ Found and added sequence for HBB (Length: 628bp)
Fetching HOXC11 by UID NM_014212...
  ✅ Found and added sequence for HOXC11 (Length: 3261bp)
Fetching HOTAIR by UID NR_003716...
  ✅ Found and added sequence for HOTAIR (Length: 2273bp)
Fetching VEGFA by UID NM_003376...
  ✅ Found and added sequence for VEGFA (Length: 3609bp)
Fetching NEAT1 by UID NR_003513...
  ✅ Found and added sequence for NEAT1 (Length: 3190bp)
Fetching NORAD by UID NR_027451...
  ✅ Found and added sequence for NORAD (Length: 5378bp)
Fetching STAT3 by UID NM_139276...
  ✅ Found and added sequence for STAT3 (Length: 4921bp)
Fetching PTEN by UID NM_000314...
  ✅ Found and added sequence for PTEN (Length: 8515bp)
Fetc

In [4]:
print('시퀀스 생성 중...')
for model_label, model_instance in iter_models(
    device, model_configs, torch_dtype=torch_dtype
):
    print(f'모델: {model_label}')

    model_name = model_label.replace('/', '-')
    model_dir = SEQ_DIR / model_name
    model_dir.mkdir(parents=True, exist_ok=True)

    gene_iter = tqdm(
        gene_selection.items(),
        desc=f'{model_label} genes',
        leave=False,
    )

    for gene_id, original_sequence in gene_iter:
        output_csv = model_dir / f'{gene_id}.csv'
        results_data = {}

        for strategy_base_key, strategy_cfg in DEFAULT_DECODING_STRATEGIES.items():
            strategy_type = strategy_cfg['type']
            temperatures = strategy_cfg.get('temperatures', [1.0])
            top_k = strategy_cfg.get('top_k', 50)

            for temp in temperatures:
                if strategy_type == 'greedy':
                    strategy_key = strategy_base_key
                else:
                    strategy_key = f'{strategy_base_key}_t{temp}'

                generated_sequences = model_instance.run(
                    sequence=original_sequence,
                    steps=ITERATIONS,
                    mask_ratio=MASK_RATIO,
                    strategy=strategy_type,
                    temperature=temp,
                    top_k=top_k,
                    save_all=True,
                    save_interval=1,
                )

                results_data[strategy_key] = generated_sequences

                del generated_sequences
                gc.collect()
                if device == 'cuda':
                    torch.cuda.empty_cache()
                elif device == 'mps':
                    torch.mps.empty_cache()

        df = pd.DataFrame(results_data).T
        df.columns = [f'iteration_{i}' for i in range(df.shape[1])]
        df.to_csv(output_csv)

    del model_instance
    gc.collect()
    if device == 'cuda':
        torch.cuda.empty_cache()
    elif device == 'mps':
        torch.mps.empty_cache()

print('✅ 시퀀스 생성 완료')

removed = 0
for cache_file in EMB_DIR.glob('**/*.pkl'):
    try:
        cache_file.unlink()
        removed += 1
    except Exception as e:
        print(f'  캐시 삭제 실패: {cache_file} ({e})')
print(f'🧹 기존 임베딩 캐시 삭제 완료: {removed} files')


시퀀스 생성 중...
📥 Downloading DNABERT-2...


Fetching 12 files:   0%|          | 0/12 [00:00<?, ?it/s]

✅ DNABERT-2 Triton patch applied successfully.
[DNABERT-2] Loading model...
[DNABERT-2] Model loaded successfully.
✅ DNABERT-2 loaded successfully.
모델: DNABERT-2
  H4C1...



 ✓ (4 strategies)
  HBB... ✓ (4 strategies)
  GAPDHP1... ✓ (4 strategies)
  GBAP1...

KeyboardInterrupt: 

In [9]:
# 모든 시퀀스 로드
all_sequences = {}

for model_key in model_labels:
    model_name = model_key.replace('/', '-')
    model_dir = SEQ_DIR / model_name
    all_sequences[model_name] = {}

    for csv_file in model_dir.glob('*.csv'):
        gene_id = csv_file.stem
        df = pd.read_csv(csv_file, index_col=0)
        sequences_dict = {}
        for strategy in df.index:
            sequences_dict[strategy] = df.loc[strategy].tolist()
        all_sequences[model_name][gene_id] = sequences_dict

print('✅ 시퀀스 로드 완료')


✅ 시퀀스 로드 완료


In [10]:
def build_cross_embeddings(source_sequences, target_model_instance):
    cross_emb = {}
    for gene_id, strategies in source_sequences.items():
        cross_emb[gene_id] = {}
        for strategy, sequences in strategies.items():
            if not sequences:
                continue
            embeddings = []
            for seq in sequences:
                if pd.isna(seq) or seq == '':
                    continue
                emb = target_model_instance.get_embedding(str(seq))
                embeddings.append(emb)
            if embeddings:
                cross_emb[gene_id][strategy] = embeddings
            gc.collect()
            if device == 'cuda':
                torch.cuda.empty_cache()
            elif device == 'mps':
                torch.mps.empty_cache()
    return cross_emb

print('임베딩 생성 중...')
generator_labels = list(model_labels)
evaluator_labels = list(model_labels)

cross_embeddings = {}

for eval_label in evaluator_labels:
    model_instance = load_model(
        device,
        eval_label,
        model_configs[eval_label],
        torch_dtype=torch_dtype,
    )

    if eval_label == 'DNABERT-2':
        _ = model_instance.get_embedding('A' * 1024)


    for gen_label in generator_labels:
        cache_name = f'embeddings_{gen_label}__by__{eval_label}.pkl'
        cache_path = EMB_DIR / cache_name

        print(f'임베딩 생성: {gen_label} sequences → {eval_label} evaluator')
        cross_emb = build_cross_embeddings(
            all_sequences[gen_label.replace('/', '-')],
            model_instance,
        )
        cross_embeddings[(gen_label, eval_label)] = cross_emb
        with open(cache_path, 'wb') as f:
            import pickle
            pickle.dump(cross_emb, f, protocol=4)
        print(f'  ✓ 저장: {cache_path}')

    del model_instance
    gc.collect()
    if device == 'cuda':
        torch.cuda.empty_cache()
    elif device == 'mps':
        torch.mps.empty_cache()

print('✅ 임베딩 생성 완료')


임베딩 생성 중...
📥 Downloading DNABERT-2...


Fetching 12 files:   0%|          | 0/12 [00:00<?, ?it/s]

✅ DNABERT-2 Triton patch applied successfully.
[DNABERT-2] Loading model...




[DNABERT-2] Model loaded successfully.
임베딩 생성: DNABERT-2 sequences → DNABERT-2 evaluator




KeyboardInterrupt: 