In [2]:
import random
import pickle
import pandas as pd
import pyranges as pr
from collections import Counter
from deap import base, creator, tools
import os
import multiprocessing
from multiprocessing import Manager
import uuid


In [3]:

os.chdir('/gpfs1/tangfuchou_pkuhpc/tangfuchou_coe/jiangzh/cellTransformer/code/evol/')
from pred import CellTransformer, load_msgpack, map_guide_to_peaks

os.chdir('/gpfs1/tangfuchou_pkuhpc/tangfuchou_coe/jiangzh/cellTransformer/GET/')

# =========================================
# Configuration
# =========================================
CONFIG = {
    "population_size": 50,
    "num_generations": 50,
    "crossover_prob": 0.7,
    "mutation_prob": 0.3,
    "tournament_size": 5,
    "set_size": 3
}

KMER_CONFIG = {
    "max_global_hits": 1e4,
    "min_local_hits": 3,
    "min_effective_hits": 1
}

# Global references
kmer_to_peak = None
roi = None
fitness_cache = None
threads = 8  # Default number of threads


  "cipher": algorithms.TripleDES,
  "class": algorithms.TripleDES,


In [4]:


# =========================================
# Data loading
# =========================================
def init_worker(kmer_data, roi_data, cache_dict):
    global kmer_to_peak, roi, fitness_cache
    kmer_to_peak = kmer_data
    roi = roi_data
    fitness_cache = cache_dict

def load_kmer_index():
    return load_msgpack("./data/index/hg38_CATlas_cCREs.9mer.kmer_to_peak_freq.msgpack")

def load_data():
    """Load ROI and k-mer datasets for the GA search."""
    try:
        roi = pr.PyRanges(
            chromosomes=["chr19"],
            starts=[int(55115750 - 5e5)],
            ends=[int(55115750 + 5e5)]
        )
        with open('./data/intervention/hg38_CATlas_cCREs.9mer.kmer_cov_by_peak.pkl', 'rb') as f:
            kmer_cov = pickle.load(f)

        with open('./data/intervention/AAVS1_1Mb_9mer_HEK293T_novel.pkl', 'rb') as f:
            roi_kmer = pickle.load(f)

        # Filter out extremely common kmers
        kmer_cov_flt = {k: v for k, v in kmer_cov.items() if v <= KMER_CONFIG["max_global_hits"]}
        roi_kmer_counts = Counter(k for k in roi_kmer if k in kmer_cov_flt)
        roi_kmer_list = [k for k, count in roi_kmer_counts.items() if count >= KMER_CONFIG["min_local_hits"]]

        if not roi_kmer_list:
            raise ValueError("No valid k-mers found after filtering")

        print(f"Available k-mers: {len(roi_kmer_list)}")
        return roi, roi_kmer_list
    except Exception as e:
        print(f"Error loading data: {e}")
        raise


# =========================================
# Helper: fix duplicates
# =========================================
def fix_duplicates(individual, kmer_list):
    """Ensure no duplicate kmers in an individual."""
    seen = set()
    for i in range(len(individual)):
        if individual[i] in seen:
            choices = [k for k in kmer_list if k not in seen]
            individual[i] = random.choice(choices)
        seen.add(individual[i])
    return individual


# =========================================
# Fitness function with caching
# =========================================
def fitness(individual):
    """Evaluate individual fitness using CellTransformer, with caching."""

    key = tuple(individual)
    if key in fitness_cache:
        return (fitness_cache[key],)

    peak_hits = map_guide_to_peaks(
        individual, 
        kmer_to_peak, 
        hit_threshold=KMER_CONFIG["min_effective_hits"]
    )

    run_id = uuid.uuid4() 

    ct = CellTransformer(
        guide_list=list(individual),
        peak_hits=peak_hits,
        target_gene=["GFP"],
        output_dir="./data/get_tmp/",
        celltype="HEK293T",
        insert_transgene=True,
        prediction_scope=roi,
        run_id=run_id,
        motif_bed="../resource/hg38.archetype_motifs.v1.0.bed.gz",
        zarr_path="./data/zarr/HEK293T_hPGK1_AAVS1.zarr",
    )

    score = ct.predict()    

    fitness_cache[key] = score
    return (score,)


