In [1]:
import torch
import torch.nn
from odak.learn.tools import convolve2d,zero_pad,crop_center,load_image,save_image
from tqdm import tqdm,trange

if torch.cuda.is_available():
    device = torch.device("cuda")

In [2]:
class optimizer:

    def __init__(self, psf):
        self.psf  = zero_pad(psf)
        self.loss_func = [
                          torch.nn.MSELoss(),
                         ]
        
    def evaluate(self,scene_estimate,image_sensor_estimate,image_sensor_ground_truth,w=[1.,0.01]):
        loss        = w[0]*self.loss_func[0](image_sensor_estimate,image_sensor_ground_truth)
        return loss
        
    def forward(self,scene):
        image_sensor_padded = convolve2d(scene,self.psf)
        return crop_center(image_sensor_padded)

    def solve(self,image_sensor,n_iterations,device):
        scene         = torch.zeros(self.psf.shape).detach().to(device).requires_grad_()
        optimizer     = torch.optim.Adam(lr=0.001, params=[scene,])
        t             = tqdm(range(n_iterations),leave=True)
        zero          = torch.tensor([0]).to(device)
        for i in t:
            optimizer.zero_grad()
            reconstruction = self.forward(scene)
            loss           = self.evaluate(scene,reconstruction,image_sensor)
            description    = "Iteration:{}, Loss:{:.4f}".format(i,loss.item())
            loss.backward(retain_graph=True)
            optimizer.step()
            t.set_description(description)
        torch.no_grad()
        return crop_center(scene.detach())

In [3]:
def resize(image,mul=0.5):
    scale     = torch.nn.Upsample(scale_factor=mul, mode='bilinear')
    new_image = torch.zeros((int(image.shape[0]*mul),int(image.shape[1]*mul),3)).to(image.device)
    for i in range(3):
        cache            = image[:,:,i].unsqueeze(0)
        cache            = cache.unsqueeze(0)
        new_cache        = scale(cache).unsqueeze(0)
        new_image[:,:,i] = new_cache.unsqueeze(0)
    return new_image

In [4]:
n_iterations               = 300
mul                        = 1/2.
psf                        = load_image('images/psf.jpg').to(device).float()
image_sensor               = load_image('images/hnd.jpg').to(device).float()
psf                        = resize(psf,mul=mul)
image_sensor               = resize(image_sensor,mul=mul)
image_sensor               = (image_sensor-image_sensor.min())/(image_sensor.max()-image_sensor.min())
psf                        = (psf-psf.min())/(psf.max()-psf.min())
psf.requires_grad          = False
image_sensor.requires_grad = False

y = torch.zeros((image_sensor.shape[0],image_sensor.shape[1],3))
for i in [0,1,2]:
    solver   = optimizer(psf[:,:,i])
    y[:,:,i] = solver.solve(image_sensor[:,:,i],n_iterations,device)

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Iteration:299, Loss:0.0001: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [00:06<00:00, 46.52it/s]
Iteration:299, Loss:0.0001: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [00:06<00:00, 47.31it/s]
Iteration:299, Loss:0.0001: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [00:06<00:00, 47.28it/s]


In [5]:
print(y.max(),y.min())
z = y.detach().clone()
m = (z-z.min())/(z.max()-z.min())*255.
m = m.detach().cpu()

save_image('result.png',m)

tensor(5.1396e-05) tensor(-1.9580e-05)


True