# Template Notebook for using an Island Model GA with Style Evolution on Islands
Notebook Version: 0.5 (27/03/2024)
* update to new evolutionary library, changes to GA, plotting statistics

## Google Colab Setup

In [None]:
# Google Colab: Execute this to install packages and setup drive
!pip install "evolutionary[all] @ git+https://git@github.com/malthee/evolutionary-diffusion.git"

In [None]:
# Mount drive to save results
from google.colab import drive
import evolutionary_imaging.processing as ip
drive.mount("/content/drive")
base_path = "/content/drive/MyDrive/evolutionary/"
ip.RESULTS_FOLDER = base_path + ip.RESULTS_FOLDER

In [None]:
# Check if GPU is available
import torch
print(torch.cuda.is_available())

## Project Setup

In [None]:
from evolutionary.plotting import plot_fitness_statistics, plot_time_statistics
import evolutionary_imaging.processing as ip
from diffusers.utils import logging
from evolutionary_imaging.processing import create_animation_from_generations, create_generation_image_grid, save_images_from_generation
import torch
import os

In [None]:
logging.disable_progress_bar() # Or else your output will be full of progress bars
logging.set_verbosity_error() # Enable again if you are having problems
os.environ["TOKENIZERS_PARALLELISM"] = "false" # To remove warning of libraries using tokenizers
# Change the results folder for images if you want to 
# ip.RESULTS_FOLDER = 'choose_your_destination'

class SaveImagesPostEvaluation:  # Class to save images and difference between islands; used to allow pickling
    def __init__(self, ident):
        self.ident = ident

    def __call__(self, g, a):
        return save_images_from_generation(a.population, g, self.ident)

# Check torch random state, used across all libraries. Caution setting fixed seeds as it affects not only generation but also variation.
print(torch.random.get_rng_state())

In [None]:
from evolutionary_prompt_embedding.argument_types import PooledPromptEmbedData
from evolutionary_prompt_embedding.image_creation import SDXLPromptEmbeddingImageCreator
from evolutionary_prompt_embedding.variation import \
    UniformGaussianMutatorArguments, PooledUniformGaussianMutator, PooledArithmeticCrossover, PooledUniformCrossover
from evolutionary_prompt_embedding.value_ranges import SDXLTurboEmbeddingRange, SDXLTurboPooledEmbeddingRange
from evolutionary.evolutionary_selectors import TournamentSelector, RouletteWheelSelector, RankSelector
from evolutionary.evaluators import CappedEvaluator, GoalDiminishingEvaluator, MultiObjectiveEvaluator
from evolutionary.algorithms.island_model import IslandModel
from evolutionary.algorithms.ga import GeneticAlgorithm
from evolutionary.algorithms.nsga_ii import NSGA_II, NSGATournamentSelector
from evolutionary_imaging.evaluators import AIDetectionImageEvaluator, AestheticsImageEvaluator, SingleCLIPIQAEvaluator

population_size = 10
num_generations = 100
batch_size = 1
elitism = None
inference_steps = 4
crossover_rate = 0.9
mutation_rate = 0.3

art_epochs = [
    "Prehistoric Art",  # c. 40,000 BCE - 4,000 BCE
    "Ancient Egyptian Art",  # c. 3,100 BCE - 332 BCE
    "Classical Greek Art",  # c. 480 BCE - 323 BCE
    "Roman Art",  # c. 500 BCE - 476 CE
    "Byzantine Art",  # c. 330 CE - 1453 CE
    "Islamic Art",  # c. 7th Century - Present
    "Romanesque Art",  # c. 1000 CE - 1200 CE
    "Gothic Art",  # c. 12th Century - 16th Century
    "Renaissance Art",  # c. 14th Century - 17th Century
    "Baroque Art",  # c. 1600 CE - 1750 CE
    "Neoclassicism",  # c. 18th Century - Early 19th Century
    "Romanticism",  # c. Late 18th Century - Mid 19th Century
    "Realism",  # c. Mid 19th Century
    "Impressionism",  # c. 1860s - 1880s
    "Modernism",  # Late 19th Century - 1970s
    "Contemporary Art"  # Post-1945 - Present
]

embedding_range = SDXLTurboEmbeddingRange()
pooled_embedding_range = SDXLTurboPooledEmbeddingRange()

