# Template Notebook for testing Genetic Algorithms exploring the Prompt Embedding space
Notebook Version: 0.3 (06/03/2024)
* added Google Colab support

## Google Colab Setup

In [None]:
# Google Colab: Execute this to install packages and setup drive
!pip install "evolutionary[all] @ git+https://git@github.com/malthee/evolutionary-diffusion.git"

In [None]:
# Mount drive to save results
from google.colab import drive
import evolutionary_imaging.processing as ip
drive.mount("/content/drive")
base_path = "/content/drive/MyDrive/evolutionary/"
ip.RESULTS_FOLDER = base_path + ip.RESULTS_FOLDER

In [None]:
# Check if GPU is available
import torch
print(torch.cuda.is_available())

## Project Setup

In [None]:
from evolutionary.plotting import plot_fitness_statistics
import evolutionary_imaging.processing as ip
from diffusers.utils import logging
from evolutionary_imaging.processing import create_animation_from_generations, create_generation_image_grid, save_images_from_generation
import torch
import os

In [None]:
logging.disable_progress_bar() # Or else your output will be full of progress bars
logging.set_verbosity_error() # Enable again if you are having problems
os.environ["TOKENIZERS_PARALLELISM"] = "false" # To remove warning of libraries using tokenizers
# Change the results folder for images if you want to 
# ip.RESULTS_FOLDER = 'choose_your_destination'

def save_images_post_evaluation(g, a):
    save_images_from_generation(a.population, g)
    
# Check torch random state, used across all libraries. Caution setting fixed seeds as it affects not only generation but also variation.
print(torch.random.get_rng_state())

In [None]:
from evolutionary_prompt_embedding.argument_types import PooledPromptEmbedData
from evolutionary_prompt_embedding.image_creation import SDXLPromptEmbeddingImageCreator
from evolutionary_prompt_embedding.variation import \
    UniformGaussianMutatorArguments, PooledUniformGaussianMutator, PooledArithmeticCrossover
from evolutionary_prompt_embedding.value_ranges import SDXLTurboEmbeddingRange, SDXLTurboPooledEmbeddingRange
from evolutionary.evolutionary_selectors import TournamentSelector
from evolutionary.algorithms.ga import GeneticAlgorithm
from evolutionary_imaging.evaluators import AIDetectionImageEvaluator, AestheticsImageEvaluator, CLIPScoreEvaluator, SingleCLIPIQAEvaluator

population_size = 100
num_generations = 100
batch_size = 1
elitism = 1
inference_steps = 1

# Define min/max values for the prompt embeddings
embedding_range = SDXLTurboEmbeddingRange()
pooled_embedding_range = SDXLTurboPooledEmbeddingRange()
# Create the necessary components for the genetic algorithm
creator = SDXLPromptEmbeddingImageCreator(batch_size=batch_size, inference_steps=inference_steps)
evaluator = AIDetectionImageEvaluator()
crossover = PooledArithmeticCrossover(crossover_rate=0.5, crossover_rate_pooled=0.5)
mutation_arguments = UniformGaussianMutatorArguments(mutation_rate=0.05, mutation_strength=3, 
                                                     clamp_range=(embedding_range.minimum, embedding_range.maximum)) 
mutation_arguments_pooled = UniformGaussianMutatorArguments(mutation_rate=0.05, mutation_strength=0.7, 
                                                            clamp_range=(pooled_embedding_range.minimum, pooled_embedding_range.maximum))
mutator = PooledUniformGaussianMutator(mutation_arguments, mutation_arguments_pooled)
selector = TournamentSelector(tournament_size=3)

# Prepare initial arguments, random population of *reasonable* prompt embeddings
init_args = [PooledPromptEmbedData(embedding_range.random_tensor_in_range(), pooled_embedding_range.random_tensor_in_range()) 
             for _ in range(population_size)]

# Create and run the genetic algorithm
ga = GeneticAlgorithm(
    population_size=population_size,
    num_generations=num_generations,
    solution_creator=creator,
    evaluator=evaluator,
    mutator=mutator,
    crossover=crossover,
    selector=selector,
    initial_arguments=init_args,
    elitism_count=elitism,
    post_evaluation_callback=save_images_post_evaluation,
)

In [None]:
best_solution = ga.run()

In [None]:
from diffusers.utils import make_image_grid

# Show best solution
print(best_solution.fitness)
make_image_grid(best_solution.result.images, 1, batch_size)

## Compare to directly generating it with the prompt (for CLIP-Score)

In [None]:
from evolutionary_prompt_embedding.image_creation import SDXLPromptEmbeddingImageCreator
creator_compare = SDXLPromptEmbeddingImageCreator(batch_size=4, inference_steps=inference_steps)

In [None]:
from diffusers.utils import make_image_grid
args = creator_compare.arguments_from_prompt(prompt)
solution = creator_compare.create_solution(args)
print(evaluator.evaluate(solution.result))
make_image_grid(solution.result.images, 2, 2)

## Visualize the evolution

In [None]:
for gen in range(num_generations):
    create_generation_image_grid(gen, max_images=10)

In [None]:
video_loc = create_animation_from_generations(num_generations)
print(video_loc)

## Plot fitness statistics

In [None]:
plot_fitness_statistics(num_generations, ga.best_fitness, ga.worst_fitness, ga.avg_fitness)

## Save notebook and components

In [None]:
!jupyter nbconvert --to html ga_notebook.ipynb

### Save the run to disk

In [None]:
import pickle
import os
from datetime import datetime

os.makedirs("saved_runs", exist_ok=True)
output_file = os.path.join("saved_runs", f"ga_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl")
with open(output_file, "wb") as f:
    pickle.dump(ga, f)
print(f"Run saved to {output_file}")

### Load the run from disk 
Notebook and library versions should match with the saved run

In [None]:
import pickle
import os

with open(os.path.join("saved_runs", "insert_filename"), "rb") as f:
    run = pickle.load(f)

## Fallback functions for when something went wrong

### Access Best Solution from Disk

In [None]:
import os
import glob
import evolutionary_imaging.processing as ip
from PIL import Image

num_generations = 42  # Set this to the number of generations you ran (if you didn't finish)
generation_dir = os.path.join(ip.RESULTS_FOLDER, f"{num_generations}")
image_files = glob.glob(os.path.join(generation_dir, "*.png"))
image_files.sort(key=ip.fitness_filename_sorting_key, reverse=True)
print(image_files[0])
Image.open(image_files[0])

### ffmpeg is not installed, create GIF instead

In [None]:
from evolutionary_imaging.processing import create_animation_from_generations_pil
video_loc = create_animation_from_generations_pil(num_generations)
print(video_loc)