# **1. INSTALL DEPENDENCIES**

In [None]:
%%capture
!pip install -q transformers sentence-transformers deap datasets tqdm matplotlib

# **2. IMPORTS & CONFIG**

In [None]:
import random
import numpy as np
import torch
import logging
from deap import base, creator, tools
from transformers import pipeline, AutoTokenizer
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
from tqdm import tqdm
import matplotlib.pyplot as plt

# Set logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

# For reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Genetic Algorithm Parameters
POPULATION_SIZE = 20
GENERATIONS = 10
CX_PROB = 0.5
MUT_PROB = 0.3
ELITE_SIZE = 2

# Model selection
MODEL_NAME = "gpt2"
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"

# **3. DATASET LOADING**

In [None]:
def create_fallback_dataset():
    articles = [
        "The quick brown fox jumps over the lazy dog. This classic sentence contains all letters in the English alphabet.",
        "Global temperatures continue to rise, with scientists predicting a 1.5°C increase within the next decade.",
        "Artificial intelligence is transforming industries from healthcare to finance, with new applications emerging daily.",
        "Researchers have discovered a potential breakthrough in battery technology that could double electric vehicle range."
    ]
    summaries = [
        "The fox sentence demonstrates all English letters.",
        "Climate scientists warn of significant temperature increases.",
        "AI is revolutionizing multiple industry sectors.",
        "New battery tech may greatly improve EV performance."
    ]
    return [{'article': a, 'highlights': s} for a, s in zip(articles, summaries)]

def load_data():
    try:
        dataset = load_dataset("cnn_dailymail", "3.0.0", split="validation[:20]", trust_remote_code=True)
        logging.info("Loaded CNN/DailyMail dataset successfully.")
        reference_summaries = [item['highlights'] for item in dataset]
    except Exception as e:
        logging.warning(f"Error loading CNN/DailyMail: {e}")
        dataset = create_fallback_dataset()
        reference_summaries = [item['highlights'] for item in dataset]
        logging.info("Using fallback dataset.")
    return dataset, reference_summaries

dataset, reference_summaries = load_data()

# **4. MODEL LOADING**

In [None]:
device = 0 if torch.cuda.is_available() else -1

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

generator = pipeline(
    "text-generation",
    model=MODEL_NAME,
    tokenizer=tokenizer,
    device=device,
    max_new_tokens=80,
    temperature=0.7,
    do_sample=True
)

eval_model = SentenceTransformer(EMBED_MODEL_NAME, device='cuda' if torch.cuda.is_available() else 'cpu')

# **5. PROMPT POOLS**

In [None]:
INSTRUCTION_POOL = [
    "Summarize this text",
    "Create a summary of",
    "Briefly explain",
    "Generate a short summary of",
    "Condense this text"
]

STYLE_POOL = [
    "in a professional tone",
    "in simple language",
    "using bullet points",
    "in 3 sentences maximum",
    "focusing on main ideas"
]

def create_individual():
    return [
        random.choice(INSTRUCTION_POOL),
        random.choice(STYLE_POOL)
    ]

# **6. GENETIC OPERATORS**

In [None]:
def mutate_individual(individual, indpb):
    for i in range(len(individual)):
        if random.random() < indpb:
            if i == 0:
                individual[i] = random.choice(INSTRUCTION_POOL)
            else:
                individual[i] = random.choice(STYLE_POOL)
    return (individual,)

def evaluate(individual, k=3):
    """
    Evaluate prompt on k random samples for robustness.
    Returns tuple with average cosine similarity.
    """
    try:
        indices = np.random.choice(len(dataset), size=k, replace=False)
        similarities = []
        for idx in indices:
            text_sample = dataset[idx]['article'] if isinstance(dataset[idx], dict) else dataset[idx]['article']
            reference = reference_summaries[idx]
            prompt = f'{individual[0]}: "{text_sample}" {individual[1]}'
            output = generator(prompt, num_return_sequences=1, truncation=True)[0]['generated_text']
            generated_text = output[len(prompt):].strip()
            emb_gen = eval_model.encode(generated_text)
            emb_ref = eval_model.encode(reference)
            similarity = np.dot(emb_gen, emb_ref) / (np.linalg.norm(emb_gen) * np.linalg.norm(emb_ref))
            similarities.append(similarity)
        avg_sim = float(np.mean(similarities))
        return (avg_sim,)
    except Exception as e:
        logging.warning(f"Error in evaluation: {str(e)[:120]}...")
        return (0.0,)

# **7. DEAP SETUP**

In [None]:
creator.create("FitnessMax", base.Fitness, weights=(1.0,))
creator.create("Individual", list, fitness=creator.FitnessMax)

toolbox = base.Toolbox()
toolbox.register("individual", tools.initIterate, creator.Individual, create_individual)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)
toolbox.register("evaluate", evaluate)
toolbox.register("mate", tools.cxTwoPoint)
toolbox.register("mutate", mutate_individual, indpb=0.3)
toolbox.register("select", tools.selTournament, tournsize=3)

class ScalarStats(tools.Statistics):
    def __init__(self, key):
        super().__init__(key)
    def compile(self, data):
        values = [self.key(ind) for ind in data]
        if values and isinstance(values[0], tuple):
            values = [v[0] for v in values]
        return {
            "avg": np.mean(values),
            "min": np.min(values),
            "max": np.max(values)
        }

# **8. GA MAIN LOOP**

