In [None]:
import requests
import json
import time
import numpy as np
import os
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import load_model
import pandas as pd
from tqdm import tqdm
import torch
from Bio import SeqIO
from util import *
from framepool import *
# Importing utility functions from the original code
# Assuming these functions are defined in your util.py module
# from util import recover_seq, rev_rna_vocab, encode_seq_framepool, one_hot_all_motif
# from framepool import load_framepool

# For demonstration, we'll add placeholder functions since we don't have the actual util.py
def reverse_complement(sequence):
    """Compute the reverse complement of a DNA sequence."""
    complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C', 
                  'a': 't', 't': 'a', 'c': 'g', 'g': 'c', 'N': 'N', 'n': 'N'}
    return ''.join(complement.get(base, 'N') for base in reversed(sequence))




def reverse_complement(sequence):
    """Compute the reverse complement of a DNA sequence."""
    complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C', 
                  'a': 't', 't': 'a', 'c': 'g', 'g': 'c', 'N': 'N', 'n': 'N'}
    return ''.join(complement.get(base, 'N') for base in reversed(sequence))

def one_hot(seq):
    """Convert sequences to one-hot encoding."""
    convert = True
    if isinstance(seq, tf.Tensor):
        seq = seq.numpy().astype(str)
        convert = True

    num_seqs = len(seq)
    seq_len = len(seq[0])
    seqindex = {'A':0, 'C':1, 'G':2, 'T':3, 'a':0, 'c':1, 'g':2, 't':3}
    seq_vec = np.zeros((num_seqs, seq_len, 4), dtype='bool')
    for i in range(num_seqs):
        thisseq = seq[i]
        for j in range(seq_len):
            try:
                seq_vec[i, j, seqindex[thisseq[j]]] = 1
            except:
                pass
    
    if convert:
        seq_vec = tf.convert_to_tensor(seq_vec, dtype=tf.float32)

    return seq_vec

