In [19]:
## for storing the dataset
# from dataset import create_small_dataset
# !mkdir ./data/sm0llest
# create_small_dataset("./data/train/", 100, "./data/sm0llest/")
# import wandb
# wandb.init(project="wall_jepa")
# artifact = wandb.Artifact("sm0llest", type="dataset")
# artifact.add_dir("./data/sm0llest")
# wandb.log_artifact(artifact)

In [20]:
import sys
import os
sys.path.append(os.path.abspath('..'))

from dataset import create_wall_dataloader
import wandb
run = wandb.init(project="wall_jepa")
artifact = run.use_artifact('sm0llest:latest')
artifact_dir = artifact.download()

dl = create_wall_dataloader(artifact_dir, batch_size=64, train=True)

[34m[1mwandb[0m: Downloading large artifact sm0llest:latest, 54.81MB. 2 files... 
[34m[1mwandb[0m:   2 of 2 files downloaded.  
Done. 0:0:0.2


Loading data from /data/DL_24/notebooks/artifacts/sm0llest:v0 ...
Dataset size: 100
States shape: torch.Size([100, 17, 2, 65, 65])
Actions shape: torch.Size([100, 16, 2])


In [21]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.patches import Arrow
from dataset import apply_augmentations
import math

def create_subplot(ax, initial_state, title):
    wall_img = ax.imshow(initial_state[1], cmap='gray', extent=[0, 64, 64, 0])
    agent_img = ax.imshow(initial_state[0], cmap='jet', alpha=0.5, extent=[0, 64, 64, 0])
    ax.set_title(title)
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    arrow = ax.add_patch(Arrow(0, 0, 0, 0, width=0.5, color='r'))
    return wall_img, agent_img, arrow

def update_subplot(wall_img, agent_img, arrow, state, action, ax):
    wall_img.set_array(state[1])
    agent_img.set_array(state[0])
    
    if action is not None:
        agent_pos = np.unravel_index(state[0].argmax(), state[0].shape)
        arrow.remove()
        arrow = ax.add_patch(Arrow(agent_pos[1], agent_pos[0], 
                                   action[0] * 5, action[1] * 5, 
                                   width=2, color='r'))
    return arrow

def visualize_augmentation_comparison(dl, index=1):
    print(f"Visualizing original and augmented trajectory for index {index}")
    
    sample = dl.dataset[index]
    states_np = sample.states.cpu()
    actions_np = sample.actions.cpu()
    
    augmentations = {
        "Original": (states_np, actions_np),
        "Horizontal Flip": apply_augmentations(states_np, actions_np, p_aug=1, p_hflip=1, p_vflip=0, p_rot90=0, p_noise=0),
        "Vertical Flip": apply_augmentations(states_np, actions_np, p_aug=1, p_hflip=0, p_vflip=1, p_rot90=0, p_noise=0),
        "90° Rotation": apply_augmentations(states_np, actions_np, p_aug=1, p_hflip=0, p_vflip=0, p_rot90=1, p_noise=0),
        "Added Noise": apply_augmentations(states_np, actions_np, p_aug=1, p_hflip=0, p_vflip=0, p_rot90=0, p_noise=1),
        "All-nonoise": apply_augmentations(states_np, actions_np, p_aug=1, p_hflip=1, p_vflip=1, p_rot90=1, p_noise=0),
        "All": apply_augmentations(states_np, actions_np, p_aug=1),
    }

    # Calculate grid dimensions
    n_augmentations = len(augmentations)
    n_cols = 3  # You can adjust this to change the number of columns
    n_rows = math.ceil(n_augmentations / n_cols)

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(6*n_cols, 6*n_rows))
    fig.suptitle("Original vs Different Augmentations")
    
    subplots = {}
    for (title, (states, _)), ax in zip(augmentations.items(), axes.flatten()):
        subplots[title] = create_subplot(ax, states[0], title)
    
    # Hide any unused subplots
    for ax in axes.flatten()[n_augmentations:]:
        ax.set_visible(False)
    
    def update(frame):
        updated = []
        for title, (states, actions) in augmentations.items():
            wall_img, agent_img, arrow = subplots[title]
            action = actions[frame] if frame < len(actions) else None
            new_arrow = update_subplot(wall_img, agent_img, arrow, states[frame], action, wall_img.axes)
            subplots[title] = (wall_img, agent_img, new_arrow)
            updated.extend([wall_img, agent_img, new_arrow])
        return updated
    
    anim = animation.FuncAnimation(fig, update, 
                                   frames=min(states_np.shape[0], actions_np.shape[0]+1),
                                   interval=200, blit=True)
    
    plt.close(fig)
    return anim

# Use the function like this:
anim = visualize_augmentation_comparison(dl, 15)
HTML(anim.to_jshtml())

Visualizing original and augmented trajectory for index 15


In [22]:
anim = visualize_augmentation_comparison(dl, 12)
HTML(anim.to_jshtml())

Visualizing original and augmented trajectory for index 12
