In [1]:
# Some useful settings for interactive work
%load_ext autoreload
%autoreload 2

%matplotlib widget

import torch
torch.set_float32_matmul_precision('high')

In [2]:
# Import the relevant modules
import numpy as np
from PIL import Image
import torch.nn.functional as F
import matplotlib.pyplot as plt
import sousvide.synthesize.rollout_generator as rg
import sousvide.synthesize.observation_generator as og
import sousvide.control.networks.feature_extractors as fe
import sousvide.utilities.feature_utilities as fu
import figs.visualize.generate_videos as gv
import figs.utilities.transform_helper as th
from sklearn.decomposition import PCA

  check_for_updates()


In [None]:
cohort = "features"

data_method = "eval_single"
eval_method = "eval_single"

scene = "mid_gate"

courses = ["traverse"]   

roster = ["clanGhostBear"]

In [None]:
# Generate Rollouts
rg.generate_rollout_data(cohort,["traverse"],"mid_gate","eval_single")
rg.generate_rollout_data(cohort,["circuit"],"backroom","eval_single")

In [None]:
data = {}

Tro,Xro,Uro,Iro = fu.extract_rollout_data(cohort,"traverse")
data["traverse"] = {"Tro":Tro,"Xro":Xro,"Uro":Uro,"Iro":Iro}

Tro,Xro,Uro,Iro = fu.extract_rollout_data(cohort,"circuit")
data["circuit"] = {"Tro":Tro,"Xro":Xro,"Uro":Uro,"Iro":Iro}

# image = Image.open('elephant1.jpeg').convert('RGB')  # ensure RGB
# image_np = np.expand_dims(np.array(image),axis=0)  # shape: (H, W, 3), dtype=uint8
# data["elephant"] = {"Tro":None,"Xro":None,"Uro":None,"Iro":image_np}

# images = []
# for i in range(1,6):
#     image = Image.open(f'ladder{i}.jpg').convert('RGB')  # ensure RGB
#     image_np =np.array(image)
#     images.append(image_np)
# data["ladder"] = {"Tro":None,"Xro":None,"Uro":None,"Iro":images}

# images = []
# for i in range(1,5):
#     image = Image.open(f'test{i}.png').convert('RGB')  # ensure RGB
#     image_np =np.array(image)
#     images.append(image_np)
# data["test"] = {"Tro":None,"Xro":None,"Uro":None,"Iro":images}

# images = []
# for i in range(1,4):
#     image = Image.open(f'cabinet{i}.jpg').convert('RGB')  # ensure RGB
#     image_np =np.array(image)
#     images.append(image_np)
# data["cabinet"] = {"Tro":None,"Xro":None,"Uro":None,"Iro":images}

In [None]:
# Load variables
vit_rg = fe.VitB16()
vit_dn = fe.DINOv2()
pca = PCA(n_components=3)

In [None]:
# imgs = data["traverse"]["Iro"]
imgs = data["circuit"]["Iro"]
# imgs = data["elephant"]["Iro"]
# imgs = data["test"]["Iro"]
# imgs = data["ladder"]["Iro"]

# Extract initial patch
N = len(imgs)
Ynn_rg,Ynn_dn = [], []
Cls_rg,Cls_dn = [], []
for img in imgs:
    with torch.no_grad():
        img_in = fu.process_image(img).unsqueeze(0)

        ynn_rg,cls_rg = vit_rg(img_in)
        ynn_dn,cls_dn = vit_dn(img_in)
        Ynn_rg.append(ynn_rg)
        Ynn_dn.append(ynn_dn)
        Cls_rg.append(cls_rg)
        Cls_dn.append(cls_dn)

In [None]:
# Static Image
img_idx = 3

pca_img_rg = pca.fit_transform(Ynn_rg[img_idx])
pca_img_dn = pca.fit_transform(Ynn_dn[img_idx])

pca_img_rg = pca_img_rg.reshape(14,14,3)
pca_img_dn = pca_img_dn.reshape(16,16,3)