class GeneInfoRetriever:
    def __init__(self):
        self.base_url = "https://rest.ensembl.org"
        self.headers = {"Content-Type": "application/json"}
        self.sleep_time = 0.5  # Respect Ensembl API rate limits
        
        # Create cache directory if it doesn't exist
        os.makedirs('./.cache/', exist_ok=True)
        os.makedirs('./outputs/', exist_ok=True)

    def _make_request(self, endpoint):
        """Make a request to the Ensembl REST API."""
        url = self.base_url + endpoint
        try:
            response = requests.get(url, headers=self.headers)
            time.sleep(self.sleep_time)
            if response.status_code == 200:
                return response.json()
            else:
                print(f"Error: {response.status_code} - {response.text}")
                return None
        except Exception as e:
            print(f"Request error: {e}")
            return None

    def get_gene_id(self, gene_symbol, species="homo_sapiens"):
        """Retrieve the Ensembl gene ID for a gene symbol."""
        endpoint = f"/lookup/symbol/{species}/{gene_symbol}"
        response = self._make_request(endpoint)
        return response.get("id") if response else None

    def get_gene_coordinates(self, gene_id):
        """Retrieve genomic coordinates for a gene ID."""
        endpoint = f"/lookup/id/{gene_id}?expand=1"
        response = self._make_request(endpoint)
        if response:
            return {
                "chromosome": response.get("seq_region_name"),
                "start": response.get("start"),
                "end": response.get("end"),
                "strand": response.get("strand")
            }
        return None

    def get_tss_and_utr(self, gene_id):
        """Retrieve TSS and 5' UTR coordinates for the canonical transcript."""
        endpoint = f"/lookup/id/{gene_id}?expand=1&utr=1"
        response = self._make_request(endpoint)
        if not response or "Transcript" not in response:
            return None

        # Find canonical transcript
        canonical_transcript = None
        for transcript in response["Transcript"]:
            if transcript.get("is_canonical", 0) == 1:
                canonical_transcript = transcript
                break
        if not canonical_transcript:
            for transcript in response["Transcript"]:
                if transcript.get("biotype") == "protein_coding":
                    canonical_transcript = transcript
                    break
        if not canonical_transcript:
            canonical_transcript = response["Transcript"][0] if response["Transcript"] else None

        if not canonical_transcript:
            return None

        # Determine TSS and 5' UTR
        strand = canonical_transcript.get("strand")
        tss = canonical_transcript["start"] if strand == 1 else canonical_transcript["end"]
        five_prime_utr = None

        if "UTR" in canonical_transcript:
            for utr in canonical_transcript["UTR"]:
                if utr.get("object_type") == "five_prime_UTR":
                    five_prime_utr = {
                        "start": utr.get("start"),
                        "end": utr.get("end")
                    }
                    break

        # Verify TSS matches 5' UTR start
        if five_prime_utr:
            expected_tss = five_prime_utr["start"] if strand == 1 else five_prime_utr["end"]
            if expected_tss != tss:
                print(f"Warning: Adjusting TSS from {tss} to match 5' UTR {'start' if strand == 1 else 'end'} ({expected_tss})")
                tss = expected_tss

        return {
            "tss": tss,
            "strand": strand,
            "chromosome": canonical_transcript.get("seq_region_name"),
            "five_prime_utr": five_prime_utr,
            "transcript_id": canonical_transcript.get("id")
        }

    def get_promoter_sequence(self, gene_id, upstream=8000, downstream=4000):
        """Retrieve sequence around TSS (8kb upstream, 4kb downstream)."""
        tss_info = self.get_tss_and_utr(gene_id)
        if not tss_info:
            return None, None

        chromosome = tss_info["chromosome"]
        strand = tss_info["strand"]
        tss_position = tss_info["tss"]

        # Calculate region based on strand
        if strand == 1:
            seq_start = tss_position - upstream
            seq_end = tss_position + downstream - 1
        else:
            seq_start = tss_position - downstream
            seq_end = tss_position + upstream - 1

        seq_start = max(1, seq_start)

        # Store sequence coordinates
        sequence_coords = {
            "chromosome": chromosome,
            "start": seq_start,
            "end": seq_end,
            "strand": 1 if strand == 1 else -1
        }

        # Validate 5' UTR inclusion
        if tss_info["five_prime_utr"]:
            utr_start = tss_info["five_prime_utr"]["start"]
            utr_end = tss_info["five_prime_utr"]["end"]
            if not (seq_start <= utr_start <= seq_end and seq_start <= utr_end <= seq_end):
                print(f"Warning: 5' UTR ({utr_start}-{utr_end}) not fully within sequence ({seq_start}-{seq_end})")

        # Get sequence
        strand_str = "1" if strand == 1 else "-1"
        endpoint = f"/sequence/region/human/{chromosome}:{seq_start}..{seq_end}:{strand_str}"
        response = self._make_request(endpoint)
        return response.get("seq") if response else None, sequence_coords

    def get_gene_info(self, gene_symbol, species="homo_sapiens", output_json="gene_info.json"):
        """Retrieve and save promoter sequence, TSS, 5' UTR, and coordinates."""
        cache_file = os.path.join('./.cache/', f"{gene_symbol}_info.json")
        
        if not os.path.exists(cache_file):
            # Get gene ID
            gene_id = self.get_gene_id(gene_symbol, species)
            if not gene_id:
                return {"error": f"Gene {gene_symbol} not found"}

            # Get TSS and 5' UTR
            tss_info = self.get_tss_and_utr(gene_id)
            if not tss_info:
                return {"error": "Could not retrieve TSS or transcript information"}

            # Get promoter sequence and coordinates
            promoter_sequence, sequence_coords = self.get_promoter_sequence(gene_id)
            if not promoter_sequence:
                return {"error": "Could not retrieve promoter sequence"}

            # Compile gene information
            gene_info = {
                "gene_symbol": gene_symbol,
                "gene_id": gene_id,
                "promoter_sequence": promoter_sequence,
                "sequence_length": len(promoter_sequence),
                "sequence_coordinates": sequence_coords,
                "tss": {
                    "chromosome": tss_info["chromosome"],
                    "position": tss_info["tss"],
                    "strand": "+" if tss_info["strand"] == 1 else "-"
                },
                "five_prime_utr": tss_info["five_prime_utr"],
                "transcript_id": tss_info["transcript_id"]
            }

            # Save to JSON
            try:
                with open(cache_file, "w") as f:
                    json.dump(gene_info, f, indent=2)
                print(f"Saved gene information to {cache_file}")
            except Exception as e:
                print(f"Error saving JSON: {e}")
        else:
            with open(cache_file, "r") as f:
                gene_info = json.load(f)

        return gene_info

    def replace_utr_in_sequence(self, gene_info_file, generated_utrs, target_length=10500, output_prefix="modified_sequence", write_json=False, verbose=False):
        """
        Replace original 5' UTR with generated UTRs, ensuring target_length output.
        """
        try:
            # Read gene information
            with open(gene_info_file, "r") as f:
                gene_info = json.load(f)

            original_sequence = gene_info["promoter_sequence"]
            strand = gene_info["tss"]["strand"]
            tss_position = gene_info["tss"]["position"]
            sequence_coords = gene_info["sequence_coordinates"]
            seq_start = sequence_coords["start"]
            seq_end = sequence_coords["end"]
            five_prime_utr = gene_info["five_prime_utr"]
            gene_symbol = gene_info["gene_symbol"]
            transcript_id = gene_info["transcript_id"]

            if not five_prime_utr:
                print(f"Error: No 5' UTR information available for {gene_symbol}")
                return []

            # Calculate original 5' UTR position in sequence
            if strand == "+":
                utr_start_genomic = five_prime_utr["start"]
                utr_end_genomic = five_prime_utr["end"]
                utr_start_seq = utr_start_genomic - seq_start
                utr_end_seq = utr_end_genomic - seq_start
            else:
                utr_start_genomic = five_prime_utr["end"]  # TSS
                utr_end_genomic = five_prime_utr["start"]
                utr_start_seq = seq_end - utr_start_genomic
                utr_end_seq = seq_end - utr_end_genomic

            # Validate UTR positions
            seq_length = len(original_sequence)
            if not (0 <= utr_start_seq <= seq_length and 0 <= utr_end_seq <= seq_length):
                print(f"Error: 5' UTR coordinates (seq indices {utr_start_seq}-{utr_end_seq}) out of sequence bounds (0-{seq_length}) for {gene_symbol}")
                return []

            original_utr_length = abs(utr_end_genomic - utr_start_genomic) + 1
            if verbose:
                print(f"Original 5' UTR length for {gene_symbol}: {original_utr_length} nt")

            modified_sequences = []
            for i, new_utr in enumerate(generated_utrs):
                new_utr_length = len(new_utr)
                if not 64 <= new_utr_length <= 128:
                    if verbose:
                        print(f"Warning: Generated UTR {i+1} length ({new_utr_length}) outside 64-128nt range for {gene_symbol}")
                    continue

                # Construct new sequence
                if strand == "+":
                    new_sequence = (
                        original_sequence[:utr_start_seq] +
                        new_utr +
                        original_sequence[utr_end_seq + 1:]
                    )
                    new_utr_start_genomic = utr_start_genomic
                    new_utr_end_genomic = utr_start_genomic + new_utr_length - 1
                    if len(new_sequence) > target_length:
                        new_sequence = new_sequence[:target_length]
                        sequence_coords["end"] = seq_start + target_length - 1
                    elif len(new_sequence) < target_length:
                        if verbose:
                            print(f"Error: Sequence too short ({len(new_sequence)} nt) after UTR replacement for {gene_symbol}")
                        continue
                else:
                    new_utr_rc = reverse_complement(new_utr)
                    new_sequence = (
                        original_sequence[:min(utr_start_seq, utr_end_seq)] +
                        new_utr_rc +
                        original_sequence[max(utr_start_seq, utr_end_seq) + 1:]
                    )
                    new_utr_start_genomic = utr_start_genomic
                    new_utr_end_genomic = utr_start_genomic - new_utr_length + 1
                    if len(new_sequence) > target_length:
                        trim_amount = len(new_sequence) - target_length
                        new_sequence = new_sequence[trim_amount:]
                        sequence_coords["start"] = seq_start + trim_amount
                    elif len(new_sequence) < target_length:
                        if verbose:
                            print(f"Error: Sequence too short ({len(new_sequence)} nt) after UTR replacement for {gene_symbol}")
                        continue

                # Store modified sequence and metadata
                modified_info = {
                    "gene_symbol": gene_symbol,
                    "transcript_id": transcript_id,
                    "modified_sequence": new_sequence,
                    "sequence_length": len(new_sequence),
                    "sequence_coordinates": sequence_coords.copy(),
                    "tss": gene_info["tss"],
                    "five_prime_utr": {
                        "start": new_utr_start_genomic,
                        "end": new_utr_end_genomic,
                        "sequence": new_utr if strand == "+" else new_utr_rc
                    },
                    "original_utr_length": original_utr_length,
                    "new_utr_length": new_utr_length,
                    "utr_index": i + 1
                }

                # Save to JSON
                if write_json:
                    output_file = f"{output_prefix}_{gene_symbol}_utr_{i+1}.json"
                    try:
                        os.makedirs(os.path.dirname(output_file), exist_ok=True)
                        with open(output_file, "w") as f:
                            json.dump(modified_info, f, indent=2)
                        print(f"Saved modified sequence {i+1} for {gene_symbol} to {output_file}")
                    except Exception as e:
                        print(f"Error saving modified sequence {i+1} for {gene_symbol}: {e}")

                modified_sequences.append(modified_info["modified_sequence"])

            return modified_sequences

        except Exception as e:
            print(f"Error processing UTR replacement for {gene_info.get('gene_symbol', 'unknown')}: {e}")
            return []

    def replace_utr_in_multiple_sequences(self, gene_symbols, generated_utrs, target_length=10500, cache_dir="./.cache", output_prefix="modified_sequence", verbose=False):
        """
        Replace 5' UTRs for multiple genes with generated UTRs.
        """
        all_modified_sequences = []
        n_utrs = len(generated_utrs)
        n_genes = len(gene_symbols)

        for gene_symbol in gene_symbols:
            json_file = os.path.join(cache_dir, f"{gene_symbol}_info.json")
            if not os.path.exists(json_file):
                print(f"Error: Gene info file {json_file} not found")
                continue
            
            if verbose:
                print(f"\nProcessing gene: {gene_symbol}")
            modified_sequences = self.replace_utr_in_sequence(
                gene_info_file=json_file,
                generated_utrs=generated_utrs,
                target_length=target_length,
                output_prefix=os.path.join(cache_dir, output_prefix)
            )

            if modified_sequences:
                all_modified_sequences.extend(modified_sequences)
            else:
                if verbose:
                    print(f"No modified sequences generated for {gene_symbol}")

        expected_count = n_utrs * n_genes
        actual_count = len(all_modified_sequences)
        if verbose:
            print(f"\nGenerated {actual_count} modified sequences (expected: {expected_count})")

        return all_modified_sequences

