# 1. Define genetic algorithm suite

In [None]:
import json
from typing import Literal

from src.prediction.cnn.cnn_trainer import CNNTrainer

from src.prediction.cnn.drag_model import DragModel
from src.prediction.cnn.avg_temp_model import AvgTempModel
from src.prediction.cnn.max_temp_model import MaxTempModel

from src.ga.chromosome.vent_hole import VentHole

from src.ga.gene.shape.shape_variations import (
    circle_params,
    arrow_params,
    parabola_x_right_inner_multi_params,
    parabola_x_right_inner_params,
    wing_params,
    rose_params,
    flower_params,
    diamond_params,
    hexagon_params,
    triangle_params,
    trapezoid_params,
    triple_rectangle_params,
)
from src.ga.gene.pattern.pattern_variations import (
    circular_params,
    corn_params,
    grid_params,
)

from src.ga.ga_pipeline import GAPipeline

from src.ga.p1_initialize.init_vent import VentInitializer
from src.ga.p2_fitness.vent_fitness_cnn import Criterion, VentFitnessCalculatorCNN
from src.ga.p3_select.behaviors import (
    TournamentSelectionFilter,
    ElitismSelectionFilter,
    RouletteWheelSelectionFilter,
)
from src.ga.p4_crossover.behaviors import (
    OnePointCrossover,
    TwoPointCrossover,
    UniformCrossover,
)

import torch

# ----------------- Define the MODEL configs -----------------
# Device configuration: Use GPU if available
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")
print(f"Using device: {DEVICE}")

MODEL_PATH: dict[Literal["drag", "avg_temp", "max_temp"], str] = json.load(
    open("model_path.json")
)

# ----------------- Define the GA CONSTANTS -----------------
# 1. Define the criteria weights, w1, w2, w3
CRITERIA_WEIGHT = (1.25, 1.25, 0.25)

# 2. Define the criteria with direction, min, and max values
DRAG_CRITERION: Criterion = ("lower", 0.2, 0.3)  # Lower is better, range 0.2 to 0.3
DRAG_STD_CRITERION: Criterion = ("lower", 0, 0.05)  # Lower is better, range 0 to 0.05

AVG_TEMP_CRITERION: Criterion = (
    "lower",
    300,
    400,
)  # Lower is better, range 250K to 400K
AVG_TEMP_STD_CRITERION: Criterion = ("lower", 0, 10)  # Lower is better, range 0K to 10K

MAX_TEMP_CRITERION: Criterion = (
    "lower",
    400,
    475,
)  # Lower is better, range 300K to 500K
MAX_TEMP_STD_CRITERION: Criterion = ("lower", 0, 10)  # Lower is better, range 0K to 10K

# 3. Define the grid parameters
GRID_SCALE = 1
# Scaling factor

GRID_RESOLUTION = 2
# Grid resolution

GRID_WIDTH = 100
# Grid width, ~= canvas size

GRID_BOUND = (
    (-GRID_WIDTH / 2, GRID_WIDTH / 2),
    (-GRID_WIDTH / 2, GRID_WIDTH / 2),
)
# ----------------- Define the GA MODELS -----------------

drag_model = DragModel()
drag_model_trainer = CNNTrainer(
    model=drag_model,
    model_device=DEVICE,
    model_name="drag model",
    model_load_path=MODEL_PATH["drag"],
    grid_bound=GRID_BOUND,
    grid_bound_width=GRID_WIDTH,
    grid_resolution=GRID_RESOLUTION,
    grid_scale=GRID_SCALE,
)

avg_temp_model = AvgTempModel()
avg_temp_model_trainer = CNNTrainer(
    model=avg_temp_model,
    model_device=DEVICE,
    model_name="avg_temp model",
    model_load_path=MODEL_PATH["avg_temp"],
    grid_bound=GRID_BOUND,
    grid_bound_width=GRID_WIDTH,
    grid_resolution=GRID_RESOLUTION,
    grid_scale=GRID_SCALE,
)

max_temp_model = MaxTempModel()
max_temp_model_trainer = CNNTrainer(
    model=max_temp_model,
    model_device=DEVICE,
    model_name="max_temp model",
    model_load_path=MODEL_PATH["max_temp"],
    grid_bound=GRID_BOUND,
    grid_bound_width=GRID_WIDTH,
    grid_resolution=GRID_RESOLUTION,
    grid_scale=GRID_SCALE,
)

