In [None]:
!pip install kornia
import kornia

In [None]:
import torch
import matplotlib.pyplot as plt
from torch import nn, Tensor
import torch.nn.functional as F
import torch.optim as optim
import tqdm
import numpy as np
from urllib.request import urlopen
import cv2

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def imshow_tensor(img, title=''):
  plt.imshow(img[0, 0,:,:].to('cpu'), vmin=0, vmax=1)
  plt.title(title)
  plt.axis(False)
  plt.show()


In [None]:
def url_to_tensor(url, readFlag=cv2.IMREAD_COLOR):
    # download the image, convert it to a NumPy array, and then read
    # it into OpenCV format
    resp = urlopen(url)
    image = np.asarray(bytearray(resp.read()), dtype="uint8")
    image = cv2.imdecode(image, readFlag)

    # return the image
    image: torch.Tensor = kornia.utils.image_to_tensor(image)  # CxHxW
    image = image[None,...].float() / 255.
    image = kornia.color.bgr_to_rgb(image)
    return image.to(device)

In [None]:
def draw_cricle(size, wh):
  x, y = torch.meshgrid(torch.arange(0, size[2]).to(device), torch.arange(0, size[3]).to(device))
  x = x.to(torch.float32)
  y = y.to(torch.float32)
  circle = 1 - 2*(((x - size[2]/2)**2) / wh[0]**2 + ((y - size[3]/2)**2 / wh[1]**2))
  circle[circle>0] = 1
  circle[circle<0] = 0
  return circle

def translate_shape_img(img, dx, dy, rotation, scale=None):
  angle: torch.tensor = torch.ones(1, requires_grad=True).to(device) * rotation
  center: torch.tensor = torch.ones(1, 2).to(device)
  center[..., 0] = img.shape[3] / 2  # x
  center[..., 1] = img.shape[2] / 2  # y
  if scale is None:
    scale =  (torch.tensor([1,1])[None,:]).to(torch.float32).to(device)
  M: torch.tensor = kornia.geometry.get_rotation_matrix2d(center, angle, scale)  # 1x2x3
  _, _, h, w = img.shape
  x_rotated: torch.tensor = kornia.geometry.warp_affine(img, M.to(device), dsize=(h, w))
  translation = torch.cat((dx.unsqueeze(0), dy.unsqueeze(0))).unsqueeze(0) #torch.tensor([[dx, dy]], dtype=torch.float32, requires_grad=True)
  out = kornia.geometry.translate(x_rotated, translation)
  return out

def add_circle_on_image(image, circle, alpha):
  return image + alpha*circle


In [None]:
gt_img = url_to_tensor('https://cms.uni-konstanz.de/fileadmin/archive/informatik-saupe/fileadmin/informatik/ag-saupe/Webpages/lehre/dip_w0910/pictures/cameraman.tif', readFlag=cv2.IMREAD_COLOR)[:,:1,:,:]


In [None]:
n_circles = 2000
n_inner_opts = 200
circles = torch.rand(n_circles, 2).to(device)

In [None]:
criterion = nn.MSELoss()

In [None]:
# TODO: LR SPECIFID TO EACH PARAM
# FIND IF ACTUALLY DOES ANYTHING IN THE INNER LOOP !!!!

In [None]:
image = torch.zeros_like(gt_img).detach()
best_error = criterion(gt_img, image).detach()
size = image.shape

for cid, wh in tqdm.tqdm(enumerate(circles), total=len(circles)):
  wh = wh*torch.tensor(size[2:]).to(device)/1.5
  dx_, dy_, alpha = torch.rand(3).to(device)
  
  dx = (0.5-dx_)*size[2]
  dy = (0.5-dy_)*size[3]
  rotation = 180*(0.5 - torch.rand(1).to(device))
  scale = 2*(torch.rand(2).unsqueeze(0).to(device))
  for member in [dx, dy, rotation, alpha, wh, scale]:
    member.requires_grad = True

  new_circle = draw_cricle(size, wh)[None,None,...].detach()
  optimizer = optim.Adam([dx, dy, alpha, rotation, scale], lr=0.5)

  for i in range(n_inner_opts):
    optimizer.zero_grad()
    new_circle_trans = translate_shape_img(new_circle.clone(), dx, dy, rotation, scale)
    optional_image = torch.clip(add_circle_on_image(image, new_circle_trans, alpha), 0, 1)
    tmp_error = criterion(gt_img, optional_image) 
    tmp_error.backward()
    optimizer.step()
  #imshow_tensor(optional_image.detach(), 'optional_img')

  
  if tmp_error < best_error:
    print(f'found improvement: {best_error} to {tmp_error}')
    print(f'dx: {dx.detach().cpu().numpy().ravel()[0]:.1f}, dy: {dy.detach().cpu().numpy().ravel()[0]:.1f}, wh: {wh.detach().cpu().numpy()}, alpha: {alpha.detach().cpu().numpy().ravel()[0]:.3f}')
    image = optional_image.detach()
    if cid%10==0: ####<<<<<<<<<<<<
      imshow_tensor(optional_image.detach(), 'improved_img')

    best_error = tmp_error

In [None]:
plt.figure(figsize=[8,8])
plt.imshow(np.hstack([image[0,0,:,:].cpu().numpy(), gt_img[0,0,:,:].cpu().numpy()]), vmin=0, vmax=1, cmap='gray')
plt.axis(False)
plt.show()
