In [5]:
import numpy as np
from ABC.abc import ArtificialBeeColony
import plotly.graph_objects as go
from PIL import Image
import tempfile

In [6]:
def rastrigin(point):
    x, y = point
    res = 20 + (x**2 - 10*np.cos(2*np.pi*x)) + (y**2 - 10*np.cos(2*np.pi*y))
    return res

n_bees      = 50
limit       = (n_bees // 2) * 2
max_iters   = 25

ABC = ArtificialBeeColony(n_bees = n_bees,
                          limit = limit,
                          max_iters = max_iters,
                          lower_bound = [-5.12,-5.12],
                          upper_bound = [5.12,5.12],
                          function = rastrigin,
                          )
ABC.optimize()

Running Optimization: 100%|██████████| 25/25 [00:00<00:00, 541.56it/s]


In [7]:
x = np.linspace(-5.12, 5.12, 100)
y = np.linspace(-5.12, 5.12, 100)
X, Y = np.meshgrid(x, y)
Z = rastrigin((X, Y))

plots = []


for iteration,(xx,yy) in enumerate(ABC.optimal_source_history):
    
    # Create contour plot
    fig = go.Figure(data=go.Contour(
        z=Z,
        x=x,  # x-axis values
        y=y,  # y-axis values
        colorscale='Greys',  # Colormap
        opacity=0.6,
        contours=dict(
            showlabels=False,  # Show labels on contours
            labelfont=dict(
                size=12,
                color='white'
            )
        )
    ))
    fig.update_layout(
        title=f'Iteration {iteration+1}/ {ABC.max_iters}',
        xaxis_title='X1',
        yaxis_title='X2',
        width=700,
        height=700,
    )
    #---
    fig.add_layout_image(
                dict(
                    source=Image.open(f"assets/BeeMarker.png"),
                    xref="x",
                    yref="y",
                    xanchor="center",
                    yanchor="middle",
                    x=xx,
                    y=yy,
                    sizex=0.5,
                    sizey=0.5,
                    sizing="contain",
                    opacity=1,
                    layer="above"
                )
            )
    plots.append(fig)

In [8]:
with tempfile.TemporaryDirectory() as tmpdirname:
    image_files = []

    # Save each figure as a separate image file
    for i, fig in enumerate(plots):
        # Define the file path
        file_path = f"{tmpdirname}/frame_{i}.png"
        fig.write_image(file_path, format="png", scale=2)  # Adjust 'scale' for image resolution
        image_files.append(file_path)

    # Open images and save as GIF
    images = [Image.open(file) for file in image_files]
    gif_path = "images/rastrigin_animated_optimization.gif"
    images[0].save(gif_path, save_all=True, append_images=images[1:], duration=500, loop=0)

print(f"GIF created and saved at {gif_path}")

GIF created and saved at images/rastrigin_animated_optimization.gif