def convert_model(model_):
    # print(model_.summary())
    input_ = tf.keras.layers.Input(shape=( 10500, 4))
    input = input_
    for i in range(len(model_.layers)-1):

        # print(type(model_.layers[i+1]))
        
        if isinstance(model_.layers[i+1],tf.keras.layers.Concatenate):
            paddings = tf.constant([[0,0],[0,6]])
            output = tf.pad(input, paddings, 'CONSTANT')
            input = output
        else:
            if not isinstance(model_.layers[i+1],tf.keras.layers.InputLayer):
                output = model_.layers[i+1](input)
                input = output

            if isinstance(model_.layers[i+1],tf.keras.layers.Conv1D):
                pass

    model = tf.keras.Model(inputs=input_, outputs=output)
    model.compile(loss="mse", optimizer="adam")
    return model

def select_best(scores, seqs, gc_control=False, GC=-1):
    """Select best sequences based on scores."""
    selected_scores = []
    selected_seqs = []
    for i in range(len(scores[0])):
        best = scores[1][i]
        best_seq = seqs[1][i]
        for j in range(len(scores)-1):
            if scores[j+1][i] > best:
                best = scores[j+1][i]
                best_seq = seqs[j+1][i]

        selected_scores.append(best)
        selected_seqs.append(best_seq)

    return selected_seqs, selected_scores

def select_best_per_gene(scores, seqs, gc_control=False, GC=-1):
    """Select best sequences based on scores along the other axis."""
    """Select best sequences based on scores."""
    selected_scores = []
    selected_seqs = []
    for i in range(len(scores[0])):
        best = np.mean(scores[1][i])
        best_seq = seqs[1][i]
        for j in range(len(scores)-1):
            if np.mean(scores[j+1][i]) > best:
                best = np.mean(scores[j+1][i])
                best_seq = seqs[j+1][i]

        selected_scores.append(best)
        selected_seqs.append(best_seq)

    return selected_seqs, selected_scores

