In [None]:
%load_ext autoreload
%autoreload 2
%run lvc.py

In [None]:
%load_ext autoreload
%autoreload 2

from lvc import allowed_image_sizes

allowed_image_sizes((44, 36), mode=420, cr=0.25, min_dim=200, max_dim=400)

In [None]:
%load_ext autoreload
%autoreload 2

import scipy.io as sio
import torch
from lvc import YuvImage, run_yuv420

import matplotlib.pyplot as plt

frame_mats = [
    sio.loadmat("reference/kodim23_cif_frame01.mat"),
    sio.loadmat("reference/husky_cif_frame01.mat"),
]

yuv_images_inp = [
    YuvImage(
        torch.Tensor(frame_mat["first_Y"]) / 255.0,
        torch.Tensor(frame_mat["first_U"]) / 255.0,
        torch.Tensor(frame_mat["first_V"]) / 255.0,
    )
    for frame_mat in frame_mats
]

crs = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
csnr_dbs = [-5, 0, 5, 10, 15, 20, 25, 30, 35, "inf"]
chunk_size = (36, 44)

res_zf = run_yuv420(crs, csnr_dbs, yuv_images_inp, chunk_size, "zf")
res_llse = run_yuv420(crs, csnr_dbs, yuv_images_inp, chunk_size, "llse")

report_n = 0

plot_style_zf = {
    "marker" : "+",
    "markersize" : 15,
    "label": "ZF",
}
plot_style_llse = {
    "marker" : "x",
    "markersize" : 15,
    "label": "LLSE",
}

psnrs_zf = [
    [res_zf[(cr, "inf")][report_n][0] for cr in crs],
    [res_zf[(cr, "inf")][report_n][1] for cr in crs],
    [res_zf[(cr, "inf")][report_n][2] for cr in crs],
]
psnrs_llse = [
    [res_llse[(cr, "inf")][report_n][0] for cr in crs],
    [res_llse[(cr, "inf")][report_n][1] for cr in crs],
    [res_llse[(cr, "inf")][report_n][2] for cr in crs],
]

fig, axs = plt.subplots(2, 2, figsize=(20, 20))
fig.suptitle("PSNR vs CR")
axs = axs.flatten()
for i, title in enumerate(["Y", "U", "V"]):
    axs[i].plot(csnr_dbs, psnrs_zf[i], **plot_style_zf),
    axs[i].plot(csnr_dbs, psnrs_llse[i], **plot_style_llse),
    axs[i].set_title(title)
    axs[i].legend()
    axs[i].set_ylim(ymax=70)
    axs[i].set_xlabel("CSNR (dB)")
    axs[i].set_ylabel("PSNR (dB)")
plt.tight_layout()
plt.show()

psnrs_zf = [
    [res_zf[(1.0, csnr_db)][report_n][0] for csnr_db in csnr_dbs],
    [res_zf[(1.0, csnr_db)][report_n][1] for csnr_db in csnr_dbs],
    [res_zf[(1.0, csnr_db)][report_n][2] for csnr_db in csnr_dbs],
]
psnrs_llse = [
    [res_llse[(1.0, csnr_db)][report_n][0] for csnr_db in csnr_dbs],
    [res_llse[(1.0, csnr_db)][report_n][1] for csnr_db in csnr_dbs],
    [res_llse[(1.0, csnr_db)][report_n][2] for csnr_db in csnr_dbs],
]

fig, axs = plt.subplots(2, 2, figsize=(20, 20))
fig.suptitle("PSNR vs CSNR")
axs = axs.flatten()
for i, title in enumerate(["Y", "U", "V"]):
    axs[i].plot(csnr_dbs, psnrs_zf[i], **plot_style_zf),
    axs[i].plot(csnr_dbs, psnrs_llse[i], **plot_style_llse),
    axs[i].set_title(title)
    axs[i].legend()
    axs[i].set_ylim(ymax=70)
    axs[i].set_xlabel("CSNR (dB)")
    axs[i].set_ylabel("PSNR (dB)")
plt.tight_layout()
plt.show()

In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path

import torch
from lvc import run_rgb

import matplotlib.pyplot as plt

img_dir = Path("C:/kubouch/data/kodim/raw")

rgb_images_inp = list(img_dir.glob("*.png"))

crs = [0.1, 0.25, 0.5, 0.75, 1.0]
csnr_dbs = [0, 10, 20, 30, "inf"]
chunk_size = (36, 44)

res = run_rgb(crs, csnr_dbs, rgb_images_inp, chunk_size)
res

In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from lvc import create_lvc
from utils import set_device

device = set_device()

img_dir = Path("C:/kubouch/data/kodim/raw")
image_names = list(img_dir.glob("*.png"))
images = torch.stack([
    # transforms.CenterCrop((288, 352))(
    transforms.Resize((512, 768))(
    # transforms.Resize((288, 672))(
        transforms.ToTensor()(
            Image.open(image_name).convert(mode="RGB")
        )
    )
    for image_name in image_names
])
print("Images: ", images.shape)

