In [9]:
import torch
from torchinfo import summary
import numpy as np
from PIL import Image

from src.forward_operator.operators import cfa_operator
from src.layers import U_PDGH

In [20]:
CFA = 'kodak'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
NB_STAGES = 1

x = Image.open('images/01690.png')
x = x.resize((x.size[0] // 12, x.size[1] // 12))
x = np.array(x) / 255
RGB_SPECTRAL_STENCIL = np.array([650, 525, 480])

OP = cfa_operator(CFA, x.shape, RGB_SPECTRAL_STENCIL, 'dirac')

y = torch.tensor(OP.direct(x)[None, ...], dtype=torch.float).to(DEVICE)

In [21]:
model = U_PDGH(NB_STAGES, CFA, RGB_SPECTRAL_STENCIL, 3).to(DEVICE)
summary(model, input_size=[1, 26, 40])

Layer (type:depth-idx)                        Output Shape              Param #
U_PDGH                                        [1, 26, 40, 3]            --
├─Sequential: 1-1                             [1, 3120]                 --
│    └─primal_layer: 2-1                      [1, 3120]                 1
│    └─dual_layer: 2-2                        [1, 3120]                 1
│    │    └─Conv2d: 3-1                       [1, 16, 26, 40]           448
│    │    └─conv_block: 3-2                   [1, 16, 26, 40]           4,640
│    │    └─down_block: 3-3                   [1, 32, 26, 40]           4,640
│    │    └─conv_block: 3-4                   [1, 32, 26, 40]           18,496
│    │    └─down_block: 3-5                   [1, 64, 26, 40]           18,496
│    │    └─conv_block: 3-6                   [1, 64, 26, 40]           73,856
│    │    └─down_block: 3-7                   [1, 128, 26, 40]          73,856
│    │    └─conv_block: 3-8                   [1, 128, 26, 40]          29

In [22]:
model.eval()

with torch.no_grad():
    res = model(y)

print(y[0, 0, 0])
print(OP.adjoint(OP.direct(x))[0, 0])
print(res[0, 0, 0])

tensor(0.8170, device='cuda:0')
[0.27233115 0.27233115 0.27233115]
tensor([0.2899, 0.2899, 0.2899], device='cuda:0')
