In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt

In [None]:
from pathlib import Path
import pandas as pd

In [None]:
import sys
sys.path.append("..")
sys.path.append("../source")
# sys.path.append("../source/datasets")

In [None]:
from datasets.ptz_dataset import PTZImageDataset, get_position_datetime_from_labels
from transforms import make_transforms
from gen_embed import generate_embedding
from utils.analysis_viz import read_train_loss, flatten, read_fname_embed_from_h5, sort_by_time_from_label

In [None]:
root_dir = Path("/Users/yufengluo/Research/anl/su24/trainings/workflow")
img_dir = root_dir / "collected_imgs"
# code_dir = Path("/app/PTZJEPA")
model_dir = root_dir / "world_models/model_43"
infer_dir = model_dir / "inference"

In [None]:
dataset = PTZImageDataset(img_dir, transform=make_transforms(crop_size=512), return_label=True)

In [None]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False)

In [None]:
all_train, header = read_train_loss(model_dir / "jepa.csv")
df = pd.DataFrame(flatten(all_train), columns=header, dtype=float)
x_axis = np.array([row.epoch + row.itr/df.query("epoch==@row.epoch").itr.max() for row in df.itertuples()],
                  dtype=float)

In [None]:
plt.figure(figsize=(8, 6))
plt.plot(x_axis, df['loss'], "-")
for i in np.unique(df.restart):
    plt.vlines(x_axis[df.query("restart==@i").index[-1]], ymin=0, ymax=df['loss'].max()/2, color='r', linestyles='dashed')
# plt.vlines(x_axis[df.query("restart==0").index[-1]], ymin=0, ymax=300, color='r', linestyles='dashed')
plt.xlabel('epoch')
plt.ylabel('loss')

In [None]:
from source.prepare_dataset import verify_image


for fp in img_dir.glob("*.jpg"):
    verify_image(fp)

In [None]:
len(list(img_dir.glob("*.jpg")))

In [None]:
generate_embedding(model_dir / "params-ijepa.yaml", model_dir / "jepa-latest.pt",
                   img_dir, infer_dir, world_model=True, device="mps")

In [None]:
infer_dir

In [None]:
import h5py
import os

In [None]:
os.listdir(infer_dir)

In [None]:
batch_size = 4

In [None]:
reward = np.load(infer_dir / "rewards_predictor.npy", allow_pickle=True)
reward = np.mean(reward, axis=1)
reward = reward.reshape(-1, batch_size, 1)

In [None]:
with open(infer_dir / "labels.txt") as f:
    labels = [l.strip() for l in f if l.strip()]

In [None]:
pred_fnames, pred_embeds = read_fname_embed_from_h5(infer_dir / "embeds_predictor.h5")
contx_fnames, contx_embeds = read_fname_embed_from_h5(infer_dir / "embeds_contx_encoder.h5")
target_fnames, target_embeds = read_fname_embed_from_h5(infer_dir / "embeds_target_encoder.h5")

In [None]:
pred_embeds[3].shape

In [None]:
pred_embeds = np.mean(pred_embeds, axis=2)
contx_embeds = np.mean(contx_embeds, axis=2)
target_embeds = np.mean(target_embeds, axis=2)

In [None]:
pred_embed_sort, *pred_meta = sort_by_time_from_label(pred_embeds, pred_fnames)
contx_embed_sort, *contx_meta = sort_by_time_from_label(contx_embeds, contx_fnames)
target_embed_sort, *target_meta = sort_by_time_from_label(target_embeds, target_fnames)

In [None]:
import pandas as pd

In [None]:
df_pred = pd.DataFrame(pred_meta, index=["fnames", "pos", "time"]).T
df_contx = pd.DataFrame(contx_meta, index=["fnames", "pos", "time"]).T
df_target = pd.DataFrame(target_meta, index=["fnames", "pos", "time"]).T

In [None]:
pred_embed_avg = pred_embed_sort
contx_embed_avg = contx_embed_sort
target_embed_avg = target_embed_sort

In [None]:
# pred_embed_avg = np.mean(pred_embed_sort, axis=2)
# contx_embed_avg = np.mean(contx_embed_sort, axis=2)
# target_embed_avg = np.mean(target_embed_sort, axis=2)

In [None]:
pred_embed_avg.shape, contx_embed_avg.shape, target_embed_avg.shape

# Demonstrate the prediction results

The order of context & pred images are 0000 1111 2222 3333 .... 

The order of target images is 0123 0123 0123 ....

So the predictor should have a behavior closely resemble the target images,
but not exactly the same.

The zero norm values mean that the target and context encoders are producing the
same output for the same image. This is just a sanity check to make sure the
encoders are working correctly.

In [None]:
from numpy import linalg