lvc_params_zf = {
    "csnr_db": 30,
    "cr": 1.0,
    "mode": 420,
    # "chunk_w": 32,
    # "chunk_h": 32,
    "nchunks": 256,
    "seed": 42,
    "estimator": "zf",  # llse
    "packet_loss": None, #0.0,
    "dct_w": 16,
    "dct_h": 16,
    "grouping": "vertical_uv",
}

lvc_params_llse = lvc_params_zf
lvc_params_llse["estimator"] = "llse"

lvc_zf = create_lvc(lvc_params_zf, device, half=False)
lvc_llse = create_lvc(lvc_params_llse, device, half=False)

res_zf = lvc_zf(images)
res_llse = lvc_llse(images)
print("Result: ", res_zf.shape)

psnr = 10 * torch.log10(1.0 / (res_zf - images).square().mean())
print("ZF PSNR:", psnr)

mse = (res_llse - res_zf).square().mean()
print("MSE LSE vs ZF:", mse)

fig, axs = plt.subplots(1, 1, figsize=(20, 20))
fig.suptitle("Output image")
axs.imshow(transforms.ToPILImage()(res_zf[22]))
axs.set_title("RGB ZF")
plt.tight_layout()
plt.show()

In [None]:
%load_ext autoreload
%autoreload 2

import torch

from torch import nn
from torch.profiler import profile, record_function, ProfilerActivity

torch.manual_seed(7)

N = 4
w1 = torch.empty(N, N)
w2 = torch.empty(N, N)

with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    with record_function("init_orthogonal"):
        nn.init.orthogonal_(w1)

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

# Also:
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.ortho_group.html?highlight=ortho_group#scipy.stats.ortho_group

print(torch.matmul(w1, w1.T))

torch.manual_seed(7)
nn.init.orthogonal_(w2)
print("diff: ", (w1 - w2).abs().mean())

x = torch.rand(4)
print(w1 @ x)
print(x @ w1)
print(x)
print(w1.T @ (w1 @ x))
print((x @ w1) @ w1.T)

y = torch.rand(2, 3)
Y = torch.stack([y, y, y, y])
YY = Y.permute((1, 2, 0)).unsqueeze(3)
print(YY.shape)
YY = torch.matmul(w1, YY)
YYY = torch.matmul(w1.T, YY).squeeze().permute(2, 0, 1)

# z = torch.arange(6).reshape(2,3)
# z  = torch.stack([z, z, z, z]).permute((1,2,0)).unsqueeze(3)
# wz = torch.arange(16).reshape(4,4)
# Z = torch.matmul(wz, z)

In [None]:

%load_ext autoreload
%autoreload 2

import torch
import scipy.io as sio

from lvc import LlseEstimate, Metadata

noise_power = torch.tensor(sio.loadmat("reference/sigma_noise.mat")["sigma_bruit"]).square().squeeze()
variances = torch.tensor(sio.loadmat("reference/var.mat")["Lambdan"]).squeeze()
llse_in  = torch.tensor(sio.loadmat("reference/LLSE_in.mat")["Tn"]).reshape(-1, 36, 44)
llse_mat = torch.tensor(sio.loadmat("reference/LLSE_mat.mat")["Hn"])
llse_out = torch.tensor(sio.loadmat("reference/LLSE_out.mat")["X_hatNa"]).reshape(-1, 36, 44)

metadata = Metadata((288, 352), (36, 44))
metadata.set_variances(variances)
metadata.set_noise_power(noise_power)
power = 1.0

out_tensor, out_metadata = LlseEstimate(power)(llse_in, metadata)

print("LLSE estimation vs. Matlab MAE: ", (llse_out - out_tensor).abs().mean())

In [None]:
%load_ext autoreload
%autoreload 2

import torch
import scipy.io as sio

import torchvision

from torchvision import transforms

from PIL import Image

import matplotlib.pyplot as plt

from lvc import LvcEncode, RgbToYcbcr, Downsample, SubtractMean, DCT, DctBlock, ChunkSelect, LvcDecode, IDCT, IdctBlock, ChunkRestore, ChunkCombine, RestoreMean, ChunkSplit, Channel, Upsample, YcbcrToRgbMetadata, RgbToYcbcrMetadata, Pad, Crop, PowerAllocate, RandomOrthogonal, ZfEstimate, LlseEstimate

from utils import set_device

image_inp = transforms.ToTensor()(Image.open("C:/kubouch/data/mini/kodim23_48x48.png"))

references = None
results = None
norm = 'ortho'
do_print = True
device = set_device()

