# Template Notebook for testing an Island Model GA with Artists/Styles on an island
Notebook Version: 0.4 (08/03/2024)
* now showing island identifiers in the grid

## Google Colab Setup

In [None]:
# Google Colab: Execute this to install packages and setup drive
!pip install "evolutionary[all] @ git+https://git@github.com/malthee/evolutionary-diffusion.git"

In [None]:
# Mount drive to save results
from google.colab import drive
import evolutionary_imaging.processing as ip
drive.mount("/content/drive")
base_path = "/content/drive/MyDrive/evolutionary/"
ip.RESULTS_FOLDER = base_path + ip.RESULTS_FOLDER

In [None]:
# Check if GPU is available
import torch
print(torch.cuda.is_available())

## Project Setup

In [None]:
from evolutionary.plotting import plot_fitness_statistics
import evolutionary_imaging.processing as ip
from diffusers.utils import logging
from evolutionary_imaging.processing import create_animation_from_generations, create_generation_image_grid, save_images_from_generation
import torch
import os

In [None]:
logging.disable_progress_bar() # Or else your output will be full of progress bars
logging.set_verbosity_error() # Enable again if you are having problems
os.environ["TOKENIZERS_PARALLELISM"] = "false" # To remove warning of libraries using tokenizers
# Change the results folder for images if you want to 
# ip.RESULTS_FOLDER = 'choose_your_destination'

class SaveImagesPostEvaluation:  # Class to save images and difference between islands; used to allow pickling
    def __init__(self, ident):
        self.ident = ident

    def __call__(self, g, a):
        return save_images_from_generation(a.population, g, self.ident)

# Check torch random state, used across all libraries. Caution setting fixed seeds as it affects not only generation but also variation.
print(torch.random.get_rng_state())

In [None]:
from evolutionary_prompt_embedding.argument_types import PooledPromptEmbedData
from evolutionary_prompt_embedding.image_creation import SDXLPromptEmbeddingImageCreator
from evolutionary_prompt_embedding.variation import \
    UniformGaussianMutatorArguments, PooledUniformGaussianMutator, PooledArithmeticCrossover, PooledUniformCrossover
from evolutionary_prompt_embedding.value_ranges import SDXLTurboEmbeddingRange, SDXLTurboPooledEmbeddingRange
from evolutionary.evolutionary_selectors import TournamentSelector, RouletteWheelSelector, RankSelector
from evolutionary.algorithms.island_model import IslandModel
from evolutionary.algorithms.ga import GeneticAlgorithm
from evolutionary_imaging.evaluators import AIDetectionImageEvaluator, AestheticsImageEvaluator

population_size = 4
num_generations = 4
batch_size = 1
elitism = None
inference_steps = 3

visual_arts_by_epoch = [
    # Prehistoric Art (c. 40,000–4,000 BCE)
    ["Cave paintings at Lascaux (France)", "Cave paintings at Altamira (Spain)", "Venus of Willendorf"],
    
    # Ancient Near East Art (c. 3500–331 BCE)
    ["Sumerian art (e.g., Standard of Ur)", "Akkadian art (e.g., Victory Stele of Naram-Sin)", "Babylonian art (e.g., Ishtar Gate)"],
    
    # Ancient Egyptian Art (c. 3100–30 BCE)
    ["Ancient Egyptian tomb paintings", "Sculptures such as The Great Sphinx of Giza and the Statues of Ramses II at Abu Simbel"],
    
    # Aegean Art (c. 3000–1100 BCE)
    ["Minoan art (e.g., Frescoes at Knossos)", "Mycenaean art (e.g., Gold funerary masks)"],
    
    # Ancient Greek Art (c. 900–31 BCE)
    ["Archaic sculptures (Kouroi and Korai)", "Classical sculptures (e.g., Discobolus, Doryphoros)", "Hellenistic sculptures (e.g., Laocoön and His Sons)"],
    
    # Ancient Roman Art (c. 753 BCE–476 CE)
    ["Roman mural paintings (e.g., Frescoes in Pompeii and Herculaneum)", "Sculpture (e.g., Augustus of Prima Porta)", "Architecture (e.g., Colosseum, Pantheon)"],
    
    # Medieval (5th to 14th century)
    ["Art of Cimabue", "Art of Giotto di Bondone", "Art and illuminated manuscripts of Hildegard of Bingen", "Art of Duccio di Buoninsegna", "Art of Simone Martini"],
    
    # Renaissance (14th to 17th century)
    ["Art and inventions of Leonardo da Vinci", "Sculptures, paintings, and architectural works of Michelangelo Buonarroti", "Paintings of Raphael", "Art and theoretical writings of Albrecht Dürer", "Art of Titian"],
    
    # Baroque (17th century)
    ["Art of Caravaggio", "Paintings and etchings of Rembrandt van Rijn", "Art of Peter Paul Rubens", "Art of Johannes Vermeer", "Art of Diego Velázquez"],
    
    # Neoclassicism & Romanticism (late 18th to mid-19th century)
    ["Art of Jacques-Louis David", "Paintings of Jean-Auguste-Dominique Ingres", "Art of Francisco Goya", "Landscapes and marine paintings of J.M.W. Turner", "Art of Eugène Delacroix"],
    
    # Modern (late 19th to mid-20th century)
    ["Impressionist paintings of Claude Monet", "Post-Impressionist art of Vincent van Gogh", "Cubist works of Pablo Picasso", "Abstract art of Wassily Kandinsky", "Surrealist art of Salvador Dalí"],
    
    # Contemporary (mid-20th century to present)
    ["Abstract expressionist art of Jackson Pollock", "Pop art of Andy Warhol", "Neo-expressionist art of Jean-Michel Basquiat", "Street art of Banksy", "Art and installations of Yayoi Kusama"]
]