for i in range(3):
    overlay1 = fu.heatmap_overlay(pca_img_rg[:,:,i],imgs[img_idx])
    overlay2 = fu.heatmap_overlay(pca_img_dn[:,:,i],imgs[img_idx])
    fig, axs = plt.subplots(1, 2, figsize=(8, 3))
    axs[0].imshow(overlay1)
    axs[1].imshow(overlay2)

overlay1 = fu.pca_overlay(pca_img_rg)
overlay2 = fu.pca_overlay(pca_img_dn)
fig, axs = plt.subplots(1, 2, figsize=(8, 3))
axs[0].imshow(overlay1)
axs[1].imshow(overlay2)

In [None]:
# Static Images
Ynn_rg_out = torch.cat(Ynn_rg,dim=0)
Ynn_dn_out = torch.cat(Ynn_dn,dim=0)

pca_img_rg = pca.fit_transform(Ynn_rg_out)
pca_img_dn = pca.fit_transform(Ynn_dn_out)

pca_img_rg = pca_img_rg.reshape(len(imgs),14,14,3)
pca_img_dn = pca_img_dn.reshape(len(imgs),16,16,3)

for i in range(3):
    overlay1 = fu.heatmap_overlay(pca_img_rg[-1,:,:,i],img)
    overlay2 = fu.heatmap_overlay(pca_img_dn[-1,:,:,i],img)
    fig, axs = plt.subplots(1, 2, figsize=(8, 3))
    axs[0].imshow(overlay1)
    axs[1].imshow(overlay2)

overlay1 = fu.pca_overlay(pca_img_rg[-1,:,:,:])
overlay2 = fu.pca_overlay(pca_img_dn[-1,:,:,:])
fig, axs = plt.subplots(1, 2, figsize=(8, 3))
axs[0].imshow(overlay1)
axs[1].imshow(overlay2)

In [None]:
# Live Video
Ynn_rg_out = Ynn_rg
Ynn_dn_out = Ynn_dn

n = 20
threshold = 0.6

Hmap_rg = np.zeros((N,img.shape[0],img.shape[1],3),dtype=np.uint8)
Hmap_dn = np.zeros((N,img.shape[0],img.shape[1],3),dtype=np.uint8)
muHmap_rg_out = np.zeros((N,img.shape[0],img.shape[1],3),dtype=np.uint8)
muHmap_dn_out = np.zeros((N,img.shape[0],img.shape[1],3),dtype=np.uint8)
for i in range(n,N):
    Ynns_rg = Ynn_rg_out[i-n:i]
    Ynns_dn = Ynn_dn_out[i-n:i]
    Ynns_rg = torch.cat(Ynns_rg,dim=0)
    Ynns_dn = torch.cat(Ynns_dn,dim=0)

    pca_img_rg = pca.fit_transform(Ynns_rg)
    pca_img_dn = pca.fit_transform(Ynns_dn)

    pca_img_rg = pca_img_rg.reshape(n,14,14,3)
    pca_img_dn = pca_img_dn.reshape(n,16,16,3)

    mu_pca_img_rg = np.mean(pca_img_rg,axis=3)
    mu_pca_img_dn = np.mean(pca_img_dn,axis=3)

    Hmap_rg[i,:,:,:] = fu.heatmap_overlay(pca_img_rg[-1,:,:,0],imgs[i],threshold=threshold)
    Hmap_dn[i,:,:,:] = fu.heatmap_overlay(pca_img_dn[-1,:,:,0],imgs[i],threshold=threshold)
    muHmap_rg_out[i,:,:,:] = fu.heatmap_overlay(mu_pca_img_rg[-1,:,:],imgs[i],threshold=threshold)
    muHmap_dn_out[i,:,:,:] = fu.heatmap_overlay(mu_pca_img_dn[-1,:,:],imgs[i],threshold=threshold)

    if i == n:
        for j in range(3):
            overlay1 = fu.heatmap_overlay(pca_img_rg[-1,:,:,j],imgs[i],threshold=threshold)
            overlay2 = fu.heatmap_overlay(pca_img_dn[-1,:,:,j],imgs[i],threshold=threshold)
            fig, axs = plt.subplots(1, 2, figsize=(8, 3))
            axs[0].imshow(overlay1)
            axs[1].imshow(overlay2)
