In [5]:
from pathlib import Path
import numpy as np
from skimage import io
from skimage.transform import rescale, resize
import torch
from transformers import Dinov2Backbone
from torchvision.models import resnet18
from scipy.ndimage import gaussian_filter
from scipy import ndimage

In [6]:
backbone = resnet18(pretrained=True)

In [7]:
children = list(backbone.children())
newmodel = torch.nn.Sequential(*children[0:6])
newmodel

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Con

In [48]:
# === Load last image before max
data_path = Path("data")
file_list = sorted(data_path.glob("last_*jpg"))
print(file_list[-1])
last_image = io.imread(file_list[-1])

imginfo = lambda img: print(type(img), img.dtype, img.shape, img.min(), img.max())

first_image = io.imread("data/first_image.jpg")
max_image = io.imread("data/max_image.jpg")

# === load first touch coors
import json
with open("data/single_log.json") as f:
    episode_log = json.load(f)
first_x, first_y, first_z = map(int, episode_log["first_touch"])
cropw = 200

def to_batch(image):
    downscaled = resize(image, (224, 224))
    batch = torch.from_numpy(downscaled)[None, :, :, :].to(torch.float32) / 255  # [h, w, c] -> [1, h, w, c]
    batch = batch.permute(0, 3, 1, 2)  # [1, h, w, c] -> [1, c, h, w]
    return batch

def crop(image):
    """
    :param image: [h, w, c]
    """
    return image[first_y-cropw:first_y+cropw, first_x-cropw:first_x+cropw]

io.imsave("data/result_crop_a.jpg", crop(first_image).astype(np.uint8))
io.imsave("data/result_crop_b.jpg", crop(max_image).astype(np.uint8))
io.imsave("data/result_crop_last.jpg", crop(last_image).astype(np.uint8))

# === Run model on crops
batch = to_batch(crop(first_image))
imginfo(batch)
dino1 = newmodel(batch).detach()
dino2 = newmodel(to_batch(crop(max_image))).detach()
dino_last = newmodel(to_batch(crop(last_image))).detach()

imginfo(dino1)
imginfo(dino2)

def scale(diff_image: np.array):
    low, high = np.quantile(diff_image, 0.1), np.quantile(diff_image, 0.99)
    return np.clip(diff_image / high * 255, 0, 255)

diff = dino1 - dino2
diff = (diff[0, :, :, :] ** 2).sum(axis=0).numpy()

diff_last_max = dino2 - dino_last
diff_last_max = (diff_last_max[0, :, :, :] ** 2).sum(axis=0).numpy()

sub = diff - gaussian_filter(diff_last_max, sigma=2) * 2

# sub = ndimage.median_filter(sub, size=2)
crop_shape = crop(first_image).shape
io.imsave("data/result_diff.jpg", resize(scale(sub), crop_shape[:2]).astype(np.uint8))
io.imsave("data/result_before_sub.jpg", resize(scale(diff), crop_shape[:2]).astype(np.uint8))


# === Simple difference
imginfo(first_image)
diff = crop(first_image.astype(np.float32)) - crop(max_image.astype(np.float32)) # careful, must be float!
diff = np.sum(diff ** 2, axis=2)
io.imsave("data/result_simple.jpg", resize(scale(diff), crop_shape[:2]).astype(np.uint8))

data/last_0150.jpg
<class 'torch.Tensor'> torch.float32 torch.Size([1, 3, 224, 224]) tensor(0.) tensor(0.0039)
<class 'torch.Tensor'> torch.float32 torch.Size([1, 128, 28, 28]) tensor(0.) tensor(4.3697)
<class 'torch.Tensor'> torch.float32 torch.Size([1, 128, 28, 28]) tensor(0.) tensor(4.1345)
<class 'numpy.ndarray'> uint8 (720, 1280, 3) 0 255


In [20]:
imginfo(scale(diff))

<class 'numpy.ndarray'> float32 (400, 400) 0.0 255.0


In [35]:
canvas = np.copy(first_image)

def crop_to_color(cr):
    cr = resize(scale(cr), crop_shape[:2]).astype(np.uint8)
    cr = np.tile(cr[:, :, None], (1, 1, 3))  # 1 channel to 3 channels
    cr[:, :, 1:] = 0
    return cr

def add_crop(image, cr):
    canvas = np.copy(image)
    canvcrop = canvas[first_y-cropw:first_y+cropw, first_x-cropw:first_x+cropw]
    canvas[first_y-cropw:first_y+cropw, first_x-cropw:first_x+cropw] = canvcrop * 0.4 + cr * 0.6 # use broadcasting
    return canvas

canvas = add_crop(first_image, diff)
io.imsave("data/result_overlay.jpg", scale(canvas).astype(np.uint8))