# =========================================
# Custom mutation
# =========================================
def custom_mutation(individual, indpb, kmer_list):
    """Randomly replace k-mers in the individual with probability indpb."""
    for i in range(len(individual)):
        if random.random() < indpb:
            new_kmer = random.choice(kmer_list)
            while new_kmer in individual:  # avoid duplicates
                new_kmer = random.choice(kmer_list)
            individual[i] = new_kmer
    return individual,


# =========================================
# Main GA procedure
# =========================================
def main():
    global kmer_to_peak, roi, fitness_cache, threads

    # Load data
    kmer_to_peak = load_kmer_index()
    roi, roi_kmer_list = load_data()

    manager = Manager()
    fitness_cache = manager.dict()

    # DEAP setup
    creator.create("FitnessMax", base.Fitness, weights=(1.0,))
    creator.create("Individual", list, fitness=creator.FitnessMax)

    toolbox = base.Toolbox()
    toolbox.register("attr_kmer", lambda: random.choice(roi_kmer_list))
    toolbox.register("individual", tools.initIterate, creator.Individual,
                     lambda: random.sample(roi_kmer_list, CONFIG["set_size"]))
    toolbox.register("population", tools.initRepeat, list, toolbox.individual)

    toolbox.register("mate", tools.cxUniform, indpb=0.5)
    toolbox.register("mutate", custom_mutation, kmer_list=roi_kmer_list, indpb=CONFIG["mutation_prob"])
    toolbox.register("select", tools.selTournament, tournsize=CONFIG["tournament_size"])
    toolbox.register("evaluate", fitness)

    # Enable multiprocessing for faster evaluation
    pool = multiprocessing.Pool(
        processes=threads,
        initializer=init_worker,
        initargs=(kmer_to_peak, roi, fitness_cache)
    )
    toolbox.register("map", pool.map)

    # Initialize population & statistics
    pop = toolbox.population(n=CONFIG["population_size"])
    hof = tools.HallOfFame(1)
    stats = tools.Statistics(lambda ind: ind.fitness.values)
    stats.register("avg", lambda fits: sum(f[0] for f in fits) / len(fits))
    stats.register("max", lambda fits: max(f[0] for f in fits))
    logbook = tools.Logbook()
    logbook.header = ["gen", "best_fitness", "avg_fitness"]

    # Initialize CSV file for storing individual data
    output_file = f"ga_individuals_{uuid.uuid4()}.csv"
    pd.DataFrame(columns=["Generation", "Individual", "Fitness"]) \
        .to_csv(output_file, index=False)

    # Evolutionary loop
    for gen in range(CONFIG["num_generations"]):
        offspring = toolbox.select(pop, len(pop))
        offspring = list(map(toolbox.clone, offspring))

        # Crossover
        for child1, child2 in zip(offspring[::2], offspring[1::2]):
            if random.random() < CONFIG["crossover_prob"]:
                toolbox.mate(child1, child2)
                fix_duplicates(child1, roi_kmer_list)
                fix_duplicates(child2, roi_kmer_list)
                del child1.fitness.values
                del child2.fitness.values

        # Mutation
        for mutant in offspring:
            if random.random() < CONFIG["mutation_prob"]:
                toolbox.mutate(mutant)
                fix_duplicates(mutant, roi_kmer_list)
                del mutant.fitness.values

        # Fitness evaluation
        invalid_ind = [ind for ind in offspring if not ind.fitness.valid]
        fitnesses = toolbox.map(toolbox.evaluate, invalid_ind)
        for ind, fit in zip(invalid_ind, fitnesses):
            ind.fitness.values = fit

        # Save individual data to CSV using pandas
        df = pd.DataFrame([
            [gen, str(ind), ind.fitness.values[0]]
            for ind in offspring
        ], columns=["generation", "individual", "fitness"])
        df.to_csv(output_file, mode='a', header=False, index=False)

        pop[:] = offspring
        hof.update(pop)

        # Logging
        record = stats.compile(pop)
        logbook.record(gen=gen, best_fitness=hof[0].fitness.values[0], avg_fitness=record["avg"])
        print(logbook.stream)

    pool.close()
    pool.join()
    print(f"Individual data saved to {output_file}")
    return hof[0]