gv.images_to_mp4(Hmap_rg, "hmap_rg.mp4", fps=20)
gv.images_to_mp4(Hmap_dn, "hmap_dn.mp4", fps=20)
gv.images_to_mp4(muHmap_rg_out, "mu_hmap_rg.mp4", fps=20)
gv.images_to_mp4(muHmap_dn_out, "mu_hmap_dn.mp4", fps=20)

In [None]:
# Live Single Frame
Ynn_rg_out = Ynn_rg
Ynn_dn_out = Ynn_dn

threshold = 0.6
fig, axs = plt.subplots(3, 2, figsize=(8, 9))
for i in range(1):
    ynn_rg = Ynn_rg_out[i]
    ynn_dn = Ynn_dn_out[i]

    pca_img_rg = pca.fit_transform(ynn_rg)
    pca_img_dn = pca.fit_transform(ynn_dn)

    pca_img_rg = pca_img_rg.reshape(n,14,14,3)
    pca_img_dn = pca_img_dn.reshape(n,16,16,3)

    for j in range(3):
        overlay1 = fu.heatmap_overlay(pca_img_rg[-1,:,:,j],imgs[i],threshold=threshold)
        overlay2 = fu.heatmap_overlay(pca_img_dn[-1,:,:,j],imgs[i],threshold=threshold)
        
        axs[j,0].imshow(overlay1)
        axs[j,1].imshow(overlay2)

plt.tight_layout()

In [None]:
# CLS dot product
idx_cls = 200
Ynn_rg_out,Cls_rg_out = Ynn_rg,Cls_rg
Ynn_dn_out,Cls_dn_out = Ynn_dn,Cls_dn

Hmap_rg = np.zeros((N,img.shape[0],img.shape[1],3),dtype=np.uint8)
Hmap_dn = np.zeros((N,img.shape[0],img.shape[1],3),dtype=np.uint8)
for i in range(N):
    ynn_rg = Ynn_rg_out[i]
    ynn_dn = Ynn_dn_out[i]
    cls_rg = Cls_rg_out[i]
    cls_dn = Cls_dn_out[i]

    rel_rg = torch.matmul(ynn_rg,cls_rg).reshape(14,14)
    rel_dn = torch.matmul(ynn_dn,cls_dn).reshape(16,16)

    Hmap_rg[i,:,:,:] = fu.heatmap_overlay(rel_rg,imgs[i],threshold=0.8)
    Hmap_dn[i,:,:,:] = fu.heatmap_overlay(rel_dn,imgs[i],threshold=0.8)

fig, axs = plt.subplots(1, 2, figsize=(8, 3))
axs[0].imshow(Hmap_rg[idx_cls,:,:,:])
axs[1].imshow(Hmap_dn[idx_cls,:,:,:])
    
gv.images_to_mp4(Hmap_rg, "cls_hmap_rg.mp4", fps=20)
gv.images_to_mp4(Hmap_dn, "cls_hmap_dn.mp4", fps=20)

In [None]:
# imgs_rf = data["ladder"]["Iro"]
imgs_rf = data["cabinet"]["Iro"]

Nrf = len(imgs_rf)
Yrf_rg,Yrf_dn = [], []
Crf_rg,Crf_dn = [], []
for img in imgs_rf:
    with torch.no_grad():
        img_in = fu.process_image(img).unsqueeze(0)

        yrf_rg,crf_rg = vit_rg(img_in)
        yrf_dn,crf_dn = vit_dn(img_in)
        Yrf_rg.append(yrf_rg)
        Yrf_dn.append(yrf_dn)
        Crf_rg.append(crf_rg)
        Crf_dn.append(crf_dn)

In [None]:
# Static Target on Live I
threshold = 0.8
targ_rg = torch.mean(torch.stack(Crf_rg),dim=0)
targ_dn = torch.mean(torch.stack(Crf_dn),dim=0)