In [None]:
linalg.norm(contx_embed_avg[:, 0] - contx_embed_avg[:, batch_size-1])

In [None]:
print(linalg.norm(target_embed_avg[0, :] - target_embed_avg[batch_size-2, :]))
print(linalg.norm(target_embed_avg[0, :] - target_embed_avg[batch_size-1, :]))
print(linalg.norm(target_embed_avg[0, :] - target_embed_avg[batch_size, :]))
# print(linalg.norm(target_embed_avg[5, :] - target_embed_avg[7, :]))

In [None]:
print(linalg.norm(pred_embed_avg[:, 0] - pred_embed_avg[:, 3]))
print(linalg.norm(pred_embed_avg[0, :] - pred_embed_avg[3, :]))

In [None]:
# now downsample the embeddings based on the ordering
samp_target_embed = target_embed_avg[::batch_size, :]
samp_target_embed = samp_target_embed.reshape(-1, samp_target_embed.shape[-1])
samp_contx_embed = contx_embed_avg[:, 0]

In [None]:
linalg.norm(samp_target_embed - samp_contx_embed)

In [None]:
from sklearn.decomposition import PCA, KernelPCA
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler, RobustScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn import metrics
from sklearn.feature_selection import VarianceThreshold
import joblib

In [None]:
def scale_pca_tsne_transform(embeds, pca_components=50):
    embeds_scaled = StandardScaler().fit_transform(embeds)
    # embeds_feat = VarianceThreshold(threshold=0.001).fit_transform(embeds_scaled)
    embeds_feat = embeds_scaled
    embeds_pca = PCA(n_components=pca_components, svd_solver='auto').fit_transform(embeds_feat)
    embeds_tsne = TSNE(n_components=2, learning_rate='auto', init='pca', perplexity=50, n_jobs=-1, random_state=0).fit_transform(embeds_pca)
    return embeds_tsne, embeds_pca

In [None]:
trans_contx_embeds = scale_pca_tsne_transform(samp_contx_embed)
# trans_pred_embeds = scale_pca_tsne_transform(pred_embed_avg)
trans_target_embeds = scale_pca_tsne_transform(samp_target_embed)

