# Template Notebook for using an Island Model GA with Style Evolution on Islands
Notebook Version: 0.7.0 (15/03/2025)
* include TensorboardEmbedVisualizer
* rename functions, conditionally save images
* show parent history of solutions

## Google Colab Setup

In [None]:
from evolutionary.history import SolutionHistoryKey
# Google Colab: Execute this to install packages and setup drive
!pip install "evolutionary[prompt_embedding] @ git+https://git@github.com/malthee/evolutionary-diffusion.git"

In [None]:
# Mount drive to save results
from google.colab import drive
import evolutionary_imaging.processing as ip
import evolutionary_prompt_embedding.tensorboard_embed_visualizer as ev
drive.mount("/content/drive")
base_path = "/content/drive/MyDrive/evolutionary/"
ip.RESULTS_FOLDER = base_path + ip.RESULTS_FOLDER
ev.DEFAULT_OUTPUT_FOLDER = base_path + "vis"
save_run_path = base_path + "saved_runs"

In [None]:
# Check if GPU is available
import torch
print(torch.cuda.is_available())

## Project Setup

In [None]:
from evolutionary.plotting import plot_fitness_statistics, plot_time_statistics
import evolutionary_imaging.processing as ip
import evolutionary_prompt_embedding.tensorboard_embed_visualizer as ev
from diffusers.utils import logging
from evolutionary_imaging.processing import create_animation_from_generations, create_generation_image_grid, save_images_from_generation
import torch
import os

In [None]:
logging.disable_progress_bar() # Or else your output will be full of progress bars
logging.set_verbosity_error() # Enable again if you are having problems
os.environ["TOKENIZERS_PARALLELISM"] = "false" # To remove warning of libraries using tokenizers
# Change the results folder for images and embedding visualization if you want to
# ip.RESULTS_FOLDER = "choose_your_destination"
# ev.DEFAULT_OUTPUT_FOLDER = "choose_your_destination"
# save_run_path = "saved_runs"

use_visualizer = True # Set to False if you don't want to use the TensorboardEmbedVisualizer
save_images = True # Set to False if you don't want to save images
print(torch.random.get_rng_state()) # Check torch random state, used across all libraries. Caution setting fixed seeds as it affects not only generation but also variation.

In [None]:
from evolutionary_prompt_embedding.argument_types import PooledPromptEmbedData
from evolutionary_prompt_embedding.image_creation import SDXLPromptEmbeddingImageCreator
from evolutionary_prompt_embedding.variation import \
    UniformGaussianMutatorArguments, PooledUniformGaussianMutator, PooledArithmeticCrossover, PooledUniformCrossover
from evolutionary_prompt_embedding.value_ranges import SDXLTurboEmbeddingRange, SDXLTurboPooledEmbeddingRange
from evolutionary.evolutionary_selectors import TournamentSelector, RouletteWheelSelector, RankSelector
from evolutionary.evaluators import CappedEvaluator, GoalDiminishingEvaluator, MultiObjectiveEvaluator
from evolutionary.algorithms.island_model import IslandModel
from evolutionary.algorithms.ga import GeneticAlgorithm
from evolutionary.algorithms.nsga_ii import NSGA_II, NSGATournamentSelector
from evolutionary_imaging.evaluators import AIDetectionImageEvaluator, AestheticsImageEvaluator, SingleCLIPIQAEvaluator
from evolutionary_prompt_embedding.tensorboard_embed_visualizer import TensorboardEmbedVisualizer, EmbeddingVariant

visualizer = TensorboardEmbedVisualizer[PooledPromptEmbedData, [str, str, str, str]](["Index", "Generation", "Fitness", "Island"])

class IslandPostEvaluationCallback:
    """Class to contain post evaluation callbacks for islands; used to allow pickling"""
    def __init__(self, ident: int, description: str):
        self.ident = ident
        self.description = description

    def __call__(self, g, a):
        image_paths = None
        if save_images:
            image_paths = save_images_from_generation(a.population, g, self.ident)
        if use_visualizer:
            for i, s in enumerate(a.population):
                visualizer.add_embedding(s.arguments, [str(i), str(g), f"{s.fitness:.3f}", self.description], image_paths[i] if image_paths else None)

population_size = 8
num_generations = 30
batch_size = 1
elitism = None
inference_steps = 4
crossover_rate = 0.9
mutation_rate = 0.3
migration_size = 1
migration_interval = 1

art_styles = [
    "Prehistoric Art",  # c. 40,000 BCE - 4,000 BCE
#   "Ancient Egyptian Art",  # c. 3,100 BCE - 332 BCE
    "Classical Greek Art",  # c. 480 BCE - 323 BCE
#    "Roman Art",  # c. 500 BCE - 476 CE
#    "Byzantine Art",  # c. 330 CE - 1453 CE
    "Islamic Art",  # c. 7th Century - Present
  #  "Romanesque Art",  # c. 1000 CE - 1200 CE
   # "Gothic Art",  # c. 12th Century - 16th Century
    "Renaissance Art",  # c. 14th Century - 17th Century
 #   "Baroque Art",  # c. 1600 CE - 1750 CE
 #   "Neoclassicism",  # c. 18th Century - Early 19th Century
    "Romanticism",  # c. Late 18th Century - Mid 19th Century
    "Realism",  # c. Mid 19th Century
    "Impressionism",  # c. 1860s - 1880s
 #   "Modernism",  # Late 19th Century - 1970s
    "Contemporary Art"  # Post-1945 - Present
]

embedding_range = SDXLTurboEmbeddingRange()
pooled_embedding_range = SDXLTurboPooledEmbeddingRange()

