In [None]:
import wandb
import os

api = wandb.Api()

entity = "jduque"   
project = "MeltingPot" 
run_id = "ir1x5pcr"  

run = api.run(f"{entity}/{project}/{run_id}")
target_step = 20000
os.makedirs(f"wandb_videos/{run_id}/", exist_ok=True)

files = run.files()

for file in files:
    if "media/video" in file.name:  # Filters for files in the media/video/ directory
        step = int((file.name).split("_")[-2])
        if step == target_step:
            file_orig_clean_name = file.name.split("/")[-1]  # Get only the file name
            weird_hash = file_orig_clean_name.split(".gif")[0].split("_")[-1]
            file_clean_name = file_orig_clean_name.replace('_'+weird_hash, "")
            print("file_orig_clean_name", file_orig_clean_name)
            download_path = f"wandb_videos/{run_id}/{file_clean_name}"
            print(f"Downloading {file_clean_name} to {download_path}")
            file.download(root=f"wandb_videos/{run_id}/", replace=True)

            # Move the file to the correct directory without recreating the hierarchy
            os.rename(f"wandb_videos/{run_id}/media/videos/eval/{file_orig_clean_name}", download_path)

            # Optionally remove the empty directories after the move
            os.removedirs(f"wandb_videos/{run_id}/media/videos/eval/")

In [None]:
gif_file = f"wandb_videos/{run_id}/video_commons_harvest__open_0_{target_step}.gif"

import matplotlib.pyplot as plt
from PIL import Image

gif = Image.open(gif_file)

num_frames = gif.n_frames
print(f'The GIF has {num_frames} frames.')

def extract_frames(gif, frame_indices):
    frames = []
    for index in frame_indices:
        gif.seek(index)
        frames.append(gif.copy())
    return frames

frames_to_plot = [2, 4, 6]
selected_frames = extract_frames(gif, frames_to_plot)

fig, axes = plt.subplots(1, len(selected_frames), figsize=(15, 5))
for i, frame in enumerate(selected_frames):
    axes[i].imshow(frame)
    axes[i].axis('off')

plt.savefig(f"frames.pdf", format="pdf", bbox_inches='tight', dpi=300)
plt.show()

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os

# List of video files
video_files = [
    '/home/mila/j/juan.duque/projects/advantage-alignment/notebooks/wandb_videos/agent_best_aa/commons_harvest__open_0.mp4', 
    '/home/mila/j/juan.duque/projects/advantage-alignment/notebooks/wandb_videos/agent_best_naive/commons_harvest__open_0.mp4', 
    '/home/mila/j/juan.duque/projects/advantage-alignment/notebooks/wandb_videos/agent_best_sum_rewards/commons_harvest__open_0.mp4']
video_names = [os.path.basename(f) for f in video_files]

# Time steps at which to extract frames (in seconds)
time_steps = [1.0, 2.0, 3.0]  # Adjust these times as needed

def extract_frames_from_video(video_file, times):
    cap = cv2.VideoCapture(video_file)
    frames = []
    fps = cap.get(cv2.CAP_PROP_FPS)
    if not fps:
        print(f"Could not get FPS for {video_file}")
        return [None] * len(times)
    for t in times:
        frame_no = int(round(t * fps))
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_no)
        ret, frame = cap.read()
        if ret:
            # Convert the frame from BGR (OpenCV format) to RGB (matplotlib format)
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(frame)
        else:
            print(f"Could not read frame at time {t} from {video_file}")
            frames.append(None)
    cap.release()
    return frames

# Extract frames from each video at the specified time steps
all_frames = []
for video_file in video_files:
    frames = extract_frames_from_video(video_file, time_steps)
    all_frames.append(frames)

# Set up the plotting grid
num_videos = len(video_files)
num_times = len(time_steps)

fig, axes = plt.subplots(nrows=num_videos, ncols=num_times, figsize=(5 * num_times, 5 * num_videos))
axes = np.atleast_2d(axes)  # Ensure axes is 2D even if num_videos or num_times is 1

# Plot the frames
for i, frames in enumerate(all_frames):
    for j, frame in enumerate(frames):
        ax = axes[i, j]
        if frame is not None:
            ax.imshow(frame)
        else:
            ax.text(0.5, 0.5, 'No Frame', ha='center', va='center', fontsize=12)
        ax.axis('off')
    # Label each row with the video name
    axes[i, 0].set_ylabel(video_names[i], rotation=90, fontsize=16, labelpad=20)

# Label the columns with the time steps
for j, t in enumerate(time_steps):
    ax = axes[0, j]
    ax.set_title(f'Time {t} s', fontsize=16)

plt.tight_layout()
plt.savefig('frames.pdf', format='pdf', bbox_inches='tight', dpi=300)
plt.show()