In [7]:
import os
import glob
import imageio
import cv2
import numpy as np

def create_animation(log_dir, output_name="scatter_animation.mp4", fps=10):
    """
    Create an animation (GIF or MP4) from saved scatter plots without cropping.
    Args:
        log_dir (str): The base log folder (e.g., logs/multi_cmpnn/version_0)
        output_name (str): Output filename (.gif or .mp4)
        fps (int): Frames per second
    """
    # image_dir = os.path.join(log_dir, "images")
    # output_path = os.path.join(log_dir, output_name)
    image_dir = log_dir  
    output_path = os.path.join(log_dir, output_name)
    scatter_files = sorted(glob.glob(os.path.join(image_dir, "jiggle_epoch_*.png")))
    if not scatter_files:
        print("No scatter images found!")
        return

    frames = []
    for idx, img_path in enumerate(scatter_files):
        img = cv2.imread(img_path)
        if img is None:
            continue

        h, w, _ = img.shape

        # Resize proportionally to fit inside a 640x480 canvas
        target_w, target_h = 1920, 1080
        scale = min(target_w / w, target_h / h)
        new_w, new_h = int(w * scale), int(h * scale)
        resized = cv2.resize(img, (new_w, new_h))

        # Create white canvas and center the resized image
        canvas = 255 * np.ones((target_h, target_w, 3), dtype=np.uint8)
        x_offset = (target_w - new_w) // 2
        y_offset = (target_h - new_h) // 2
        canvas[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = resized

        # Add epoch label
        font_scale = 1
        thickness = 2
        text = f"Epoch {idx * 10}"
        (text_w, text_h), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
        cv2.putText(
            canvas,
            text,
            (10, text_h + 10),
            cv2.FONT_HERSHEY_SIMPLEX,
            font_scale,
            (0, 0, 0),
            thickness,
            cv2.LINE_AA
        )

        frames.append(canvas)

    if not frames:
        print("No valid frames found!")
        return

    if output_name.endswith(".mp4"):
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        height, width, _ = frames[0].shape
        out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
        for frame in frames:
            out.write(frame)
        out.release()
        print(f"Saved MP4 animation to: {output_path}")

    elif output_name.endswith(".gif"):
        frames_rgb = [cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frames]
        imageio.mimsave(output_path, frames_rgb, fps=fps)
        print(f"Saved GIF animation to: {output_path}")

    else:
        raise ValueError("Output filename must end with .gif or .mp4")


In [8]:
create_animation(log_dir="/home/calvin/code/cmpnn_revised/scripts/logs/multi_cmpnn/version_234/jiggle_monitor",
                output_name="jiggle.gif", fps=5)

Saved GIF animation to: /home/calvin/code/cmpnn_revised/scripts/logs/multi_cmpnn/version_234/jiggle_monitor/jiggle.gif
