In [None]:
# 選択したcheckpointから、種ごとに代表個体と情報を表示する

import sys
import statistics
import os
from pathlib import Path
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import re
import ipywidgets as widgets
from IPython.display import display, Image
from PIL import Image as PILImage
# The NEAT-Python library imports
import modneat
from modneat import visualize
# import task
current_dir = str(Path().resolve())
sys.path.append(os.path.join(current_dir, '..'))
import settings.task as task

#load experiment settings
experiment_settings = modneat.report_funcs.load_experiment_settings(path_of_command_file="../settings/command")

#load checkpoint
checkpoints_path = '../checkpoints/'
checkpoints = os.listdir(checkpoints_path)
sorted_checkpoints = sorted(checkpoints, key=lambda x: int(x.split('-')[1]))

# 選択メニュー用のウィジェットを作成
dropdown = widgets.Dropdown(
    options=sorted_checkpoints,  # リストから選択肢を指定
    description='Select a checkpoint:',
    disabled=False,
)

global selected_checkpoint
selected_checkpoint = sorted_checkpoints[0]

# 選択メニューの値が変更されたときに実行される関数
def on_change(change):
    global selected_checkpoint
    if change['type'] == 'change' and change['name'] == 'value':
        selected_checkpoint = change['new']
        print(f"Selected checkpoint: {selected_checkpoint}")

# 選択メニューの値が変更されたときに on_change 関数を呼び出す
dropdown.observe(on_change)

# ウィジェットを表示
display(dropdown)
print(f"Selected checkpoint: {selected_checkpoint}")

In [None]:
def investigate_checkpoint(checkpoint_path, is_display=False):
    print(checkpoint_path)
    generation = int(checkpoint_path.split('-')[-1])
    generation += 1
    print('============================ infomation of ', generation, 'th generation ============================')
    populations = modneat.Checkpointer.restore_checkpoint(checkpoint_path)
    s = populations.species
    species_num = len(s.species)
    species_id_set = s.species.keys()
    print(species_id_set)
    #singleSpeciesObject = list(s.species.items())[0][1]
    #print(singleSpeciesObject.members)
    for sid in species_id_set:

        #print species info
        print(' >> species id: ', sid)
        print(' created: ', s.species[sid].created) #when the species was created
        print(' last improved: ', s.species[sid].last_improved)
        fitness_list = s.species[sid].get_fitnesses()
        print( sorted(s.species[sid].get_fitnesses(), reverse=True))
        members = s.species[sid].members
        member_list = list(s.species[sid].members.values())
        member_list.sort(key=lambda x: x.fitness, reverse=True)

        print('member_num: ', len(member_list))
        print('best_fitness: ', member_list[0].fitness)
        print('median_fitness: ', member_list[int(len(member_list)*0.5)].fitness)
        print('worst_fitness: ', member_list[-1].fitness)
        print('fitness_stddv: ', statistics.stdev(fitness_list))

        best_member = member_list[0]
        top25_member = member_list[int(len(member_list)*0.25)]
        median_member = member_list[int(len(member_list)*0.5)]
        top75_member = member_list[int(len(member_list)*0.75)]
        worst_member = member_list[-1]

        print(' number: ', len(members))
        tmp_img_dir = './cache/g' + str(generation) + '/s' + str(sid) + '/'
        for g in members.items():
            if not os.path.exists(tmp_img_dir):
                os.makedirs(tmp_img_dir)
            visualize.draw_net(config = populations.config, genome=best_member, view=False, filename=tmp_img_dir + 'best')
            visualize.draw_net(config = populations.config, genome=top25_member, view=False, filename=tmp_img_dir + 'top25')
            visualize.draw_net(config = populations.config, genome=median_member, view=False, filename=tmp_img_dir + 'top50')
            visualize.draw_net(config = populations.config, genome=top75_member, view=False, filename=tmp_img_dir + 'top75')
            visualize.draw_net(config = populations.config, genome=worst_member, view=False, filename=tmp_img_dir + 'worst')

        #Create a compiled image
        image_paths = [tmp_img_dir + 'best.png', tmp_img_dir + 'top25.png', tmp_img_dir + 'top50.png', tmp_img_dir + 'top75.png', tmp_img_dir + 'worst.png']
        images = [PILImage.open(x) for x in image_paths]
        widths, heights = zip(*(i.size for i in images))
        total_width = sum(widths)
        max_height = max(heights)
        new_im = PILImage.new('RGB', (total_width, max_height))
        x_offset = 0
        for im in images:
            new_im.paste(im, (x_offset,0))
            x_offset += im.size[0]
        new_im.save(tmp_img_dir + 'compiled.png')
    
        if(is_display):
            disp_imaege = tmp_img_dir + '/compiled.png'
            display(Image(disp_imaege))
            
investigate_checkpoint('../checkpoints/' + selected_checkpoint, is_display=True)