In [None]:
import cv2
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm
from pathlib import Path
from torchvision.io import read_image
from eops.deconv import fft_admm_tv
from elayers.admmdeconv import ADMMDeconv

In [None]:
def torch_abs2(x: torch.Tensor) -> torch.Tensor:
    return torch.pow(torch.abs(x), 2)


def hard_thresh(x: torch.Tensor, tau: float) -> torch.Tensor:
    return x * (torch.abs(x) > tau)


def soft_thresh(x: torch.Tensor, tau: float) -> torch.Tensor:
    return torch.sign(x) * torch.maximum(torch.abs(x)-tau, torch.tensor([0]))


def block_thresh(x: torch.Tensor, tau: torch.Tensor) -> torch.Tensor:
    return torch.maximum(1 - tau / pixelnorm(x), torch.tensor([0])) * x


def pixelnorm(x: torch.Tensor) -> torch.Tensor:
    return torch.sqrt(torch.sum(torch.pow(x, 2), (0, 1)))


def identity(x: torch.Tensor) -> torch.Tensor:
    return x

In [None]:
def get_im_hash(img: np.ndarray) -> str:
    # imgg = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    h=cv2.img_hash.pHash(img) # 8-byte hash
    pH=hex(int.from_bytes(h.tobytes(), byteorder='big', signed=False))
    return pH

In [None]:
def get_images(src_dir):
    ims = []
    images = list(src_dir.glob('*.png'))

    for image in images:
        img = cv2.imread(str(image), cv2.IMREAD_COLOR)
        ims.append(img)

    return ims

In [None]:
def blur_gaussian(images, k_shape=(17, 17), std=2.4):
    blurred = []

    for img in tqdm(images, total=len(images)):
        # add gaussian blurring
        blur = cv2.GaussianBlur(img, k_shape, std)
        blurred.append(blur)

    return blurred

In [None]:
def add_gaussian_noise(images, mean=0, var=0.177):
    noisy = []
    
    for img in tqdm(images):
        noise = np.random.normal(loc=mean, scale=var, size=img.shape)
        # Noise overlaid over image
        img = np.clip((img + noise), 0, 255)
        noisy.append(img)

    return noisy

In [None]:
def add_cv2randn(images, mean=0, stdv=25):
    noisy = []
    
    for img in tqdm(images):
        dst = np.zeros_like(img)
        noise = cv2.randn(dst, (mean,mean,mean), (stdv,stdv,stdv))
        # Noise overlaid over image
        imgn = cv2.add(img, noise)
        noisy.append(imgn)

    return noisy

In [None]:
ims_p = Path('test_imgs')
imgs = get_images(ims_p)

## Add blur and noise

In [None]:
blurs = blur_gaussian(imgs, (7, 7), 1.5)

In [None]:
noisy = add_cv2randn(blurs, 0, 20)

In [None]:
# plt.figure(figsize=(10,8))
plt.imshow(cv2.cvtColor(noisy[0], cv2.COLOR_BGR2RGB))

In [None]:
mean = 0               # Mean of the Gaussian noise
std_dev = 20 / 255           # Standard deviation of the Gaussian noise

# Generate Gaussian noise
gaussian_noise = torch.clamp(torch.randn(blurs[0].shape) * std_dev + mean, 0.0, 1.0)

# Add noise to the original image
noisy_image = (blurs[0] / 255) + gaussian_noise.numpy()

# Clip the values to be in the proper range [0, 255] for an 8-bit image
noisy_image = np.clip(noisy_image * 255, 0, 255).astype(np.uint8)

In [None]:
plt.imshow(cv2.cvtColor(noisy_image, cv2.COLOR_BGR2RGB))

In [None]:
xin1 = torch.tensor(noisy[0]/255, dtype=torch.float32).permute(2,0,1)[torch.newaxis,...]
xin2 = torch.tensor(noisy[1]/255, dtype=torch.float32).permute(2,0,1)[torch.newaxis,...]
k = cv2.getGaussianKernel(7, 1.5)
k = k @ k.T
k = torch.tensor(k, dtype=torch.float32)[torch.newaxis, torch.newaxis,...]
lmb = torch.tensor([0.02])
rho = torch.tensor([0.02])
xin = torch.cat((xin1, xin2), 0)

In [None]:
r = fft_admm_tv(torch.tensor(xin1[0])[torch.newaxis,...], lmb, rho, k, True, 300)

In [None]:
r.shape

In [None]:
rrr = r[0,:,:,:].permute((0,1,2))

In [None]:
rrr = rrr.permute((1,2,0)

In [None]:
# plt.figure(figsize=(10,8))
plt.imshow(cv2.cvtColor(rrr.permute((1,2,0)).detach().numpy(), cv2.COLOR_BGR2RGB))

In [None]:
l = ADMMDeconv((3,3), max_iters=150, lmbda=0.02, rho=0.04, iso=False)

In [None]:
a = l(xin)