In [None]:
pred_self = np.zeros((pred_embed_avg.shape[0], pred_embed_avg.shape[-1]))
reward_self = np.zeros((reward.shape[0], reward.shape[-1]))
# 3 per batch
pred_dist_one = np.zeros((int(pred_embed_avg.shape[0]*3/4), pred_embed_avg.shape[-1]))
reward_one = np.zeros((int(reward.shape[0]*3/4), reward.shape[-1]))
# 2 per batch
pred_dist_two = np.zeros((int(pred_embed_avg.shape[0]/2), pred_embed_avg.shape[-1]))
reward_two = np.zeros((int(reward.shape[0]/2), reward.shape[-1]))
# 1 per batch
pred_dist_three = np.zeros((int(pred_embed_avg.shape[0]/4), pred_embed_avg.shape[-1]))
reward_three = np.zeros((int(reward.shape[0]/4), reward.shape[-1]))
for i, arr in enumerate(pred_embed_avg):
    pred_self[i] = arr[i % batch_size]
    reward_self[i] = reward[i, i % batch_size]
    if i % batch_size == 3:
        continue
    remind = i % batch_size
    pred_dist_one[(i//batch_size)*3 + remind] = arr[1+remind]
    reward_one[(i//batch_size)*3+remind] = reward[i, 1+remind]
    if i % batch_size == 2:
        continue
    pred_dist_two[(i//batch_size)*2 + remind] = arr[2+remind]
    reward_two[(i//batch_size)*2+remind] = reward[i, 2+remind]
    if i % batch_size == 1:
        continue
    pred_dist_three[i//batch_size] = arr[3]
    reward_three[i//batch_size] = reward[i, 3]

In [None]:
trans_pred_one = scale_pca_tsne_transform(pred_dist_one)
trans_pred_two = scale_pca_tsne_transform(pred_dist_two)
trans_pred_three = scale_pca_tsne_transform(pred_dist_three)

In [None]:
trans_pred_self = scale_pca_tsne_transform(pred_self)

In [None]:
fig, ax = plt.subplots(3, 2, figsize=(12, 15))
plt.set_cmap("viridis")
ax = ax.ravel()
ax[0].scatter(trans_contx_embeds[0][:, 0], trans_contx_embeds[0][:, 1], s=2, label="context")
ax[0].scatter(trans_target_embeds[0][:, 0], trans_target_embeds[0][:, 1], s=2, alpha=0.5, label="target")
ax[0].legend()
ax[0].set_aspect("equal")

vmin = -0.05
vmax = 0.05
ax[1].scatter(trans_pred_self[0][:, 0], trans_pred_self[0][:, 1], s=2, label="self",
              c=reward_self.flatten(), vmin=vmin, vmax=vmax)
ax[1].set_aspect("equal")
ax[1].legend()
ax[1].set_xlim([-80, 80])
ax[1].set_ylim([-80, 80])


ax[2].scatter(trans_pred_one[0][:, 0], trans_pred_one[0][:, 1], s=2, label="1 step",
              c=reward_one.flatten(), vmin=vmin, vmax=vmax)
ax[2].scatter(trans_pred_two[0][:, 0], trans_pred_two[0][:, 1], s=2, label="2 step",
              c=reward_two.flatten(), vmin=vmin, vmax=vmax)
ax[2].scatter(trans_pred_three[0][:, 0], trans_pred_three[0][:, 1], s=2, label="3 step",
              c=reward_three.flatten(), vmin=vmin, vmax=vmax)
ax[2].legend()
ax[2].set_aspect("equal")
ax[2].set_xlim([-80, 80])
ax[2].set_ylim([-80, 80])
fig.colorbar(ax[2].collections[0], ax=ax[2])

ax[3].scatter(trans_pred_one[0][:, 0], trans_pred_one[0][:, 1], s=2, label="1 step",
              c=reward_one.flatten(), vmin=vmin, vmax=vmax)
ax[3].legend()
ax[3].set_aspect("equal")
ax[3].set_xlim([-80, 80])
ax[3].set_ylim([-80, 80])

ax[4].scatter(trans_pred_two[0][:, 0], trans_pred_two[0][:, 1], s=2, label="2 step",
              c=reward_two.flatten(), vmin=vmin, vmax=vmax)
ax[4].legend()
ax[4].set_aspect("equal")
ax[4].set_xlim([-80, 80])
ax[4].set_ylim([-80, 80])

ax[5].scatter(trans_pred_three[0][:, 0], trans_pred_three[0][:, 1], s=2, label="3 step",
              c=reward_three.flatten(), vmin=vmin, vmax=vmax)
ax[5].legend()
ax[5].set_aspect("equal")
ax[5].set_xlim([-80, 80])
ax[5].set_ylim([-80, 80])

# ax[5].axis("off")

In [None]:
plt.hist(reward.flatten(), bins=np.arange(vmin, vmax, 0.01))
# plt.xlim([12, 13])
plt.ylabel("count")
plt.xlabel("reward")

In [None]:
pd.DataFrame(reward.ravel(), columns=["reward"]).describe()

In [None]:
pd.DataFrame(reward_one).describe()

In [None]:
pd.DataFrame(reward_two).describe()

In [None]:
pd.DataFrame(reward_three).describe()

In [None]:
pd.DataFrame(reward_self).describe()

In [None]:
pred_fnames = np.array(pred_fnames)

In [None]:
idx = np.where(reward_one < 10)[0]
ori_fnames = pred_fnames[idx // 3 * batch_size]
comp_fnames = pred_fnames[idx // 3 * batch_size + 1] # always the next image for step 1
for pick_idx in range(5):
    fig, ax = plt.subplots(2, 1, figsize=(6, 8))
    ax[0].imshow(plt.imread(img_dir / (ori_fnames[pick_idx] + ".jpg")))
    ax[0].set_title(ori_fnames[pick_idx])
    ax[1].imshow(plt.imread(img_dir / (comp_fnames[pick_idx] + ".jpg")))
    ax[1].set_title(comp_fnames[pick_idx])
    plt.show()

In [None]:
idx = np.where(reward_two < 10)[0]
ori_fnames = pred_fnames[idx // 2 * batch_size]
comp_fnames = pred_fnames[idx // 2 * batch_size + 2]
for pick_idx in range(10):
    fig, ax = plt.subplots(2, 1, figsize=(6, 8))
    ax[0].imshow(plt.imread(img_dir / (ori_fnames[pick_idx] + ".jpg")))
    ax[0].set_title(ori_fnames[pick_idx])
    ax[1].imshow(plt.imread(img_dir / (comp_fnames[pick_idx] + ".jpg")))
    ax[1].set_title(comp_fnames[pick_idx])
    plt.show()

In [None]:
idx = np.where(reward_three < 10)[0]
ori_fnames = pred_fnames[idx // 1 * batch_size]
comp_fnames = pred_fnames[idx // 1 * batch_size + 3]
for i in range(len(idx)):
    fig, ax = plt.subplots(2, 1, figsize=(6, 8))
    ax[0].imshow(plt.imread(img_dir / (ori_fnames[i] + ".jpg")))
    ax[0].set_title(ori_fnames[i])
    ax[1].imshow(plt.imread(img_dir / (comp_fnames[i] + ".jpg")))
    ax[1].set_title(comp_fnames[i])
    plt.show()