creator = SDXLPromptEmbeddingImageCreator(batch_size=batch_size, inference_steps=inference_steps)
# Above this score the AestheticsImageEvaluator is biased to specific styles, use it more as quality control and leave it open 
evaluator = AestheticsImageEvaluator()
crossover = PooledArithmeticCrossover(interpolation_weight=0.8, interpolation_weight_pooled=0.8) # Try keeping original style more 
mutation_arguments = UniformGaussianMutatorArguments(mutation_rate=0.05, mutation_strength=1.5, 
                                                     clamp_range=(embedding_range.minimum, embedding_range.maximum)) 
mutation_arguments_pooled = UniformGaussianMutatorArguments(mutation_rate=0.05, mutation_strength=0.4, 
                                                            clamp_range=(pooled_embedding_range.minimum, pooled_embedding_range.maximum))
mutator = PooledUniformGaussianMutator(mutation_arguments, mutation_arguments_pooled)
selector = RouletteWheelSelector()

ga_instances = []

init_crossover = PooledArithmeticCrossover(interpolation_weight=0.8, interpolation_weight_pooled=0.8)
for i, epoch in enumerate(art_epochs):
    artist_arg = creator.arguments_from_prompt(f"in {epoch} style") 
    init_args = [init_crossover.crossover(artist_arg,  # Combine the artist with a random image, weighted towards the artist
                                          PooledPromptEmbedData(embedding_range.random_tensor_in_range(), pooled_embedding_range.random_tensor_in_range())) 
                 for _ in range(population_size)]
    save_images_post_evaluation = SaveImagesPostEvaluation(i)
 
    ga_instances.append(GeneticAlgorithm(
        population_size=population_size,
        num_generations=num_generations,
        solution_creator=creator,
        evaluator=evaluator,
        mutator=mutator,
        crossover=crossover,
        selector=selector,
        initial_arguments=init_args,
        crossover_rate=crossover_rate,
        mutation_rate=mutation_rate,
        elitism_count=elitism,
        post_evaluation_callback=save_images_post_evaluation,
    ))

In [None]:
island_model = IslandModel(
    ga_instances,
    migration_size=5,
    migration_interval=5,
)

In [None]:
best_solutions = island_model.run()

In [None]:
from diffusers.utils import make_image_grid

# Show best solution
for i, best_solution in enumerate(best_solutions):
    print(f"Best solution for epoch {art_epochs[i]}: {best_solution.fitness}")

make_image_grid([image for solution in best_solutions for image in solution.result.images], 4, batch_size * len(best_solutions) // 4)

## Visualize the evolution

In [None]:
for gen in range(num_generations):
    create_generation_image_grid(gen, images_per_row=4, max_images=16, label_fontsize=10, ident_mapper=art_epochs, group_by_ident=True)

In [None]:
video_loc = create_animation_from_generations(num_generations)
print(video_loc)

## Plot statistics

In [None]:
stats = island_model.statistics
plot_fitness_statistics(num_generations, stats.best_fitness, stats.worst_fitness, stats.avg_fitness)

In [None]:
plot_time_statistics(stats.evaluation_time, stats.creation_time)

### Save the run to disk

In [None]:
import pickle
import os
from datetime import datetime

os.makedirs("saved_runs", exist_ok=True)
output_file = os.path.join("saved_runs", f"island_model_goodepochs{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl")
with open(output_file, "wb") as f:
    pickle.dump(island_model, f)
print(f"Run saved to {output_file}")
    

### Load the run from disk 
Notebook and library versions should match with the saved run

In [None]:
import pickle
import os

with open(os.path.join("saved_runs", "insert_filename.pkl"), "rb") as f:
    island_model = pickle.load(f)

## Fallback functions for when something went wrong

### Access Best Solution from Disk

In [None]:
import os
import glob
import evolutionary_imaging.processing as ip
from PIL import Image

num_generations = 22  # Set this to the number of generations you ran (if you didn't finish)
generation_dir = os.path.join(ip.RESULTS_FOLDER, f"{num_generations}")
image_files = glob.glob(os.path.join(generation_dir, "*.png"))
image_files.sort(key=ip.fitness_filename_sorting_key, reverse=True)
print(image_files[0])
Image.open(image_files[0])

### ffmpeg is not installed, create GIF instead

In [None]:
from evolutionary_imaging.processing import create_animation_from_generations_pil
video_loc = create_animation_from_generations_pil(num_generations)
print(video_loc)