def optimize_gene_set(initial_genes, target_genes, batch_size=100, steps=50, lr=0.001, verbose=False):
    """Optimize 5' UTR sequences for a set of genes and evaluate on target genes."""
    # Constants and model paths
    UTR_LEN = 128
    DIM = 40
    N_GENES = 8
    BATCH_SIZE = batch_size
    STEPS = steps

    SEQ_BATCH = N_GENES
    UTR_LEN = 128
    DIM = 40
    LR = lr  # Convert to the exponential format used in the original code
    gpath = './../../models/checkpoint_3000.h5'  # GAN model
    exp_path = './../../models/humanMedian_trainepoch.11-0.426.h5'  # Expression model
    mrl_path = './../../models/utr_model_combined_residual_new.h5'  # MRL model

    # Set device for TensorFlow
    if torch.cuda.is_available():
        os.environ['CUDA_VISIBLE_DEVICES'] = '0'
        device = 'cuda'
    else:
        device = 'cpu'
    
    model = tf.keras.models.load_model(exp_path)

    model = convert_model(model)

    wgan = tf.keras.models.load_model(gpath)

    """
    Data:
    """

    gene_names = ["MYOC", "TIGD4", "ATP6V1B2", "TAGLN", "COX7A2L", "IFNGR2", "TNFRSF21", "SETD6"]

    noise = tf.Variable(tf.random.normal(shape=[BATCH_SIZE,40]))

    tf.random.set_seed(25)

    diffs = []
    init_exps = []

    opt_exps = []

    orig_vals = []

    noise = tf.Variable(tf.random.normal(shape=[BATCH_SIZE,40]))
    noise_small = tf.random.normal(shape=[BATCH_SIZE,40],stddev=1e-5)

    optimizer = tf.keras.optimizers.Adam(learning_rate=LR)

    '''
    Optimization takes place here.
    '''

    bind_scores_list = []
    bind_scores_means = []
    sequences_list = []

    means = []
    maxes = []

    iters_ = []

    OPTIMIZE = True

    DNA_SEL = False

    retriever = GeneInfoRetriever()
    refs = []
    for i in range(len(gene_names)):
        output_json = f"{gene_names[i]}_info.json"

        if not os.path.exists(os.path.join('./.cache/',output_json)):

            # Retrieve gene information
            gene_info = retriever.get_gene_info(gene_names[i], output_json=output_json)

            if "error" in gene_info:
                print(f"Error: {gene_info['error']}")
            else:
                refs.append(gene_info["promoter_sequence"])    
        else:
            with open(os.path.join('./.cache/',output_json), "r") as f:
                gene_info = json.load(f)
                refs.append(gene_info["promoter_sequence"])

    sequences_init = wgan(noise)

    gen_seqs_init = sequences_init.numpy().astype('float')

    seqs_gen_init = recover_seq(gen_seqs_init, rev_rna_vocab)

    seqs_init = retriever.replace_utr_in_multiple_sequences(gene_names, seqs_gen_init, target_length=10500, cache_dir="./.cache", output_prefix="modified_sequence")

    tf.shape(seqs_init)

    seqs_init = one_hot(seqs_init)

    pred_init = model(seqs_init) 

    pred_init = tf.reshape(pred_init,(SEQ_BATCH,-1))

    initial_exp_per_gene = tf.reduce_mean(pred_init,axis=1)

    init_t = tf.reduce_mean(pred_init,axis=0)

    init_t = init_t.numpy().astype('float')


    ########### MRL and TE check before optimization ###################

    # seqs_mrl = tf.convert_to_tensor(np.array([encode_seq_framepool(seq) for seq in seqs_gen_init]),dtype=tf.float32)
    # seqs_te =  torch.transpose(torch.tensor(np.array(one_hot_all_motif(seqs_gen_init),dtype=np.float32)),2,1).float().to(device)

    # mrl_preds_init = mrl_model(seqs_mrl).numpy().astype('float')
    # te_preds_init = te_model.forward(seqs_te).cpu().data.numpy()

    ####################################################################

    STEPS = STEPS

    seqs_collection = []
    seqs_collection_genes = []
    scores_collection = []
    scores_collection_genes = []
    if OPTIMIZE:

        
        iter_ = 0
        for opt_iter in tqdm(range(STEPS)):
            
            with tf.GradientTape() as gtape:
                gtape.watch(noise)
                
                sequences = wgan(noise)

                seqs_gen = recover_seq(sequences, rev_rna_vocab)
                seqs_collection.append(seqs_gen)

                g1_ = tf.zeros_like(sequences)

                scores_collection_temp = []
                means_temp = []
                maxes_temp = []


                for gene in gene_names:

                    seqs_dna = retriever.replace_utr_in_sequence(f"./.cache/{gene}_info.json", seqs_gen, target_length=10500, output_prefix="modified_sequence")               
                
                    seqs = one_hot(seqs_dna)
                    
                    with tf.GradientTape() as ptape:
                        ptape.watch(seqs)

                        pred =  model(seqs)
                        t = tf.reshape(pred,(-1))
                        mx = np.amax(t.numpy().astype('float'),axis=0)
                        mx = np.max(mx)
                        

                        scores_collection_temp.append(t.numpy().astype('float'))
                        nt = t.numpy().astype('float')
                        maxes_temp.append(mx)
                        means_temp.append(np.sum(t)/BATCH_SIZE)

                    g1 = ptape.gradient(pred,seqs)
                    g1 = tf.math.scalar_mul(-1.0, g1)
                    g1 = tf.slice(g1,[0,7000,0],[-1,128,-1])

                    tmp_g = g1.numpy().astype('float')
                    tmp_seqs = seqs_gen

                    # # Before the loop
                    # print("Length of tmp_seqs:", len(tmp_seqs))
                    # print("Shape of tmp_g:", tmp_g.shape)

                    # Initialize tmp_lst with correct size
                    batch_size = min(len(tmp_seqs), tmp_g.shape[0])
                    tmp_lst = np.zeros(shape=(batch_size, 128, 5))

                    # Loop with safe range
                    for i in range(batch_size):
                        len_ = min(len(tmp_seqs[i]), tmp_g.shape[1])  # Prevent exceeding tmp_g's dimensions
                        edited_g = tmp_g[i][:len_, :]
                        edited_g = np.pad(edited_g, ((0, 128-len_), (0, 1)), 'constant')
                        tmp_lst[i] = edited_g

                    g1 = tf.convert_to_tensor(tmp_lst, dtype=tf.float32)

                    g1_ = tf.math.add(g1, g1_)

                scores_collection.append(np.mean(scores_collection_temp,axis=0))
                scores_collection_genes.append(scores_collection_temp)
                means.append(np.mean(means_temp))
                maxes.append(np.max(maxes_temp))

                g2 = gtape.gradient(sequences,noise,output_gradients=g1_)


            a1 = g2 + noise_small
            change = [(a1,noise)]

            optimizer.apply_gradients(change)

            iters_.append(iter_)
            iter_ += 1

        sequences_opt = wgan(noise)

        gen_seqs_opt = sequences_opt.numpy().astype('float')

        seqs_gen_opt = recover_seq(gen_seqs_opt, rev_rna_vocab)

        seqs_opt = retriever.replace_utr_in_multiple_sequences(gene_names, seqs_gen_opt, target_length=10500, cache_dir="./.cache", output_prefix="modified_sequence")

        seqs_opt = one_hot(seqs_opt)

        pred_opt = model(seqs_opt)

        pred_opt = tf.reshape(pred_opt,(SEQ_BATCH,-1))


        t = tf.reduce_mean(pred_opt,axis=0)
        opt_t = t.numpy().astype('float')

    ########### MRL and TE check after optimization ####################

    # seqs_mrl = tf.convert_to_tensor(np.array([encode_seq_framepool(seq) for seq in seqs_gen_opt]),dtype=tf.float32)
    # seqs_te =  torch.transpose(torch.tensor(np.array(one_hot_all_motif(seqs_gen_opt),dtype=np.float32)),2,1).float().to(device)

    # mrl_preds_opt = mrl_model(seqs_mrl).numpy().astype('float')
    # te_preds_opt = te_model.forward(seqs_te).cpu().data.numpy()

    ####################################################################

    best_seqs, best_scores = select_best(scores_collection, seqs_collection)
    _, best_scores_per_gene = select_best_per_gene(scores_collection_genes, seqs_collection)

    with open('./outputs/mul_init_exps.txt', 'w') as f:
        for item in init_t:
            f.write(f'{item}\n')

    with open('./outputs/mul_best_exps.txt', 'w') as f:
        for item in best_scores:
            f.write(f'{item}\n')

    with open('./outputs/mul_opt_exps.txt', 'w') as f:
        for item in opt_t:
            f.write(f'{item}\n')

    with open('./outputs/mul_best_seqs.txt', 'w') as f:
        for item in best_seqs:
            f.write(f'{item}\n')

    with open('./outputs/mul_init_seqs.txt', 'w') as f:
        for item in seqs_gen_init:
            f.write(f'{item}\n')

    # Compute average Log TPM per gene
    init_log_tpm_initial = tf.reduce_mean(pred_init, axis=1).numpy().astype('float')
    opt_log_tpm_initial = tf.reduce_mean(pred_opt, axis=1).numpy().astype('float')
    opt_log_tpm_initial = best_scores

    # Compute overall average Log TPM across target genes
    avg_init_log_tpm_initial = np.average(init_log_tpm_initial)
    avg_opt_log_tpm_initial = np.average(opt_log_tpm_initial)

    # Convert Log TPM to TPM for percentage improvement
    # Assuming Log TPM is base-10 (common for TPM), TPM = 10^LogTPM
    avg_init_tpm_initial = np.power(10, avg_init_log_tpm_initial)
    avg_opt_tpm_initial = np.power(10, avg_opt_log_tpm_initial)

    # Compute improvement
    log_tpm_diff_initial = avg_opt_log_tpm_initial - avg_init_log_tpm_initial
    tpm_improvement_initial = avg_opt_tpm_initial - avg_init_tpm_initial
    # Percentage improvement based on TPM: ((opt - init) / init) * 100
    if avg_init_tpm_initial != 0:  # Avoid division by zero
        tpm_percent_change_initial = (tpm_improvement_initial / avg_init_tpm_initial) * 100
    else:
        tpm_percent_change_initial = float('inf') if tpm_improvement_initial > 0 else 0.0

    # Handle negative and positive percentages
    percent_str = f"{tpm_percent_change_initial:.2f}%"
    if tpm_percent_change_initial < 0:
        percent_str = f"{tpm_percent_change_initial:.2f}% (decrease)"
    elif tpm_percent_change_initial > 0:
        percent_str = f"+{tpm_percent_change_initial:.2f}% (increase)"

    # Print evaluation results
    print("\nEvaluation of Optimization on Original Genes (Log TPM):")
    print("\nExpression Levels (Log TPM):")
    print(f"  Average Initial Log TPM: {avg_init_log_tpm_initial:.4f} (TPM: {avg_init_tpm_initial:.4f})")
    print(f"  Average Optimized Log TPM: {avg_opt_log_tpm_initial:.4f} (TPM: {avg_opt_tpm_initial:.4f})")
    print(f"  Log TPM Difference: {log_tpm_diff_initial:.4f}")
    print(f"  TPM Improvement: {tpm_improvement_initial:.4f} ({percent_str})")


    print("Genes:")
    print(gene_names)
    print(f"Average Initial Expression: {np.average(init_t)}")
    print(f"Best Expression: {np.average(best_scores)}")



    """#TODO Evaluate Optimization on Target Genes """

    # Define a new list of target genes
    target_genes = ["ANTXR2", "NFIL3", "UNC13D", "DHRS2", "RPS13", "HBD", "METAP1D", "NCALD"]

    # Initialize the retriever for gene information
    retriever = GeneInfoRetriever()

    # Retrieve promoter sequences for target genes
    target_refs = []
    for gene in target_genes:
        output_json = f"{gene}_info.json"
        cache_path = os.path.join('./.cache/', output_json)
        
        if not os.path.exists(cache_path):
            # Retrieve gene information
            gene_info = retriever.get_gene_info(gene, output_json=output_json)
            if "error" in gene_info:
                print(f"Error retrieving info for {gene}: {gene_info['error']}")
                target_refs.append(None)  # Handle errors gracefully
            else:
                target_refs.append(gene_info["promoter_sequence"])
        else:
            with open(cache_path, "r") as f:
                gene_info = json.load(f)
                target_refs.append(gene_info["promoter_sequence"])

    # Filter out any None entries (failed retrievals)
    valid_indices = [i for i, ref in enumerate(target_refs) if ref is not None]
    target_genes = [target_genes[i] for i in valid_indices]
    target_refs = [target_refs[i] for i in valid_indices]

    if not target_genes:
        print("No valid target genes retrieved. Exiting evaluation.")
    else:
        # Use initial and optimized sequences from the original optimization
        seqs_gen_init = seqs_gen_init  # From original gene optimization
        seqs_gen_opt = best_seqs    # From original gene optimization

        # Replace UTRs in target genes' promoter sequences with initial and optimized sequences
        seqs_init_target = retriever.replace_utr_in_multiple_sequences(
            target_genes, seqs_gen_init, target_length=10500, cache_dir="./.cache", output_prefix="target_modified_sequence"
        )
        seqs_opt_target = retriever.replace_utr_in_multiple_sequences(
            target_genes, seqs_gen_opt, target_length=10500, cache_dir="./.cache", output_prefix="target_modified_sequence"
        )

        # One-hot encode the sequences
        seqs_init_target = one_hot(seqs_init_target)
        seqs_opt_target = one_hot(seqs_opt_target)

        # Predict expression levels (Log TPM) using the model
        pred_init_target = model(seqs_init_target)
        pred_opt_target = model(seqs_opt_target)

        # Reshape predictions to (len(target_genes), -1)
        pred_init_target = tf.reshape(pred_init_target, (len(target_genes), -1))
        pred_opt_target = tf.reshape(pred_opt_target, (len(target_genes), -1))

        # Compute average Log TPM per gene
        init_log_tpm_target = tf.reduce_mean(pred_init_target, axis=1).numpy().astype('float')
        opt_log_tpm_target = tf.reduce_mean(pred_opt_target, axis=1).numpy().astype('float')

        # Compute overall average Log TPM across target genes
        avg_init_log_tpm_target = np.average(init_log_tpm_target)
        avg_opt_log_tpm_target = np.average(opt_log_tpm_target)

        # Convert Log TPM to TPM for percentage improvement
        # Assuming Log TPM is base-10 (common for TPM), TPM = 10^LogTPM
        avg_init_tpm_target = np.power(10, avg_init_log_tpm_target)
        avg_opt_tpm_target = np.power(10, avg_opt_log_tpm_target)

        # Compute improvement
        log_tpm_diff_target = avg_opt_log_tpm_target - avg_init_log_tpm_target
        tpm_improvement_target = avg_opt_tpm_target - avg_init_tpm_target
        # Percentage improvement based on TPM: ((opt - init) / init) * 100
        if avg_init_tpm_target != 0:  # Avoid division by zero
            tpm_percent_change_target = (tpm_improvement_target / avg_init_tpm_target) * 100
        else:
            tpm_percent_change_target = float('inf') if tpm_improvement_target > 0 else 0.0

        # Handle negative and positive percentages
        percent_str = f"{tpm_percent_change_target:.2f}%"
        if tpm_percent_change_target < 0:
            percent_str = f"{tpm_percent_change_target:.2f}% (decrease)"
        elif tpm_percent_change_target > 0:
            percent_str = f"+{tpm_percent_change_target:.2f}% (increase)"

        # Print evaluation results
        print("\nEvaluation of Optimization on Target Genes (Log TPM):")
        print(f"Original Genes: {gene_names}")
        print(f"Target Genes: {target_genes}")
        print("\nExpression Levels (Log TPM):")
        print(f"  Average Initial Log TPM: {avg_init_log_tpm_target:.4f} (TPM: {avg_init_tpm_target:.4f})")
        print(f"  Average Optimized Log TPM: {avg_opt_log_tpm_target:.4f} (TPM: {avg_opt_tpm_target:.4f})")
        print(f"  Log TPM Difference: {log_tpm_diff_target:.4f}")
        print(f"  TPM Improvement: {tpm_improvement_target:.4f} ({percent_str})")

        # Save evaluation results to a file
        with open('./outputs/target_genes_evaluation.txt', 'w') as f:
            f.write("Evaluation of Optimization on Target Genes (Log TPM)\n")
            f.write(f"Original Genes: {gene_names}\n")
            f.write(f"Target Genes: {target_genes}\n\n")
            f.write("Expression Levels (Log TPM):\n")
            f.write(f"  Average Initial Log TPM: {avg_init_log_tpm_target:.4f} (TPM: {avg_init_tpm_target:.4f})\n")
            f.write(f"  Average Optimized Log TPM: {avg_opt_log_tpm_target:.4f} (TPM: {avg_opt_tpm_target:.4f})\n")
            f.write(f"  Log TPM Difference: {log_tpm_diff_target:.4f}\n")
            f.write(f"  TPM Improvement: {tpm_improvement_target:.4f} ({percent_str})\n")

        # Optional: Per-gene breakdown
        print("\nPer-Gene Expression Levels (Log TPM):")
        for gene, init_log, opt_log in zip(target_genes, init_log_tpm_target, opt_log_tpm_target):
            init_tpm = np.power(10, init_log)
            opt_tpm = np.power(10, opt_log)
            tpm_diff = opt_tpm - init_tpm
            if init_tpm != 0:
                gene_percent = (tpm_diff / init_tpm) * 100
            else:
                gene_percent = float('inf') if tpm_diff > 0 else 0.0
            gene_percent_str = f"{gene_percent:.2f}%"
            if gene_percent < 0:
                gene_percent_str = f"{gene_percent:.2f}% (decrease)"
            elif gene_percent > 0:
                gene_percent_str = f"+{gene_percent:.2f}% (increase)"
            print(f"  {gene}: Initial Log TPM = {init_log:.4f} (TPM: {init_tpm:.4f}), "
                f"Optimized Log TPM = {opt_log:.4f} (TPM: {opt_tpm:.4f}), "
                f"TPM Improvement = {tpm_diff:.4f} ({gene_percent_str})")
    

    initial_target_expression = tf.cast(init_log_tpm_target, dtype=tf.float64)
    initial_expression = tf.cast(initial_exp_per_gene, dtype=tf.float64)
    optimized_expression_initial = tf.convert_to_tensor(best_scores_per_gene, dtype=tf.float64)

    optimized_expression_target = tf.cast(opt_log_tpm_target, dtype=tf.float64)

    # Store optimization results
    results = {
        'initial_genes': initial_genes,
        'target_genes': target_genes,
        'initial_expression': initial_expression,
        'initial_target_expression': initial_target_expression,
        'optimized_expression_initial': optimized_expression_initial,
        'optimized_expression_target': optimized_expression_target,
        'improvement_initial': 0.0,  # Will be calculated
        'improvement_target': 0.0,  # Will be calculated
        'percent_improvement_initial': 0.0,  # Will be calculated
        'percent_improvement_target': 0.0  # Will be calculated
    }
    
    # Calculate improvements
    results['improvement_initial'] = np.power(10,results['optimized_expression_initial']) - np.power(10,results['initial_expression'])
    results['improvement_target'] = np.power(10,results['optimized_expression_target']) - np.power(10,results['initial_target_expression'])
    results['percent_improvement_initial'] = (results['improvement_initial'] / results['initial_expression']) * 100
    results['percent_improvement_target'] = (results['improvement_target'] / results['initial_target_expression']) * 100
    
    # Save optimization data
    os.makedirs('./outputs', exist_ok=True)
    output_prefix = f"optimization_{initial_genes[0]}_{target_genes[0]}"
    
    # Save sequences
    with open(f'./outputs/{output_prefix}_init_seqs_mix.txt', 'w') as f:
        for item in seqs_gen_init:
            f.write(f'{item}\n')
    
    with open(f'./outputs/{output_prefix}_best_seqs_mix.txt', 'w') as f:
        for item in best_seqs:
            f.write(f'{item}\n')
    
    # Save expression values
    with open(f'./outputs/{output_prefix}_init_exps_mix.txt', 'w') as f:
        for item in init_t:
            f.write(f'{item}\n')
    
    with open(f'./outputs/{output_prefix}_best_exps_mix.txt', 'w') as f:
        for item in best_scores:
            f.write(f'{item}\n')
    
    with open(f'./outputs/{output_prefix}_opt_exps_mix.txt', 'w') as f:
        for item in opt_t:
            f.write(f'{item}\n')
    
    if verbose:
        print(f"\nOptimization completed for {len(initial_genes)} initial genes and {len(target_genes)} target genes")
        print(f"Initial genes: {initial_genes}")
        print(f"Target genes: {target_genes}")
        print(f"Initial expression (training genes): {results['initial_expression']}")
        print(f"Initial expression (target genes): {results['initial_target_expression']}")
        print(f"Optimized expression (training genes): {results['optimized_expression_initial']}")
        print(f"Optimized expression (target genes): {results['optimized_expression_target']}")
        print(f"Improvement (training genes): {results['improvement_initial']} ({results['percent_improvement_initial']}%)")
        print(f"Improvement (target genes): {results['improvement_target']} ({results['percent_improvement_target']}%)")
    
    return results


