In [1]:
import torch
import numpy as np
from bvae import BetaVAE
from dataset import VOGMaze2dOfflineRLDataset
import sys
sys.path.append('../data_gen_scripts')

In [2]:
path_to_model = 'loss0.0567590706050396_changed_inpud_decoder_1e-10_lr_0.001_last.pth'
path_to_dataset  = './dataset/visual-pointmaze-medium-navigate-v0.npz'
maze_type = 'medium'

## 1. Load Model and Dataset

In [None]:
# load model
model = BetaVAE().to('cuda:0')
model.load_state_dict(torch.load(path_to_model))
print('Model loaded')

model.eval()
# load dataset - this will take a while
data = VOGMaze2dOfflineRLDataset(dataset_url=path_to_dataset) 
print('Dataset loaded')

## 2. Visualize Map

In [None]:
def get_2d_colors(points, min_point, max_point):
    """Get colors corresponding to 2-D points."""
    points = np.array(points)
    min_point = np.array(min_point)
    max_point = np.array(max_point)

    colors = (points - min_point) / (max_point - min_point)
    colors = np.hstack((colors, (2 - np.sum(colors, axis=1, keepdims=True)) / 2))
    colors = np.clip(colors, 0, 1)
    colors = np.c_[colors, np.full(len(colors), 0.8)]

    return colors
if maze_type == 'medium':
    maze_map = [
        [1, 1, 1, 1, 1, 1, 1, 1],
        [1, 0, 0, 1, 1, 0, 0, 1],
        [1, 0, 0, 1, 0, 0, 0, 1],
        [1, 1, 0, 0, 0, 1, 1, 1],
        [1, 0, 0, 1, 0, 0, 0, 1],
        [1, 0, 1, 0, 0, 1, 0, 1],
        [1, 0, 0, 0, 1, 0, 0, 1],
        [1, 1, 1, 1, 1, 1, 1, 1],
    ]
elif maze_type == 'large':
    maze_map = [
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1],
        [1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1],
        [1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1],
        [1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1],
        [1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1],
        [1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1],
        [1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
    ]
elif maze_type == 'giant':
    maze_map = [
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1],
        [1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1],
        [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1],
        [1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1],
        [1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1],
        [1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1],
        [1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1],
        [1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1],
        [1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1],
        [1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
    ]
elif maze_type == 'teleport':
    maze_map = [
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1],
        [1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1],
        [1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1],
        [1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1],
        [1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1],
        [1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1],
        [1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
    ]

height, width = len(maze_map), len(maze_map[0])
print(height, width)
map = np.zeros((height, width, 3))
for i in range(height):
    for j in range(width):
        if maze_map[i][j] == 1:
            map[i, j] = [1, 1, 1]
        else:
            map[i, j] = get_2d_colors([[i, j]], [0, 0], [height-1, width -1])[0, :3]

import matplotlib.pyplot as plt
plt.imshow(map)


In [5]:
def modify_pos(pos):
    pos_modif = pos / 4 + 1
    return pos_modif

## Plot

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cv2
import imageio

# Update font size
plt.rcParams.update({'font.size': 24})

# Function to calculate L2 distance
def l2_distance(a, b):
    return np.linalg.norm(a - b)

# Create a list to store frames
frames = []

# 예시) 전체 맵 이미지 불러오기
base_idx = 1001
base_obs, base_pos = data.__getitem__(base_idx)
base_obs = base_obs.unsqueeze(0)  # (1, C, H, W)
base_emb = model.encode(base_obs.to('cuda:0'))[0].detach().cpu().numpy()
base_pos_x, base_pos_y = modify_pos(base_pos)

for i in range(base_idx, base_idx + 1001, 500):
    # 새 Figure & Subplots
    fig, axs = plt.subplots(1, 4, figsize=(24, 9))

    # -------------------------------
    # 1) 전체 맵 + 시작 위치 + 현재 위치
    # -------------------------------
    # (예시) full_map 을 첫 번째 subplot 에 표시
    axs[0].imshow(map)
    axs[0].set_title('Entire Map')

    # i번째 data
    obs, pos = data.__getitem__(i)
    
    # 시작 위치, 현재 위치를 맵 상에 찍기
    # ※ base_pos, pos 의 좌표계가 full_map 상의 (x, y) 픽셀 좌표와 같아야 제대로 표시됩니다.
    pos_x, pos_y = modify_pos(pos)
    axs[0].scatter(base_pos_x, base_pos_y, color='red', s=200, marker='o', label="Base Position")
    axs[0].scatter(pos_x, pos_y, color='blue', s=200, marker='o', label="Current Position")
    axs[0].invert_yaxis()
    # -------------------------------
    # 2) 관측 이미지(Observation) 표시
    # -------------------------------
    obs_unnormalized = (obs * 71.0288272312382 + 141.785487953533) / 255.0
    obs_unnormalized = obs_unnormalized.permute(1, 2, 0).numpy()
    axs[1].imshow(cv2.cvtColor(obs_unnormalized, cv2.COLOR_BGR2RGB))
    axs[1].set_title(f'Observation {i - base_idx}')

    # -------------------------------
    # 3) Reconstruction 이미지 표시
    # -----------------------------
    with torch.no_grad():
        recon, _, mu, log_var = model(obs.unsqueeze(0).to('cuda:0'))
        recon_normalized = (recon * 71.0288272312382 + 141.785487953533) / 255.0
        recon_normalized = recon_normalized.squeeze(0).permute(1, 2, 0).cpu().numpy()
        axs[2].imshow(cv2.cvtColor(recon_normalized, cv2.COLOR_BGR2RGB))
        axs[2].set_title(f'Reconstruction {i - base_idx}')

    # -------------------------------
    # 4) L2 Distance 막대 그래프
    # -------------------------------
    obs = obs.unsqueeze(0)  # (1, C, H, W)
    pos_l2_dist = l2_distance(pos, base_pos)
    latent_l2_dist = l2_distance(mu.detach().cpu().numpy(), base_emb)

    axs[3].bar(['Position', 'Latent'],
               [pos_l2_dist, latent_l2_dist],
               color=['blue', 'red'])
    axs[3].set_ylim(0, 20)
    axs[3].set_title('L2 Distances')

    plt.savefig('temp_plot.png')
    plt.close(fig)

    frame = cv2.imread('temp_plot.png')
    if (i - base_idx) % 500 == 0:        
        cv2.imwrite(f'frame_{i-base_idx}.png', cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    frames.append(frame)

imageio.mimsave(f'{path_to_model}.gif', frames, fps=5)
cv2.destroyAllWindows()