In [None]:
def main():
    pop = toolbox.population(n=POPULATION_SIZE)
    hof = tools.HallOfFame(5)
    stats = ScalarStats(lambda ind: ind.fitness.values)
    logbook = tools.Logbook()
    logbook.header = ["gen", "nevals", "avg", "min", "max"]

    # Evaluate initial population
    logging.info("Evaluating initial population...")
    fitnesses = list(tqdm(map(toolbox.evaluate, pop), total=len(pop)))
    for ind, fit in zip(pop, fitnesses):
        ind.fitness.values = fit

    hof.update(pop)
    record = stats.compile(pop)
    logbook.record(gen=0, nevals=len(pop), **record)
    logging.info(f"Generation 0: Max fitness = {record['max']:.4f}")

    # Evolution loop
    for gen in range(1, GENERATIONS + 1):
        offspring = toolbox.select(pop, len(pop) - ELITE_SIZE)
        offspring = list(map(toolbox.clone, offspring))

        # Crossover
        for child1, child2 in zip(offspring[::2], offspring[1::2]):
            if random.random() < CX_PROB:
                toolbox.mate(child1, child2)
                del child1.fitness.values
                del child2.fitness.values

        # Mutation
        for mutant in offspring:
            if random.random() < MUT_PROB:
                toolbox.mutate(mutant)
                del mutant.fitness.values

        # Evaluate new individuals
        invalid_ind = [ind for ind in offspring if not ind.fitness.valid]
        fitnesses = list(tqdm(map(toolbox.evaluate, invalid_ind), total=len(invalid_ind)))
        for ind, fit in zip(invalid_ind, fitnesses):
            ind.fitness.values = fit

        # Elitism
        elite = sorted(pop, key=lambda ind: ind.fitness.values[0], reverse=True)[:ELITE_SIZE]
        pop[:] = elite + offspring
        hof.update(pop)

        # Logging
        record = stats.compile(pop)
        logbook.record(gen=gen, nevals=len(invalid_ind), **record)
        logging.info(f"Generation {gen}: Max fitness = {record['max']:.4f}")

    return pop, hof, logbook

# **9. RESULTS & VISUALIZATION**

In [None]:
def plot_results(log):
    gen = log.select("gen")
    avg = log.select("avg")
    max_ = log.select("max")

    plt.figure(figsize=(10, 6))
    plt.plot(gen, avg, label="Average Fitness")
    plt.plot(gen, max_, label="Max Fitness")
    plt.xlabel("Generation")
    plt.ylabel("Fitness (Cosine Similarity)")
    plt.title("Prompt Optimization Progress")
    plt.legend()
    plt.grid()
    plt.show()

def display_top_prompts(hall_of_fame, dataset, generator):
    print("\nTop 5 Optimized Prompts:")
    for i, ind in enumerate(hall_of_fame):
        prompt = f'{ind[0]}: "[TEXT]" {ind[1]}'
        print(f"\nRank {i+1} (Fitness: {ind.fitness.values[0]:.4f}):")
        print(prompt)
        sample_text = dataset[i % len(dataset)]['article'] if isinstance(dataset[i % len(dataset)], dict) else dataset[i % len(dataset)]['article']
        full_prompt = f'{ind[0]}: "{sample_text}" {ind[1]}'
        output = generator(full_prompt, num_return_sequences=1, truncation=True)[0]['generated_text']
        print("\nExample Output:")
        print(output[len(full_prompt):].strip())

def calculate_metrics(prompt, text, reference):
    full_prompt = f'{prompt[0]}: "{text}" {prompt[1]}'
    output = generator(full_prompt, num_return_sequences=1, truncation=True)[0]['generated_text']
    generated_text = output[len(full_prompt):].strip()
    emb_gen = eval_model.encode(generated_text)
    emb_ref = eval_model.encode(reference)
    similarity = np.dot(emb_gen, emb_ref) / (np.linalg.norm(emb_gen) * np.linalg.norm(emb_ref))
    word_count = len(generated_text.split())
    unique_words = len(set(generated_text.lower().split()))
    diversity = unique_words / word_count if word_count > 0 else 0
    return similarity, word_count, diversity, generated_text

def quantitative_comparison(hall_of_fame, dataset, reference_summaries):
    print("\nQuantitative Comparison:")
    print("="*50)
    initial_prompt = create_individual()
    optimized_prompt = hall_of_fame[0]
    sample = dataset[0]
    ref = reference_summaries[0]
    sample_text = sample['article'] if isinstance(sample, dict) else sample['article']

    init_sim, init_words, init_div, init_gen = calculate_metrics(initial_prompt, sample_text, ref)
    opt_sim, opt_words, opt_div, opt_gen = calculate_metrics(optimized_prompt, sample_text, ref)

    print(f"\nInitial Prompt: {initial_prompt[0]}: \"[TEXT]\" {initial_prompt[1]}")
    print(f"Generated: {init_gen}")
    print(f"Similarity: {init_sim:.4f} | Word Count: {init_words} | Diversity: {init_div:.4f}")

    print(f"\nOptimized Prompt: {optimized_prompt[0]}: \"[TEXT]\" {optimized_prompt[1]}")
    print(f"Generated: {opt_gen}")
    print(f"Similarity: {opt_sim:.4f} | Word Count: {opt_words} | Diversity: {opt_div:.4f}")

    improvement = (opt_sim - init_sim) / init_sim * 100 if init_sim != 0 else float('inf')
    print(f"\nImprovement: {improvement:.2f}% increase in similarity to reference")

# **10. RUN EVERYTHING**

In [None]:
if __name__ == "__main__":
    final_pop, hall_of_fame, stats_log = main()
    plot_results(stats_log)
    display_top_prompts(hall_of_fame, dataset, generator)
    quantitative_comparison(hall_of_fame, dataset, reference_summaries)