creator = SDXLPromptEmbeddingImageCreator(batch_size=batch_size, inference_steps=inference_steps)
evaluator = AestheticsImageEvaluator() # For diverse results avoid high selection pressure, the AestheticsEvaluator produces average good-looking images at scores 5-6
crossover = PooledArithmeticCrossover(interpolation_weight=0.5, interpolation_weight_pooled=0.5)
mutation_arguments = UniformGaussianMutatorArguments(mutation_rate=0.05, mutation_strength=1.5, 
                                                     clamp_range=(embedding_range.minimum, embedding_range.maximum)) 
mutation_arguments_pooled = UniformGaussianMutatorArguments(mutation_rate=0.05, mutation_strength=0.4, 
                                                            clamp_range=(pooled_embedding_range.minimum, pooled_embedding_range.maximum))
mutator = PooledUniformGaussianMutator(mutation_arguments, mutation_arguments_pooled)
selector = RouletteWheelSelector() # Less selection pressure, more exploration

ga_instances = []

init_crossover = PooledArithmeticCrossover(interpolation_weight=0.8, interpolation_weight_pooled=0.8)
for i, style in enumerate(art_styles):
    style_arg = creator.arguments_from_prompt(style)
    init_args = [init_crossover.crossover(style_arg,  # Combine the style with a random tensor, weighted towards the artist
                                          PooledPromptEmbedData(embedding_range.random_tensor_in_range(), pooled_embedding_range.random_tensor_in_range())) 
                 for _ in range(population_size)]
    post_evaluation_callback = IslandPostEvaluationCallback(i, style)
 
    ga_instances.append(GeneticAlgorithm(
        population_size=population_size,
        num_generations=num_generations,
        solution_creator=creator,
        evaluator=evaluator,
        mutator=mutator,
        crossover=crossover,
        selector=selector,
        initial_arguments=init_args,
        crossover_rate=crossover_rate,
        mutation_rate=mutation_rate,
        elitism_count=elitism,
        post_evaluation_callback=post_evaluation_callback,
        ident=i
    ))

In [None]:
island_model = IslandModel(
    ga_instances,
    migration_size=migration_size,
    migration_interval=migration_interval,
)

In [None]:
best_solutions = island_model.run()

In [None]:
from diffusers.utils import make_image_grid

# Show best solution
for i, best_solution in enumerate(best_solutions):
    print(f"Highest scoring solution for epoch {art_styles[i]}: {best_solution.fitness}")

make_image_grid([image for solution in best_solutions for image in solution.result.images], 2, batch_size * len(best_solutions) // 2)

## Print the history of a solution

In [None]:
from IPython.display import IFrame, display
from evolutionary.history import SolutionHistoryKey, SolutionHistoryItem
from evolutionary_imaging.family_tree import visualize_family_tree
from PIL import Image

history_format = "png" # Supports graphviz formats
history_key = SolutionHistoryKey(index=0, generation=4, ident=5) # Index, Generation, Ident of Island
dot = visualize_family_tree(history=island_model.statistics.solution_history, root_key=history_key, depth=6, format=history_format)
history_path = dot.render(filename="family_tree", cleanup=True, format=history_format)
Image.open(history_path) # PNG
# display(IFrame("family_tree.pdf", width=1000, height=1000)) # PDF
#print(island_model.statistics.history_string(key=history_key, depth=3)) # Textual

## Visualize Embeddings with the Tensorboard Embedding Projector

In [None]:
# This will save the embeddings and the metadata to your disk
visualizer.generate_visualization(
    sprite_single_image_dim=(80, 80),
    #filter_predicate=lambda e, l, i: int(l[0]) < 3, # Adjust this to filter embeddings if needed
)

In [None]:
%load_ext tensorboard
%tensorboard --logdir={visualizer.output_folder}

## Create a video from the generational progress showing the top images for each island

In [None]:
for gen in range(island_model.completed_generations):
    create_generation_image_grid(gen, images_per_row=4, max_images=16, label_fontsize=10, ident_mapper=art_styles, group_by_ident=True)

In [None]:
video_loc = create_animation_from_generations(island_model.completed_generations)
print(video_loc)

## Plot statistics

In [None]:
stats = island_model.statistics
plot_fitness_statistics(island_model.completed_generations, stats.best_fitness, stats.worst_fitness, stats.avg_fitness)

In [None]:
plot_time_statistics(stats.evaluation_time, stats.creation_time, stats.post_evaluation_time)

### Save the run to disk

In [None]:
import pickle
import os
from datetime import datetime

if "save_run_path""" not in globals():
    save_run_path = "saved_runs"
os.makedirs(save_run_path, exist_ok=True)
output_file = os.path.join(save_run_path, f"island_model_v0_7_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl")
with open(output_file, "wb") as f:
    pickle.dump(island_model, f)
print(f"Run saved to {output_file}")
    

### Load the run from disk 
Notebook and library versions should match with the saved run

In [None]:
import pickle
import os

with open(os.path.join("saved_runs", "insert_filename.pkl"), "rb") as f:
    island_model = pickle.load(f)

## Fallback functions for when something went wrong

### Access Best Solution from Disk

In [None]:
import os
import glob
import evolutionary_imaging.processing as ip
from PIL import Image

num_generations = island_model.completed_generations  # Set this to the number of generations you ran (if you didn't finish)
generation_dir = os.path.join(ip.RESULTS_FOLDER, f"{num_generations}")
image_files = glob.glob(os.path.join(generation_dir, "*.png"))
image_files.sort(key=ip.fitness_filename_sorting_key, reverse=True)
print(image_files[0])
Image.open(image_files[0])

### ffmpeg is not installed, create GIF instead

In [None]:
from evolutionary_imaging.processing import create_animation_from_generations_pil
video_loc = create_animation_from_generations_pil(num_generations)
print(video_loc)