# ----------------- Define the GA PIPELINES -----------------
suite = GAPipeline[VentHole](
    suite_name="exp/tournament/config",
    suite_max_count=50,
    suite_min_population=20,
    suite_min_chromosome=40,
    crossover_behavior=UniformCrossover(),
    selector_behavior=TournamentSelectionFilter(tournament_size=4),
    fitness_calculator=VentFitnessCalculatorCNN(
        grid_bound=GRID_BOUND,
        grid_resolution=GRID_RESOLUTION,
        model_trainer_tuple=(
            drag_model_trainer,
            avg_temp_model_trainer,
            max_temp_model_trainer,
        ),
        criteria_weight_list=CRITERIA_WEIGHT,
        drag_criterion=DRAG_CRITERION,
        avg_temp_criterion=AVG_TEMP_CRITERION,
        max_temp_criterion=MAX_TEMP_CRITERION,
    ),
    immediate_exit_condition=lambda x: x[0] >= 5 and x[1] >= 5,
    mutation_probability=0.01,  # 1%
    population_initializer=VentInitializer(
        population_size=1000,
        grid_scale=GRID_SCALE,
        grid_resolution=GRID_RESOLUTION,
        pattern_bound=GRID_BOUND,
        pattern_gene_pool=[
            circular_params,
            corn_params,
            grid_params,
        ],
        shape_gene_pool=[
            circle_params,
            arrow_params,
            parabola_x_right_inner_multi_params,
            parabola_x_right_inner_params,
            wing_params,
            rose_params,
            flower_params,
            diamond_params,
            hexagon_params,
            triangle_params,
            trapezoid_params,
            triple_rectangle_params,
        ],
    ),
)

# 2. Run genetic algorithm

매번 run을 실행할 때마다, 개체 수, 격자 해상도 등에 따라 다르지만, **수행 시간이 길 수 있습니다.**

# 2. Run genetic algorithm

매번 run을 실행할 때마다, 개체 수, 격자 해상도 등에 따라 다르지만, **수행 시간이 길 수 있습니다.**

In [None]:
suite.run()

# 3. Plot fitness result

- `biased_fitness`: 가중치가 적용된 fitness, $w1 * p1 + w2 * p2+ ...$
- `fitness`: 가중치 없이 계산된 fitness, $p1 + p2 + ...$

In [None]:
suite.evolution_storage.plot_fitness(
    storage="fitness",
    title="PLOTTING NAME",
    xlabel="generation",
    ylabel="fitness",
)
suite.evolution_storage.plot_fitness(
    storage="biased_fitness",
    title="PLOTTING NAME",
    xlabel="generation",
    ylabel="biased fitness",
)

# 4. Analyze final population(unique)

1. population은 최종 선택된 개체만 남게 됩니다.
2. `population[i]`는 각 VentHole 유전자(Chromosome)이 도출됩니다.
3. `population[i].pattern_matrix` 에는 pattern에 격자화된 point vector가 존재합니다.
3. `population[i].gene_tuple` 을 통해 최종적인 유전자에 접근할 수 있습니다.


## 4.1 Counting chromosome distribution

In [None]:
# Count the number of patterns
import pprint

pp = pprint.PrettyPrinter(indent=4)

pattern: dict = {}
# Count the number of shapes
shapes: dict = {}

# You can choose whole population or unique population
# 1. whole population: suite.population
# 2. unique population: suite.unique_population

for pop in suite.population:
    shape_gene, pattern_gene = pop.gene_tuple
    pattern_label = pattern_gene.param.label.lower()
    shape_label = shape_gene.param.label.lower()

    if pattern_label in pattern:
        pattern[pattern_label] += 1
    else:
        pattern[pattern_label] = 1

    if shape_label in shapes:
        shapes[shape_label] += 1
    else:
        shapes[shape_label] = 1

print("> Pattern Counts")
pp.pprint(pattern)
print("> Shape Counts")
pp.pprint(shapes)

## 4.2 Visualize top performance patterns

In [None]:
from src.prediction.cnn.to_image_matrix import to_image_matrix
from src.grid.grid import Grid
import matplotlib.pyplot as plt

sorted_population = sorted(
    suite.unique_population, key=lambda x: x.biased_fitness, reverse=True
)

full_grid_matrix = Grid(bound=GRID_BOUND, k=1 / GRID_RESOLUTION).generate_grid(
    scale=1, x_major_iteration=True
)


def get_img_from_chromosome(chromosome: VentHole):
    return to_image_matrix(
        full_grid_matrix=full_grid_matrix,
        pattern_matrix=chromosome.pattern.pattern_matrix,
        grid_resolution=1 / chromosome.pattern.pattern_unit.grid.k,
        grid_width=(
            chromosome.pattern.pattern_transformation_matrix.p_bound_x_max
            - chromosome.pattern.pattern_transformation_matrix.p_bound_x_min
        ),
    )


def draw_img(img, chromosome_name: str):
    plt.figure(figsize=(3, 3))
    plt.title(chromosome_name)
    plt.imshow(img, cmap="gray")
    plt.show()


for pop in sorted_population:
    print("-" * 50)
    print(f"label: {pop.label}")
    print(f"predict: {pop.fitness_pure_result}")
    print(f"biased_fitness: {pop.biased_fitness}, fitness: {pop.fitness}")

    shape, pattern = pop.gene_tuple
    print("Design spec: shape parameter\n")
    shape.print_parameter_info()
    print("Design spec: pattern parameter\n")
    pattern.print_parameter_info()

    img = get_img_from_chromosome(pop)
    draw_img(img, pop.label)