epochs = ["Prehistoric", "Ancient Near East", "Ancient Egypt", "Aegean", "Ancient Greece", "Ancient Rome", "Medieval", "Renaissance", "Baroque", "Neoclassicism Romanticism", "Modern", "Contemporary"]

# Initialize GA instances
embedding_range = SDXLTurboEmbeddingRange()
pooled_embedding_range = SDXLTurboPooledEmbeddingRange()
# Create the necessary components for the genetic algorithm
creator = SDXLPromptEmbeddingImageCreator(batch_size=batch_size, inference_steps=inference_steps)
evaluator = AestheticsImageEvaluator() 
crossover = PooledArithmeticCrossover(crossover_rate=0.5, crossover_rate_pooled=0.5)  
mutation_arguments = UniformGaussianMutatorArguments(mutation_rate=0.05, mutation_strength=2.5, 
                                                     clamp_range=(embedding_range.minimum, embedding_range.maximum)) 
mutation_arguments_pooled = UniformGaussianMutatorArguments(mutation_rate=0.05, mutation_strength=0.4, 
                                                            clamp_range=(pooled_embedding_range.minimum, pooled_embedding_range.maximum))
mutator = PooledUniformGaussianMutator(mutation_arguments, mutation_arguments_pooled)
selector = RankSelector(selection_pressure=2.0)

ga_instances = []

for i, epoch in enumerate(visual_arts_by_epoch):
    work_count = len(epoch)
    init_args = [creator.arguments_from_prompt(epoch[i % work_count]) for i in range(population_size)] 
    save_images_post_evaluation = SaveImagesPostEvaluation(i)
 
    # Create and run the genetic algorithm
    ga_instances.append(GeneticAlgorithm(
        population_size=population_size,
        num_generations=num_generations,
        solution_creator=creator,
        evaluator=evaluator,
        mutator=mutator,
        crossover=crossover,
        selector=selector,
        initial_arguments=init_args,
        elitism_count=elitism,
        post_evaluation_callback=save_images_post_evaluation,
    ))

In [None]:
island_model = IslandModel(
    ga_instances,
    migration_size=1,
    migration_interval=2,
)

In [None]:
best_solutions = island_model.run()

In [None]:
from diffusers.utils import make_image_grid

# Show best solution
for i, best_solution in enumerate(best_solutions):
    print(f"Best solution for epoch {epochs[i]}: {best_solution.fitness}")

make_image_grid([image for solution in best_solutions for image in solution.result.images], 2, batch_size * len(best_solutions) // 2)

## Visualize the evolution

In [None]:
for gen in range(num_generations):
    create_generation_image_grid(gen, max_images=10, ident_mapper=epochs)

In [None]:
video_loc = create_animation_from_generations(num_generations)
print(video_loc)

## Plot fitness statistics

In [None]:
plot_fitness_statistics(num_generations, island_model.best_fitness, island_model.worst_fitness, island_model.avg_fitness)

In [None]:
!jupyter nbconvert --to html ga_notebook.ipynb

### Save the run to disk

In [None]:
import pickle
import os
from datetime import datetime

os.makedirs("saved_runs", exist_ok=True)
output_file = os.path.join("saved_runs", f"island_model_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl")
with open(output_file, "wb") as f:
    pickle.dump(island_model, f)
print(f"Run saved to {output_file}")
    

### Load the run from disk 
Notebook and library versions should match with the saved run

In [None]:
import pickle
import os

with open(os.path.join("saved_runs", "insert_filename"), "rb") as f:
    island_model = pickle.load(f)

## Fallback functions for when something went wrong

### Access Best Solution from Disk

In [None]:
import os
import glob
import evolutionary_imaging.processing as ip
from PIL import Image

num_generations = 22  # Set this to the number of generations you ran (if you didn't finish)
generation_dir = os.path.join(ip.RESULTS_FOLDER, f"{num_generations}")
image_files = glob.glob(os.path.join(generation_dir, "*.png"))
image_files.sort(key=ip.fitness_filename_sorting_key, reverse=True)
print(image_files[0])
Image.open(image_files[0])

### ffmpeg is not installed, create GIF instead

In [None]:
from evolutionary_imaging.processing import create_animation_from_generations_pil
video_loc = create_animation_from_generations_pil(num_generations)
print(video_loc)