# Pseudogene vs Real Gene Sequence Reconstruction

This notebook loads DNABERT-2 and NT-v2-500m, fetches sequences for
real-pseudogene pairs, applies masking and reconstruction, and saves
the generated sequences to result/sequences.


In [None]:
import sys
import gc
from pathlib import Path

import pandas as pd
import torch

# Add project paths for module imports
current_dir = Path('.').resolve()
project_root = current_dir.parent
dna_module_path = project_root / 'dna-model-collapse'
for path in [str(project_root), str(dna_module_path)]:
    if path not in sys.path:
        sys.path.insert(0, path)

from preparation import get_device, load_models
from sequence_generation import fetch_gene_sequences, DEFAULT_GENES

print('Imports ready')


In [None]:
# Define real-pseudogene pairs
GENE_PAIRS = [
    ('GAPDH', 'GAPDHP1'),
    ('PTEN', 'PTENP1'),
    ('TPI1', 'TPI1P1'),
]

selected_genes = {}
for real_gene, pseudo_gene in GENE_PAIRS:
    for gene_name in (real_gene, pseudo_gene):
        if gene_name not in DEFAULT_GENES:
            raise KeyError(f'Missing gene in DEFAULT_GENES: {gene_name}')
        # Copy metadata to avoid mutating defaults
        selected_genes[gene_name] = dict(DEFAULT_GENES[gene_name])

print('Selected genes:')
for name, meta in selected_genes.items():
    print(f'- {name}: id={meta.get("id")}, status={meta.get("status")}, type={meta.get("type")}')


In [None]:
# Set your email for NCBI Entrez
EMAIL = 'your_email@example.com'
if not EMAIL or 'example.com' in EMAIL:
    raise ValueError('Set EMAIL to your NCBI-registered email address')

gene_sequences = fetch_gene_sequences(EMAIL, gene_uids=selected_genes)
print(f'Loaded {len(gene_sequences)} sequences')


In [None]:
# Load models
device = get_device()
model_configs = {
    'DNABERT-2': 'zhihan1996/DNABERT-2-117M',
    'NT-v2-500m': 'InstaDeepAI/nucleotide-transformer-v2-500m-multi-species',
}
models = load_models(device, model_configs=model_configs)


In [None]:
# Generate sequences and save results
MASK_RATIOS = [0.15, 0.30, 0.45]
ITERATIONS = 20
STRATEGY = 'sampling'
TEMPERATURE = 1.0
TOP_K = 50

RESULTS_DIR = Path('result/sequences')
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

print('Generating sequences...')
for model_label, model_instance in models.items():
    model_name = model_label.replace('/', '-')
    model_dir = RESULTS_DIR / model_name
    model_dir.mkdir(parents=True, exist_ok=True)

    print(f'\nModel: {model_label}')
    for gene_id, original_sequence in gene_sequences.items():
        print(f'  {gene_id}...', end='', flush=True)
        results_data = {}

        for mask_ratio in MASK_RATIOS:
            strategy_key = f'sampling_t{TEMPERATURE}_m{mask_ratio}'
            generated_sequences = model_instance.run(
                sequence=original_sequence,
                steps=ITERATIONS,
                mask_ratio=mask_ratio,
                strategy=STRATEGY,
                temperature=TEMPERATURE,
                top_k=TOP_K,
                save_all=True,
                save_interval=1,
            )
            results_data[strategy_key] = generated_sequences

            del generated_sequences
            gc.collect()
            if device == 'cuda':
                torch.cuda.empty_cache()
            elif device == 'mps':
                torch.mps.empty_cache()

        df = pd.DataFrame(results_data).T
        df.columns = [f'iteration_{i}' for i in range(df.shape[1])]

        output_csv = model_dir / f'{gene_id}.csv'
        df.to_csv(output_csv)
        print(f' saved {len(df)} strategies')

print(f'\nDone. Results saved under: {RESULTS_DIR}')