w_human = torch.Tensor([0.299, 0.587, 0.114])
yuv_mode = 420
dct_size = (8, 8)
chunk_size = 64
snr_db = 30
packet_loss = None #0.5
seed = 42
cr = 1.0 #0.25
# grouping = "vertical_uv"
grouping = "horizontal_uv"
power = 1.0

lvc_encoder = LvcEncode(
    [
        Pad(420, dct_size, do_print=do_print),
        RgbToYcbcrMetadata(w_human),
        Downsample(mode=yuv_mode),
        SubtractMean(device),
        # DCT(is_half=False, norm=norm),
        ChunkSplit(dct_size, do_print=do_print),
        DctBlock(dct_size, yuv_mode, grouping, is_half=False, norm=norm, do_print=do_print),
        ChunkSelect(cr, do_print=do_print),
        PowerAllocate(power, do_print=do_print),
        RandomOrthogonal(seed, device, invert=False),
    ],
    references,
    results,
    chunk_size,
    do_print=do_print,
)

lvc_decoder = LvcDecode(
    [
        RandomOrthogonal(seed, device, invert=True),
        ZfEstimate(power),
        # LlseEstimate(power),
        ChunkRestore(device=device, is_half=False, do_print=do_print),
        IdctBlock(dct_size, yuv_mode, grouping, is_half=False, norm=norm, do_print=do_print),
        ChunkCombine(
            yuv_mode,
            device=device,
            dct_size=dct_size,
            is_half=False,
            do_print=do_print,
        ),
        # IDCT(is_half=False, norm=norm),
        RestoreMean(),
        Upsample(mode=yuv_mode),
        YcbcrToRgbMetadata(w_human),
        Crop(),
    ]
)

image_enc, metadata = lvc_encoder([image_inp])[0]
print("Done encoding")
image_out = lvc_decoder([(image_enc, metadata)])[0]

mse = (image_out - image_inp).square().mean(dim=(1,2))
psnr = 10 * torch.log10(torch.div(1.0**2, mse))

print("MSE: ", mse)
print("PSNR: ", psnr)

fig, axs = plt.subplots(2, 2, figsize=(12, 12))
fig.suptitle("Input image")
axs = axs.flatten()
axs[0].imshow(image_inp[0], cmap="gray")
axs[0].set_title("Y")
axs[1].imshow(image_inp[1], cmap="gray")
axs[1].set_title("U")
axs[2].imshow(image_inp[2], cmap="gray")
axs[2].set_title("V")
axs[3].imshow(transforms.ToPILImage()(image_inp))
axs[3].set_title("RGB")
plt.tight_layout()

fig, axs = plt.subplots(2, 2, figsize=(12, 12))
fig.suptitle("Restored image")
axs = axs.flatten()
axs[0].imshow(image_out[0], cmap="gray")
axs[0].set_title("Y")
axs[1].imshow(image_out[1], cmap="gray")
axs[1].set_title("U")
axs[2].imshow(image_out[2], cmap="gray")
axs[2].set_title("V")
axs[3].imshow(transforms.ToPILImage()(image_out))
axs[3].set_title("RGB")
plt.tight_layout()

In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
from itertools import product
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from lvc import create_lvc
from utils import set_device

device = set_device()

img_dir = Path("C:/kubouch/data/kodim/raw")
image_names = list(img_dir.glob("*.png"))
images = torch.stack([
    # transforms.CenterCrop((288, 352))(
    transforms.Resize((288, 672))(
        transforms.ToTensor()(
            Image.open(image_name).convert(mode="RGB")
        )
    )
    for image_name in image_names
])

dct_sizes = [None, (8, "h"), (8, "v"), (16, "h"), (16, "v")]
packet_losses = [None, 0.0, 0.1, 0.25, 0.5, 0.75]
num_chunks = [ 64, 256 ]
estimators = ['zf', 'llse']
csnrs = [ 0, 5, 10, 20, 30, 'inf']
crs = [0.1, 0.25, 0.5, 0.75, 1.0]

for dct_size, packet_loss, nchunks, estimator, csnr, cr in product(dct_sizes, packet_losses, num_chunks, estimators, csnrs, crs):
    lvc_params = {
        "packet_loss": packet_loss,
        "seed": 42,
        "mode": 420,
        "cr": cr,
        "csnr_db": csnr,
        "estimator": estimator,
        "nchunks": nchunks
    }

    if dct_size is not None:
        lvc_params["dct_w"] = dct_size[0]
        lvc_params["dct_h"] = dct_size[0]
        lvc_params["grouping"] = dct_size[1]

    lvc = create_lvc(lvc_params, device, half=False)
    res = lvc(images)

    psnrs = []
    for inp, out in zip(images, res):
        mse = (out - inp).square().mean()
        psnrs.append(10 * torch.log10(1.0**2 / mse))

    psnr = float(torch.tensor(psnrs).mean())

    print("lvc_params: ", lvc_params, ", avg PSNR:", psnr)