In [5]:
individual=['GGCAGGGGG', 'GAGGGAGGA', 'AGCAGCAGC']
roi, roi_kmer_list = load_data()

Available k-mers: 655


In [6]:
global kmer_to_peak, roi, fitness_cache, threads

# Load data
kmer_to_peak = load_kmer_index()

peak_hits = map_guide_to_peaks(
    individual, 
    kmer_to_peak, 
    hit_threshold=KMER_CONFIG["min_effective_hits"]
)

# save hits in pickle for test
with open("./data/get_tmp/peak_hits.pkl", "wb") as f:
    pickle.dump(peak_hits, f)


Loading ./data/index/hg38_CATlas_cCREs.9mer.kmer_to_peak_freq.msgpack...


In [7]:
# load peak hits from pickle
with open("./data/get_tmp/peak_hits.pkl", "rb") as f:
    peak_hits = pickle.load(f)

run_id = uuid.uuid4() 

ct = CellTransformer(
    guide_list=list(individual),
    peak_hits=peak_hits,
    target_gene=["GFP"],
    output_dir="./data/get_tmp/",
    celltype="HEK293T",
    insert_transgene=True,
    prediction_scope=roi,
    run_id=run_id,
    motif_bed="../resource/hg38.archetype_motifs.v1.0.bed.gz",
    zarr_path="./data/zarr/HEK293T_hPGK1_AAVS1.zarr",
    num_region_per_sample=202 # debug
)

score = ct.predict()  

print(score)

Running prediction for HEK293T with run ID 82fa1969-ef31-4313-a16b-f4c73c054850...


[W::hts_idx_load3] The index file is older than the data file: ../resource/hg38.archetype_motifs.v1.0.bed.gz.tbi


Read 282 motifs from main zarr.
Written increment 'activation_82fa1969-ef31-4313-a16b-f4c73c054850' to ./data/zarr/HEK293T_hPGK1_AAVS1.zarr/added/activation_82fa1969-ef31-4313-a16b-f4c73c054850 with shape (24, 282)
TSS and dummy expression annotated for 'added/activation_82fa1969-ef31-4313-a16b-f4c73c054850' with celltype 'HEK293T'.




Load ckpt from ./data/checkpoints/finetune_fetal_adult_leaveout_astrocyte.checkpoint-best.pth
Load state_dict by model_key = model


/gpfs1/tangfuchou_pkuhpc/tangfuchou_coe/jiangzh/software/anaconda3/envs/get/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /gpfs1/tangfuchou_pkuhpc/tangfuchou_coe/jiangzh/soft ...


Initial number of peaks: 101589
Total peaks after adding: 101613
No 'deleted' group found in Zarr file. No peaks removed.
Loaded region motifs for celltype HEK293T: 101613 peaks


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 41.57it/s]

Leave out chromosomes: []
Input chromosomes: ['chr1', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr2', 'chr20', 'chr21', 'chr22', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chrX']



/gpfs1/tangfuchou_pkuhpc/tangfuchou_coe/jiangzh/software/anaconda3/envs/get/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:362: The dirpath has changed from '/mnt/e/SHARE/cellTransformer/GET/data/get_output/HEK293T/regionEmb_head_finetune_binary/checkpoints' to '/gpfs1/tangfuchou_pkuhpc/tangfuchou_coe/jiangzh/cellTransformer/GET/data/get_tmp/HEK293T/intervention_82fa1969-ef31-4313-a16b-f4c73c054850/checkpoints', therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded.
/gpfs1/tangfuchou_pkuhpc/tangfuchou_coe/jiangzh/software/anaconda3/envs/get/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=47` in the `DataLoader` to improve performance.


Predicting: |                                                                                                 …

1.6332805156707764