def run_multiple_optimizations(gene_sets, batch_size=100, steps=50, lr=0.001):
    """Run optimization for multiple gene sets."""
    results = []
    
    for i, gene_set in enumerate(gene_sets):
        print(f"\n=== Optimizing Gene Set {i+1}/{len(gene_sets)} ===")
        initial_genes = gene_set['initial']
        target_genes = gene_set['target']
        
        # Run optimization
        set_results = optimize_gene_set(
            initial_genes=initial_genes,
            target_genes=target_genes,
            batch_size=batch_size,
            steps=steps,
            lr=lr,
            verbose=True
        )
        
        results.append(set_results)
        
    return results

def plot_optimization_results(results, output_path='./outputs/optimization_results.png'):
    """Create a bar plot of optimization results."""
    # Extract data for plotting
    sets = [f"Set {i+1}" for i in range(len(results))]
    initial_improvements = [r['percent_improvement_initial'] for r in results]
    target_improvements = [r['percent_improvement_target'] for r in results]
    
    # Set up plot
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Set width of bars
    barWidth = 0.35
    
    # Set positions of bars on X axis
    r1 = np.arange(len(sets))
    r2 = [x + barWidth for x in r1]
    
    # Create bars
    initial_improvements = tf.math.reduce_mean(initial_improvements, axis=1).numpy()
    target_improvements = tf.math.reduce_mean(target_improvements, axis=1).numpy()
    ax.bar(r1, initial_improvements, width=barWidth, label='Initial Genes', color='blue', alpha=0.7)
    ax.bar(r2, target_improvements, width=barWidth, label='Target Genes', color='green', alpha=0.7)
    
    # Add labels and title
    ax.set_xlabel('Gene Sets', fontsize=12)
    ax.set_ylabel('Expression Improvement (%)', fontsize=12)
    ax.set_title('5\' UTR Optimization Results Across Gene Sets', fontsize=14)
    ax.set_xticks([r + barWidth/2 for r in range(len(sets))])
    ax.set_xticklabels(sets)
    
    # Add a horizontal line at y=0
    ax.axhline(y=0, color='black', linestyle='-', alpha=0.3)
    
    # Add values on top of bars
    for i, v in enumerate(initial_improvements):
        ax.text(r1[i], v + 5, f"{v:.1f}%", ha='center', va='bottom')
    
    for i, v in enumerate(target_improvements):
        ax.text(r2[i], v + 5, f"{v:.1f}%", ha='center', va='bottom')
    
    # Add legend
    ax.legend()
    
    # Add grid for better readability
    ax.yaxis.grid(True, linestyle='--', alpha=0.7)
    
    # Adjust layout and save
    plt.tight_layout()
    plt.savefig(output_path, dpi=300)
    print(f"Plot saved to {output_path}")
    
    # Show plot
    plt.show()