Hmap_rg = np.zeros((N,imgs[0].shape[0],imgs[0].shape[1],3),dtype=np.uint8)
Hmap_dn = np.zeros((N,imgs[0].shape[0],imgs[0].shape[1],3),dtype=np.uint8)
for i in range(N):
    ynn_rg = Ynn_rg[i]
    ynn_dn = Ynn_dn[i]

    rel_rg = torch.matmul(ynn_rg,targ_rg).reshape(14,14)
    rel_dn = torch.matmul(ynn_dn,targ_dn).reshape(16,16)

    Hmap_rg[i,:,:,:] = fu.heatmap_overlay(rel_rg,imgs[i],threshold=threshold)
    Hmap_dn[i,:,:,:] = fu.heatmap_overlay(rel_dn,imgs[i],threshold=threshold)

gv.images_to_mp4(Hmap_rg, "targ_rg.mp4", fps=20)
gv.images_to_mp4(Hmap_dn, "targ_dn.mp4", fps=20)

In [None]:
# Static Target on Live II
threshold = 0.7
alpha = 0.3

# Initialize target
targ_dn0 = torch.mean(torch.stack(Crf_dn),dim=0)
targ_dn = 1*targ_dn0

counter = 0
Hmap_dn = np.zeros((N,imgs[0].shape[0],imgs[0].shape[1],3),dtype=np.uint8)
for i in range(N):
    ynn_dn = Ynn_dn[i]

    # Compute similiarities
    rel_dn = torch.matmul(ynn_dn,targ_dn)
    rel_dn = (rel_dn - torch.min(rel_dn)) / (torch.max(rel_dn) - torch.min(rel_dn))

    # Update target
    indices = (rel_dn > threshold).nonzero(as_tuple=True)[0]
    if len(indices) > 40:
        counter += 1
        zupd = torch.mean(ynn_dn[indices,:],dim=0)
        targ_dn = (1-alpha)*targ_dn0 + alpha*zupd

    # Generate heatmap
    rel_dn = rel_dn.reshape(16,16)
    Hmap_dn[i,:,:,:] = fu.heatmap_overlay(rel_dn,imgs[i],threshold=threshold)
print(counter)
gv.images_to_mp4(Hmap_dn, "targ_ema_dn.mp4", fps=20)

In [None]:
# Static Target on Live III
threshold = 0.90
alpha = 0.2

# Initialize target
targ_dn0 = Crf_dn[2]
targ_dn = 1*targ_dn0

Hmap_dn = np.zeros((N,imgs[0].shape[0],imgs[0].shape[1],3),dtype=np.uint8)
for i in range(N):
    ynn_dn = Ynn_dn[i]

    # Compute similiarities
    rel_dn = torch.matmul(ynn_dn,targ_dn)
    rel_dn = (rel_dn - torch.min(rel_dn)) / (torch.max(rel_dn) - torch.min(rel_dn))

    # Generate heatmap
    rel_dn = rel_dn.reshape(16,16)
    Hmap_dn[i,:,:,:] = fu.heatmap_overlay(rel_dn,imgs[i],threshold=threshold)

gv.images_to_mp4(Hmap_dn, "targ_single_dn.mp4", fps=20)

In [None]:
# Static Target on Live IV
idx_rf = 110
threshold = 0.99
alpha = 0.2

heat_ref = torch.zeros((16,16))
heat_ref[14, 4] = heat_ref[13, 4] = heat_ref[12, 4] = heat_ref[11, 4] = 1
heat_ref[10, 4] = heat_ref[10, 5] = heat_ref[ 9, 5] = heat_ref[ 8, 5] = 1
heat_ref[13, 7] = heat_ref[12, 7] = heat_ref[11, 7] = heat_ref[10, 7] = 1
heat_ref[ 9, 6] = heat_ref[ 8, 6] = heat_ref[ 7, 6] = heat_ref[ 6, 6] = 1
heat_ref[ 7, 5] = heat_ref[ 6, 5] = heat_ref[ 5, 5] = heat_ref[ 4, 5] = 1
heat_ref = heat_ref.flatten()

indices = (heat_ref > 0).nonzero(as_tuple=True)[0]
targ_dn = Ynn_dn[idx_rf][indices]

