In [None]:
import torch

if torch.cuda.cuda_is_available():
    deivce = 'cuda'
elif torch.backends.mps.is_available():
    deivce = 'mps'
else:
    device = 'cpu'

In [None]:
# blur trajectory visulization

import torch.nn.functional as F
import matplotlib.pyplot as plt

from utils.synthesis import make_c2w
from utils.utils import visualize_blur_trajectories

def randn_vector_linspace(num_poses):
    start, end = torch.randn(3), torch.randn(3)
    return torch.stack([torch.linspace(start[i], end[i], num_poses) for i in range(3)], dim=1)

num_poses = 8

r = randn_vector_linspace(num_poses) * 0.05
t = randn_vector_linspace(num_poses)

grid_2d_rigid = F.affine_grid(
    make_c2w(r, t)[:, :2, :3], 
    [num_poses, 3, 256, 256], # image must be square for circular motion pattern
    align_corners=True
)
img = visualize_blur_trajectories(grid_2d_rigid, spacing=32, colormap=None)

plt.imshow(img, cmap='gray')
plt.axis('off')
plt.show()

In [None]:
# inference example

from config import TrainAugConfig as cfg
from models.nafnet import NAFNetGrid
from utils.synthesis import blur_synthesis
from utils.dataset import get_img
from utils.utils import load_ckpt

ckpt = ''
sharp_path = ''
blur_path = ''

model = NAFNetGrid(**cfg().nafnet_grid_params)
load_ckpt(ckpt, model)

sharp = get_img(sharp_path)
blur = get_img(blur_path)

results = blur_synthesis(model(blur), blur, sharp)
pred_blur = results['pred_blur']

plt.imshow(pred_blur)
plt.axis('off')
plt.show()