def plot_expression_levels(results, output_path='./outputs/expression_levels.png'):
    """Create a bar plot of expression levels before and after optimization."""
    # Extract data for plotting
    sets = [f"Set {i+1}" for i in range(len(results))]
    initial_expr = [r['initial_expression'].numpy() for r in results]
    initial_target_expr = [r['initial_target_expression'].numpy() for r in results]  # Added initial target expression
    optimized_initial = [r['optimized_expression_initial'].numpy() for r in results]
    optimized_target = [r['optimized_expression_target'].numpy() for r in results]
    
    # Set up plot
    fig, ax = plt.subplots(figsize=(14, 8))
    
    # Set width of bars
    barWidth = 0.2
    
    # Set positions of bars on X axis
    r1 = np.arange(len(sets))
    r2 = [x + barWidth for x in r1]
    r3 = [x + barWidth for x in r2]
    r4 = [x + barWidth for x in r3]
    
    initial_expr = tf.math.reduce_mean(initial_expr, axis=1).numpy()
    initial_target_expr = tf.math.reduce_mean(initial_target_expr, axis=1).numpy()
    optimized_initial = tf.math.reduce_mean(optimized_initial, axis=1).numpy()
    optimized_target = tf.math.reduce_mean(optimized_target, axis=1).numpy()
    # Create bars
    ax.bar(r1, initial_expr, width=barWidth, label='Initial (Training Genes)', color='gray', alpha=0.7)
    ax.bar(r2, initial_target_expr, width=barWidth, label='Initial (Target Genes)', color='purple', alpha=0.7)
    ax.bar(r3, optimized_initial, width=barWidth, label='Optimized (Training Genes)', color='blue', alpha=0.7)
    ax.bar(r4, optimized_target, width=barWidth, label='Optimized (Target Genes)', color='green', alpha=0.7)
    
    # Add labels and title
    ax.set_xlabel('Gene Sets', fontsize=12)
    ax.set_ylabel('Expression Level (log TPM)', fontsize=12)
    ax.set_title('Expression Levels Before and After Optimization', fontsize=14)
    ax.set_xticks([r + barWidth*1.5 for r in range(len(sets))])
    ax.set_xticklabels(sets)
    
    # Add values on top of bars
    for i, v in enumerate(initial_expr):
        ax.text(r1[i], v + 0.1, f"{v:.2f}", ha='center', va='bottom')
    
    for i, v in enumerate(initial_target_expr):
        ax.text(r2[i], v + 0.1, f"{v:.2f}", ha='center', va='bottom')
    
    for i, v in enumerate(optimized_initial):
        ax.text(r3[i], v + 0.1, f"{v:.2f}", ha='center', va='bottom')
        
    for i, v in enumerate(optimized_target):
        ax.text(r4[i], v + 0.1, f"{v:.2f}", ha='center', va='bottom')
    
    # Add legend
    ax.legend()
    
    # Add grid for better readability
    ax.yaxis.grid(True, linestyle='--', alpha=0.7)
    
    # Adjust layout and save
    plt.tight_layout()
    plt.savefig(output_path, dpi=300)
    print(f"Plot saved to {output_path}")
    
    # Show plot
    plt.show()