Hmap_dn = np.zeros((N,imgs[0].shape[0],imgs[0].shape[1],3),dtype=np.uint8)
for i in range(N):
    ynn_dn = Ynn_dn[i]

    # Compute similiarities
    Rel_dn = torch.matmul(targ_dn,ynn_dn.T)
    for j in range(20):
        Rel_dn[j,:] = (Rel_dn[j,:] - torch.min(Rel_dn[j,:])) / (torch.max(Rel_dn[j,:]) - torch.min(Rel_dn[j,:]))

    # Generate heatmap
    Rel_dn = Rel_dn.reshape(20,16,16)
    mask = (Rel_dn > threshold).any(dim=0).to(torch.float32)  # or .float() if needed
    Hmap_dn[i,:,:,:] = fu.heatmap_overlay(mask,imgs[i],threshold=threshold)

gv.images_to_mp4(Hmap_dn, "targ_direct_dn.mp4", fps=20)

In [None]:
# Static Target on Live V
idx_rf = 110
threshold = 0.90
alpha = 0.2

heat_ref = torch.zeros((16,16))
heat_ref[14, 4] = heat_ref[13, 4] = heat_ref[12, 4] = heat_ref[11, 4] = 1
heat_ref[10, 4] = heat_ref[10, 5] = heat_ref[ 9, 5] = heat_ref[ 8, 5] = 1
heat_ref[13, 7] = heat_ref[12, 7] = heat_ref[11, 7] = heat_ref[10, 7] = 1
heat_ref[ 9, 6] = heat_ref[ 8, 6] = heat_ref[ 7, 6] = heat_ref[ 6, 6] = 1
heat_ref[ 7, 5] = heat_ref[ 6, 5] = heat_ref[ 5, 5] = heat_ref[ 4, 5] = 1

fig,axs = plt.subplots(1, 1, figsize=(4, 4))
overlay = fu.heatmap_overlay(heat_ref,imgs[idx_rf],threshold=threshold)
axs.imshow(overlay)
heat_ref = heat_ref.flatten()

indices = (heat_ref > 0).nonzero(as_tuple=True)[0]
targ_dn = Ynn_dn[idx_rf][indices]
targ_dn = torch.mean(targ_dn,dim=0)
Hmap_dn = np.zeros((N,imgs[0].shape[0],imgs[0].shape[1],3),dtype=np.uint8)
for i in range(N):
    ynn_dn = Ynn_dn[i]

    # Compute similiarities
    rel_dn = torch.matmul(ynn_dn,targ_dn)
    rel_dn = (rel_dn - torch.min(rel_dn)) / (torch.max(rel_dn) - torch.min(rel_dn))

    # Generate heatmap
    rel_dn = rel_dn.reshape(16,16)
    Hmap_dn[i,:,:,:] = fu.heatmap_overlay(rel_dn,imgs[i],threshold=threshold)

gv.images_to_mp4(Hmap_dn, "targ_direct_mean_dn.mp4", fps=20)

In [None]:
Nro = Iro.shape[0]
height = Iro.shape[1]
width = Iro.shape[2]

Iout = np.zeros((Nro,height,width,3),dtype=np.uint8)
for p in range(Nro):
    img2 = 1*Iro[p]

    img_in2 = fu.process_image(img2).unsqueeze(0)        
    with torch.no_grad():
        ptk2 = vit(img_in2).squeeze(0).view(16,16,-1)

    cos_sims = torch.zeros((Np,16,16))
    for i in range(Np):
        for j in range(16):
            for k in range(16):
                patch2 = ptk2[j,k,:]
                cos_sims[i,j,k] = F.cosine_similarity(patches[i],patch2,dim=0)

    imgs2_hots = np.zeros((16,16))
    for i in range(Np):
        # max_idx = torch.argmax(cos_sims[i,:,:])
        # row,col = divmod(max_idx.item(), 16)

        # imgs2_hots[row,col] = 1
        imgs2_hots = (cos_sims[i,:,:] > 0.6).float()

        
    Iout[p,:,:,:] = fu.overlay_heatmap_on_image(imgs2_hots,img2)

gv.images_to_mp4(Iout,"output4.mp4",fps=20)