def main():
    # Define 8 gene sets (initial and target genes)
    gene_sets = [
        {
            'initial': ["MYOC", "TIGD4", "ATP6V1B2", "TAGLN", "COX7A2L", "IFNGR2", "TNFRSF21", "SETD6"],
            'target': ["ANTXR2", "NFIL3", "UNC13D", "DHRS2", "RPS13", "HBD", "METAP1D", "NCALD"]
        },
        {
            'initial': ["BRCA1", "TNFAIP3", "TRIM36", "TEX55", "LEMD2", "LSG1", "SGIP1", "MAD2L1"],
            'target': ["DAZL", "PPARG", "CDKN1A", "BAX", "MDM2", "BCL2", "TP53", "VEGFA"]
        },
        {
            'initial': ["IL6", "TNF", "IFNG", "IL10", "TGFB1", "IL2", "IL4", "IL17A"],
            'target': ["CD4", "CD8A", "CD19", "CD3E", "CD14", "FOXP3", "CTLA4", "PD1"]
        },
        {
            'initial': ["SOX2", "POU5F1", "NANOG", "KLF4", "MYC", "LIN28", "DNMT3B", "ZFP42"],
            'target': ["GATA1", "GATA2", "TAL1", "RUNX1", "MYB", "GYPA", "HBB", "HBA1"]
        },
        {
            'initial': ["INSR", "IGF1R", "IRS1", "PIK3CA", "AKT1", "MTOR", "PTEN", "GSK3B"],
            'target': ["PPARG", "ADIPOQ", "LEP", "FABP4", "PLIN1", "UCP1", "ADRB3", "CPT1A"]
        },
        {
            'initial': ["APP", "PSEN1", "PSEN2", "APOE", "MAPT", "BACE1", "GSK3B", "CDK5"],
            'target': ["BDNF", "NGF", "GDNF", "NTF3", "NTRK1", "NTRK2", "NGFR", "GFAP"]
        },
        {
            'initial': ["PCNA", "CCNA2", "CCNB1", "CDK1", "CDK2", "E2F1", "RB1", "TP53"],
            'target': ["CDKN1A", "CDKN1B", "CDKN2A", "CDKN2B", "ATM", "ATR", "CHEK1", "CHEK2"]
        },
        {
            'initial': ["HIF1A", "VEGFA", "EGFR", "KDR", "FLT1", "NRP1", "ANGPT1", "TEK"],
            'target': ["CDH5", "PECAM1", "VWF", "ICAM1", "VCAM1", "SELE", "EDN1", "NOS3"]
        }
    ]
    
    # Run optimization for all gene sets
    print("Starting optimization for 8 gene sets...")
    results = run_multiple_optimizations(gene_sets, steps=1000, lr=0.005)
    
    # Save results to a CSV file
    results_df = pd.DataFrame(results)
    results_df.to_csv('./outputs/optimization_results.csv', index=False)
    print("Results saved to ./outputs/optimization_results.csv")
    
    # Create plots to visualize results
    plot_optimization_results(results)
    plot_expression_levels(results)
    
    # Print summary statistics
    avg_initial_improvement = np.mean([r['percent_improvement_initial'] for r in results])
    avg_target_improvement = np.mean([r['percent_improvement_target'] for r in results])
    
    print("\n=== Summary Statistics ===")
    print(f"Average improvement for initial genes: {avg_initial_improvement:.2f}%")
    print(f"Average improvement for target genes: {avg_target_improvement:.2f}%")
    print(f"Ratio of target to initial improvement: {avg_target_improvement/avg_initial_improvement:.2f}")
    
    # Analyze gene set performance
    best_initial_set_idx = np.argmax([r['percent_improvement_initial'] for r in results])
    best_target_set_idx = np.argmax([r['percent_improvement_target'] for r in results])
    
    print("\n=== Best Performing Gene Sets ===")
    print(f"Best initial gene set: Set {best_initial_set_idx+1}")
    print(f"  Initial genes: {gene_sets[best_initial_set_idx]['initial']}")
    print(f"  Improvement: {np.mean(results[best_initial_set_idx]['percent_improvement_initial'].numpy()):.2f}%")
    
    print(f"Best target gene set: Set {best_target_set_idx+1}")
    print(f"  Target genes: {gene_sets[best_target_set_idx]['target']}")
    print(f"  Improvement: {np.mean(results[best_target_set_idx]['percent_improvement_target'].numpy()):.2f}%")

if __name__ == "__main__":
    main()