In [None]:
# Single run of LVC + NN model used for testing

%load_ext autoreload
%autoreload 2
%reset -f
%matplotlib inline

import os
from pathlib import Path

gpu_i = 7
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_i}"
os.environ["DRJIT_LIBLLVM_PATH"] = "/usr/lib/llvm-14/lib/libLLVM.so"

import pandas as pd
import torch
# import tensorflow as tf

from lvc.lvc import create_lvc, GradConfig, JPEGConfig, SionnaConfig
from models import get_model, get_model_id
from models.model import plot_channels
from utils import set_device
from utils.q_search import fetch_param

device = set_device(f"cuda:0")
cpus = set(range(gpu_i * 96 // 8, gpu_i * 96 // 8 + 96 // 8))
os.sched_setaffinity(0, cpus)
affinity = os.sched_getaffinity(0)
print(f"Running on CPUs: {affinity}")
torch.set_num_threads(len(cpus))

SEED = 42
torch.manual_seed(SEED)
# tf.random.set_seed(SEED)

Q_SEARCH_FILES = {
    "fastseg_small": "experiments_tupu/runs/run21_keep/q_search_fastseg_small_Noneimgs.pt",
    "fastseg_large": "experiments_tupu/runs/run22_keep/q_search_fastseg_large_Noneimgs.pt",
    "yolov8_n": "experiments_hupu/runs/run49_keep/q_search_yolov8_n_Noneimgs.pt",
    "yolov8_s": "experiments_hupu/runs/run50_keep/q_search_yolov8_s_Noneimgs.pt",
    "yolov8_l": "experiments_hupu/runs/run51_keep/q_search_yolov8_l_Noneimgs.pt",
}

yolov8_name = "yolov8"
yolov8_variant = "n"
config_yolov8 = {
    "name": yolov8_name,
    "variant": yolov8_variant,
    "task": "detect",
    "snapshot": f"{yolov8_name}{yolov8_variant}.pt",
    "unit": "mAP_50_95",
    "data_file": "coco.yaml",
}

config_fastseg = {
    "name": "fastseg",
    "variant": "small",
    # "variant": "large",
    "snapshot": (
        Path.home() / "data/models/fastseg/raw/small/best_checkpoint_ep171.pth"
        # Path.home() / "data/models/fastseg/raw/large/best_checkpoint_ep172.pth"
    ),
    "unit": "mean_iu",
}

config = config_fastseg

model_id = get_model_id(config)
print(f"Model: {model_id}")

color_space = "yuv"
yuv_mode = 444
nchunks = 64
cr = 0.25
csnr_dbs = [15]
estimator = "zf"
do_print = True

codec = "jpeg"
# sionna_config: SionnaConfig = { "nbits_per_sym": 2, "coderate": 1.0 }
sionna_config: SionnaConfig = { "nbits_per_sym": 2, "coderate": 0.5 }

if codec is not None:
    param = fetch_param(
        pd.DataFrame(torch.load(Q_SEARCH_FILES[model_id])),
        model_id,
        codec,
        sionna_config["nbits_per_sym"],
        cr,
    )
else:
    param = None

print(f"Codec: {codec}, param: {param}")

if codec == "grace":
    jpeg_config: JPEGConfig = {
        "codec": "grace",
        "param": param,
        "param_name": "B",
        "param_fmt": "{:12.5e}",
        "turbojpeg_enc": False,
        "turbojpeg_dec": False,
    }
elif codec == "jpeg":
    jpeg_config: JPEGConfig = {
        "codec": "jpeg",
        "param": param,
        "param_name": "Q",
        "param_fmt": "{:3}",
        "turbojpeg_enc": False,
        "turbojpeg_dec": False,
    }
else:
    jpeg_config = None

for csnr_db in csnr_dbs:
    print(f"--- CSNR: {csnr_db} dB")

    lvc_params = {
        "packet_loss": None,
        "seed": SEED,
        "mode": yuv_mode,
        "cr": cr,
        "csnr_db": csnr_db,
        "estimator": estimator,
        "nchunks": nchunks,
        "color_space": color_space,
        # block-based DCT settings
        # "dct_w": int(math.sqrt(nchunks)),
        # "dct_h": int(math.sqrt(nchunks)),
        # "grouping": "vertical_uv",
    }

    num_batches = 8
    batch_size = 1 # 8
    num_batches_probe = 32
    batch_size_probe = 1

    grad_type = "dist_sq"

    # lvc_chain = create_lvc(
    #     lvc_params,
    #     device,
    #     False,
    #     unsqueeze=True,
    #     do_print=do_print,
    #     grad_config=None,
    #     sionna_config=sionna_config,
    # )
    # model = get_model(config, device, num_batches, batch_size, lvc_chain=lvc_chain, color_space=color_space, do_print=do_print)
    # model.bench()
    # return 0
    #
    # model = get_model(config, device, num_batches, batch_size,color_space=color_space, do_print=do_print)
    # res = model.eval()
    # print("=== orig:", res)

    model_probe = get_model(config, device, num_batches_probe, batch_size_probe, color_space=color_space, do_print=do_print)
    res_probe = model_probe.run_probe(grad_type, [nchunks])
    # print("=== orig probe:", res_probe)

    grad_yuv_key = "grads_yuv_420" if yuv_mode == 420 else "grads_yuv"
    grad_norm_key = "grads_norm_420" if yuv_mode == 420 else "grads_norm"

    # plot_channels(res_probe[grad_yuv_key].numpy(), "grads YUV", False, save="grads_3.png", log=True)

    if jpeg_config is None or jpeg_config["codec"] != "grace":
        lvc_chain = create_lvc(
            lvc_params,
            device,
            half=False,
            unsqueeze=True,
            do_print=do_print,
            grad_config=None,
            sionna_config=sionna_config,
            jpeg_config=jpeg_config,
        )

        model_lvc = get_model(
            config,
            device,
            num_batches,
            batch_size,
            num_workers=0,
            lvc_chain=lvc_chain,
            color_space=color_space,
            do_print=do_print
        )
        res_lvc = model_lvc.eval()
        print("=== orig LVC:", res_lvc)


    if jpeg_config is None or jpeg_config["codec"] != "jpeg":
        gconfig: GradConfig = {
            "type": grad_type,
            "g_yuv": True,
            "g_select": True,
            "g_allocate": True,
            "w": res_probe["W"],
            "grad_mean": res_probe[grad_yuv_key],
            "grad_norm": res_probe[grad_norm_key][nchunks],
        }

        lvc_chain_g = create_lvc(
            lvc_params,
            device,
            half=False,
            unsqueeze=True,
            do_print=do_print,
            grad_config=gconfig,
            sionna_config=sionna_config,
            jpeg_config=jpeg_config,
        )

        model_g = get_model(config, device, num_batches, batch_size, color_space=color_space,
                            lvc_chain=lvc_chain_g, do_print=do_print)
        res_g = model_g.eval()
        print(f"=== grad LVC: {res_g}")

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

# # config:
# color_space = "yuv"
# yuv_mode = 444
# nchunks = 256
# cr = 1.0 #0.0625
# csnr_dbs = [10, 20, 30] #10
# estimator = "zf"

csnr_db = [10, 20, 30]

mean_iu = [
    0.04296540606173279,
    0.20398644323634071,
    0.5021734055980513,
]

mean_iu_g = [
    0.04832407239593724,
    0.27049837130461907,
    0.548599180026712,
]

plt.title("FastSeg (small) mIoU over Sionna channel")
plt.xlabel("Eb/N0 [dB]")
plt.ylabel("mIoU [-]")
plt.plot(csnr_db, mean_iu, '-+', label="miou")
plt.plot(csnr_db, mean_iu_g, '-x', label="miou_g")
plt.legend()
plt.grid()
plt.show()

In [None]:
# Results notes

# === orig: {'mean_iu': 0.6118768412396708, 'val_loss_avg': 0.22068387269973755}

# === orig LVC: {'mean_iu': 0.5076176815789921, 'val_loss_avg': 0.09857379645109177}
# === grad w grad_norm_select grad_norm_allocate: {'mean_iu': 0.5043232512120162, 'val_loss_avg': 0.1010698452591896}
# === orig LVC: {'mean_iu': 0.5076176815789921, 'val_loss_avg': 0.09857379645109177}
# === grad w grad_norm_select grad_norm_allocate: {'mean_iu': 0.5042013203125951, 'val_loss_avg': 0.09951619058847427}

# === orig LVC: {'mean_iu': 0.5052913375529151, 'val_loss_avg': 0.09825678914785385}
# === grad w grad_norm_select grad_norm_allocate: {'mean_iu': 0.5052742168501532, 'val_loss_avg': 0.09825391322374344}
# === orig LVC: {'mean_iu': 0.5052913375529151, 'val_loss_avg': 0.09825678914785385}
# === grad w grad_norm_select grad_norm_allocate: {'mean_iu': 0.5052803922630612, 'val_loss_avg': 0.09825777262449265}

# tensor([-5.1371e-10, -3.4887e-10, -2.2321e-09]  # fastseg
# tensor([ 2.9117e-08, -5.7752e-07,  6.6200e-08]  # yolov8

# fseg small, 64, 420, 0.5, 10, zf, (all batches):
# bilinear:
# === orig LVC: {'mean_iu': 0.420771445917225, 'val_loss_avg': 0.5670122504234314}
# === grad w grad_norm_select grad_norm_allocate: {'mean_iu': 0.35746933501871525, 'val_loss_avg': 0.7364088892936707}
# nearest:
# === orig LVC: {'mean_iu': 0.420771445917225, 'val_loss_avg': 0.5670122504234314}
# === grad w grad_norm_select grad_norm_allocate: {'mean_iu': 0.3952893112886412, 'val_loss_avg': 0.6174167394638062}
# nearest-exact:
# === orig LVC: {'mean_iu': 0.420771445917225, 'val_loss_avg': 0.5670122504234314}
# === grad w grad_norm_select grad_norm_allocate: {'mean_iu': 0.39431301752844455, 'val_loss_avg': 0.6204946041107178}

# cr = 0.0625
# orig LVC:
# === orig LVC: {'mean_iu': 0.60256914101466, 'val_loss_avg': 0.2260408103466034, 'collected': []}
# new:
# === grad w grad_norm_select: {'mean_iu': 0.5322282588638082, 'val_loss_avg': 0.30318042635917664, 'collected': []}
# new mean abs grad:
# === grad w grad_norm_select: {'mean_iu': 0.5023454525660811, 'val_loss_avg': 0.3682572841644287, 'collected': []}
# 0.5980870998181766
# old:
# === grad w grad_norm_select: {'mean_iu': 0.5986947963113816, 'val_loss_avg': 0.23073004186153412, 'collected': []}
#
# power alloc 10 dB
# === orig LVC: {'mean_iu': 0.34010766954194704, 'val_loss_avg': 0.9895051717758179, 'collected': []}
# === grad w grad_norm_select grad_norm_allocate: {'mean_iu': 0.3421851648080159, 'val_loss_avg': 1.2698609828948975, 'collected': []}
#
# num_batches = 8; fastseg small:
#         === orig LVC: {'mean_iu': 0.28979775041114986, 'val_loss_avg': 0.9113653302192688, 'collected': []}
#      sq === grad type g_yuv g_select g_allocate w grad_mean grad_norm: {'mean_iu': 0.30887477886310216, 'val_loss_avg': 0.9509761929512024, 'collected': []}
#     abs === grad type g_yuv g_select g_allocate w grad_mean grad_norm: {'mean_iu': 0.2757566183811295, 'val_loss_avg': 1.4529390335083008, 'collected': []}
# precise === grad type g_yuv g_select g_allocate w grad_mean grad_norm: {'mean_iu': 0.3093266004310626, 'val_loss_avg': 0.96988445520401, 'collected': []}

In [None]:
# plot_channels(res_probe["grads_norm"][nchunks].numpy(), "grads norm YUV", show=True, save=None, log=False)
# plot_channels(res_probe["grads_yuv"].numpy(), "grads abs YUV", show=False, save=None, log=False)
# plot_channels(res_probe["grads_rgb"].numpy(), "grads abs YUV", show=False, save=None, log=False)
# plot_channels(res_probe["dct_yuv"].abs().numpy(), "dct abs YUV", show=False, save=None, log=True)
# plot_channels(res_probe["dct_rgb"].abs().numpy(), "dct abs YUV", show=False, save=None, log=True)
print(res_probe.keys())
print(res_g)
# LVC G loss: 0.5724298357963562
# no LVC loss: 0.4200332760810852

distortion = (0.5724298357963562 - 0.4200332760810852) ** 2
print(distortion)

In [None]:
# testing selecting k and cr

import pandas as pd
import plotly.express as px

rows = []
for k in range(1, 65):
    rows.append(
        dict(
            k=k,
            cr_420_64=k / (64 * 3 / 2),
            cr_444_64=k / (64 * 3),
            cr_420_256=k / (256 * 3 / 2),
            cr_444_256=k / (256 * 3),
        )
    )

print(pd.DataFrame(rows).to_string())

rows = []
for cr in [0.025, 0.05, 0.1, 0.2, 0.4, 0.8, 1.0]:
    k_420_64 = int(64 * 3 / 2 * cr)
    k_444_64 = int(64 * 3 * cr)
    k_420_256 = int(256 * 3 / 2 * cr)
    k_444_256 = int(256 * 3 * cr)
    rows.append(
        dict(
            cr=cr,
            k_420_64=k_420_64,
            cr_420_64=k_420_64 / (64 * 3 / 2),
            k_444_64=k_444_64,
            cr_444_64=k_444_64 / (64 * 3),
            k_420_256=k_420_256,
            cr_420_256=k_420_256 / (256 * 3 / 2),
            k_444_256=k_444_256,
            cr_444_256=k_444_256 / (256 * 3),
        )
    )

print(pd.DataFrame(rows).to_string())

rows = []
for k in [2, 4, 8, 16, 32, 64, 96]:
    rows.append(
        dict(
            k_420_64=k,
            cr_420_64=k / (64 * 3 / 2),
            k_444_64=k * 2,
            cr_444_64=k * 2 / (64 * 3),
            k_420_256=k * 4,
            cr_420_256=k * 4 / (256 * 3 / 2),
            k_444_256=k * 8,
            cr_444_256=k * 8 / (256 * 3),
        )
    )

df = pd.DataFrame(rows)
print(df.to_string())

fig = px.line(df, x="k_420_64", y="cr_420_64", markers=True)
fig.show()

rows = []
for k in [3, 6, 12, 24, 48, 96]:
    rows.append(
        dict(
            k_420_64=k,
            cr_420_64=k / (64 * 3 / 2),
            k_444_64=k * 2,
            cr_444_64=k * 2 / (64 * 3),
            k_420_256=k * 4,
            cr_420_256=k * 4 / (256 * 3 / 2),
            k_444_256=k * 8,
            cr_444_256=k * 8 / (256 * 3),
        )
    )

df = pd.DataFrame(rows)
print(df.to_string())

fig = px.line(df, x="k_420_64", y="cr_420_64", markers=True)
fig.show()

In [None]:
import os

import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from torch import Tensor

from lvc.lvc import (
    YuvImage,
    Metadata,
    RgbToYcbcrMetadata,
    YcbcrToRgbMetadata,
    TensorToImage,
    ImageToTensor,
    DownsampleTensor,
    UpsampleTensor,
)
from models.model import plot_channels
from transforms.color_transforms import RgbToYcbcr, YcbcrToRgb
from utils import set_device


def yuv_image_from_tensor(inp: Tensor, mode: int = 420) -> YuvImage:
    if mode != 444 and mode != 420:
        raise ValueError("Invalid subsampling mode {} (choose 444 or 420)".format(mode))

    if mode == 420:
        H = inp.shape[-2]
        W = inp.shape[-1]

        y, u, v = (x.view((1, 1, H, W)) for x in inp.unbind(-3))

        sub_w = int(W / 2)
        sub_h = int(H / 2)

        u = u[:, :, :sub_h, :sub_w]
        v = v[:, :, :sub_h, :sub_w]

        return YuvImage(y, u, v)
    else:
        return YuvImage(inp[0, :, :], inp[1, :, :], inp[2, :, :])


gpu_i = 7
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_i}"
device = set_device(f"cuda:0")
cpus = set(range(84, 96))

w_human = torch.Tensor([0.299, 0.587, 0.114])

img3 = T.ToTensor()(Image.open("/home/jakub/data/kodim/raw/kodim23.png"))

mode = 420

fwd = T.Compose([RgbToYcbcr(w=w_human), DownsampleTensor(mode)])

bck = T.Compose([UpsampleTensor(mode), YcbcrToRgb(w=w_human)])

meta = Metadata(img3.shape[1:], (0, 0), (0, 0))
print(meta)

# yuv_img = img
# for layer in fwd:
#     print("fwd", layer)
#     yuv_img, meta = layer(yuv_img, meta)
#     print("yuv_img", type(yuv_img))
# print(yuv_img.shape)

# print("yuv_img", type(yuv_img))
# res_img = yuv_img

# for layer in bck:
#     print("bck", layer)
#     res_img, meta = layer(res_img, meta)
# print(res_img.shape)

# print(yuv_img.y[0, 0])
# yuv_img = torch.add(yuv_img, yuv_img)
# print(yuv_img.y[0, 0])

img3 = img3.view((1, img3.shape[-3], img3.shape[-2], img3.shape[-1]))
img3.requires_grad_(True)
img3.retain_grad()
n, c, h, w = img3.shape
print(f"img: {img3.shape}")
# yuv_img = DownsampleTensor(420)(img)

yuv_img = fwd(img3)
yuv_img2 = RgbToYcbcr(w=w_human)(img3)
res_img = bck(yuv_img)
res_img2 = YcbcrToRgb(w=w_human)(yuv_img2)

yuv_img.requires_grad_(True)
yuv_img.retain_grad()
yuv_img2.requires_grad_(True)
yuv_img2.retain_grad()
res_img.requires_grad_(True)
res_img.retain_grad()
res_img2.requires_grad_(True)
res_img2.retain_grad()
print(f"res_img: {res_img.shape}")


# yuv_img.requires_grad_(True)
# yuv_img.retain_grad()
# yuv_img2 = yuv_img * 2.0
# yuv_img, meta = TensorToImage()(img, meta)
# yuv_img.retain_grad()

# res = yuv_img.mean().sum()
# res.requires_grad_(True)
# res.retain_grad()

tgt = torch.empty((n, h, w), dtype=torch.long).random_(0, 1)
crit = nn.LogSoftmax(dim=1)
loss_fn = nn.NLLLoss()

sm = crit(res_img)
sm2 = crit(res_img2)
print(f"sm: {sm.shape}")
print(f"tgt: {tgt.shape}")
loss = loss_fn(sm, tgt)
loss2 = loss_fn(sm2, tgt)

loss.backward()
loss2.backward()

print(f"yuv_img: {yuv_img.shape}")
print(f"yuv_img grad: {yuv_img.grad.shape}")
print(f"yuv_img min/max: {yuv_img.grad.min()}/{yuv_img.grad.max()}")
print(
    f"yuv_img Y min/max: {yuv_img.grad[:, 0, :, :].min()}/{yuv_img.grad[:, 0, :, :].max()}"
)


# print(loss.grad)
# print(res_img.grad)
# print(yuv_img.grad)

# yuv_img = yuv_image_from_tensor(yuv_img, mode=420)
# plot_channels([ch.numpy() for ch in yuv_img.grad.channels()], "YUV grad")
# plot_channels([ch.detach().numpy() for ch in yuv_img.channels()], "YUV")
# plot_channels(img.grad.squeeze().numpy(), "IMG")
plot_channels(yuv_img.grad.squeeze().numpy(), "YUV")
grad2 = DownsampleTensor(mode)(yuv_img2.grad) * 4
plot_channels(grad2.squeeze().numpy(), "YUV2")

print(yuv_img.grad == grad2)
print((yuv_img.grad - grad2).abs().mean())

# plot_channels(torch.cat([
#     yuv_img.grad.squeeze()[0,:,:],
#     yuv_img.grad.squeeze()[1,:,:],
#     yuv_img.grad.squeeze()[2,:,:],

# ]).numpy(), "YUV")
# plot_channels(res_img.grad.squeeze().numpy(), "RES")

In [None]:
# Get number of parameters

from pathlib import Path

from models import get_model
from utils import set_device

device = set_device(f"cuda:0")

configs = [
    {
        "name": "yolov8",
        "variant": "n",
        "task": "detect",
        "snapshot": f"yolov8n.pt",
        "unit": "mAP_50_95",
        "data_file": "coco.yaml",
    },
    {
        "name": "yolov8",
        "variant": "s",
        "task": "detect",
        "snapshot": f"yolov8s.pt",
        "unit": "mAP_50_95",
        "data_file": "coco.yaml",
    },
    {
        "name": "yolov8",
        "variant": "l",
        "task": "detect",
        "snapshot": f"yolov8l.pt",
        "unit": "mAP_50_95",
        "data_file": "coco.yaml",
    },
    {
        "name": "fastseg",
        "variant": "small",
        "snapshot": (
            Path.home() / "data/models/fastseg/raw/small/best_checkpoint_ep171.pth"
        ),
        "unit": "mean_iu",
    },
    {
        "name": "fastseg",
        "variant": "large",
        "snapshot": (
            Path.home() / "data/models/fastseg/raw/large/best_checkpoint_ep172.pth"
        ),
        "unit": "mean_iu",
    },
]

print("Number of parameters:")
for config in configs:
    model = get_model(
        config,
        device,
        num_batches=None,
        batch_size=8,
        color_space="yuv",
        do_print=False,
    )

    print(config["name"], config["variant"], f"{model.get_num_params() / 1e6:.1f}")

In [None]:
# Test JPEG encoding

import os
from pathlib import Path

import cv2 as cv
import numpy as np
import torch
from matplotlib import pyplot as plt
from PIL import Image
from turbojpeg import TurboJPEG, TJSAMP_420, TJPF_RGB, TJSAMP_444

from lvc.lvc import ChunkSplit, DctBlock, YuvImage, Metadata
from models import get_model, probe_models
from models.model import plot_channels
from transforms.dct import dct_2d, dct_2d_block
from transforms.metrics import mse_psnr
from utils import set_device

print("image:")
img = np.array(Image.open("digcom/kodim23.png"))
(img_h, img_w, _) = img.shape
print(img.shape)

# plt.figure(0)
# plt.imshow(img)
# plt.show()

print("")
print("image yuv 420:")
img2_yuv420 = cv.cvtColor(img, cv.COLOR_RGB2YUV_YV12)
print(type(img2_yuv420))
print(img2_yuv420.shape)
img2 = cv.cvtColor(img2_yuv420, cv.COLOR_YUV2RGB_YV12)
print(img2.shape)

_, psnr = mse_psnr(torch.tensor(img) / 255.0, torch.tensor(img2) / 255.0)
print(psnr)

print("")
print("image yuv 444:")
img2_yuv444 = cv.cvtColor(img, cv.COLOR_RGB2YUV)
print(type(img2_yuv444))
print(img2_yuv444.shape)
img22 = cv.cvtColor(img2_yuv444, cv.COLOR_YUV2RGB)
print(img22.shape)

_, psnr = mse_psnr(torch.tensor(img) / 255.0, torch.tensor(img22) / 255.0)
print(psnr)

# plt.figure(1)
# plt.imshow(img2)
# plt.show()

jpeg = TurboJPEG(
    "/media/data1/jakub/git/cpc/aisa-demo/external/libjpeg-turbo/libturbojpeg/x86-64/lib/libturbojpeg.so"
)

#### This one:

q = 80

print("")
print("jpeg yuv:")
jpeg_bytes = jpeg.encode_from_yuv(img2_yuv420, img_h, img_w, q, TJSAMP_420)
img3_yuv420, img3_sizes = jpeg.decode_to_yuv(jpeg_bytes)
print(type(img3_yuv420))
print("  decoded:", img3_yuv420.shape)
print("  plane sizes:", img3_sizes)
img3_rgb = cv.cvtColor(img3_yuv420.reshape(img_w, img_w), cv.COLOR_YUV2RGB_YV12)
mse, psnr = mse_psnr(torch.tensor(img) / 255.0, torch.tensor(img3_rgb) / 255.0)
print(f"RGB          : MSE {mse:.3e}, PSNR {psnr:.3f} dB")
# TODO: These are not correct because opencv and turbojpeg YUV transforms are not
# compatible.
mse, psnr = mse_psnr(
    torch.tensor(img2_yuv420) / 255.0,
    torch.tensor(img3_yuv420.reshape(img_w, img_w)) / 255.0,
)
print(f"YUV          : MSE {mse:.3e}, PSNR {psnr:.3f} dB")
plt.imsave("img3_rgb.png", img3_rgb)

#### With no subsampling (distorted for some reason):

q = 80

print("")
print("jpeg yuv 444:")
jpeg_bytes = jpeg.encode_from_yuv(img2_yuv444, img_h, img_w, q, TJSAMP_444)
img3_yuv444, img33_sizes = jpeg.decode_to_yuv(jpeg_bytes)
print(type(img3_yuv444))
print("  decoded:", img3_yuv444.shape)
print("  plane sizes:", img33_sizes)
img33_rgb = cv.cvtColor(img3_yuv444.reshape(img_h, img_w, 3), cv.COLOR_YUV2RGB)
print("  rgb size", img33_rgb.shape)
mse, psnr = mse_psnr(torch.tensor(img) / 255.0, torch.tensor(img33_rgb) / 255.0)
print(f"RGB          : MSE {mse:.3e}, PSNR {psnr:.3f} dB")
# TODO: These are not correct because opencv and turbojpeg YUV transforms are not
# compatible.
mse, psnr = mse_psnr(
    torch.tensor(img2_yuv444) / 255.0,
    torch.tensor(img3_yuv444.reshape(img_h, img_w, 3)) / 255.0,
)
print(f"YUV444       : MSE {mse:.3e}, PSNR {psnr:.3f} dB")
plt.imsave("img3_rgb_444.png", img33_rgb)

### This one somehow loses color?

print("")
print("jpeg rgb:")
img4_rgb = jpeg.decode(jpeg_bytes)
print(type(img4_rgb))
print(img4_rgb.shape)
mse, psnr = mse_psnr(torch.tensor(img) / 255.0, torch.tensor(img4_rgb) / 255.0)
print(f"RGB          : MSE {mse:.3e}, PSNR {psnr:.3f} dB")
mse, psnr = mse_psnr(torch.tensor(img3_rgb) / 255.0, torch.tensor(img4_rgb) / 255.0)
print(f"RGB (vs YUV) : MSE {mse:.3e}, PSNR {psnr:.3f} dB")
plt.imsave("img4_rgb.png", img4_rgb)

# w_human_rgb = torch.Tensor((0.299, 0.587, 0.114))
# w_img3_yuv420 = torch.tensor(img3_yuv420) / 255.0
# w_img3_yuv420[:img_w * img_h] *= w_human[0]
# w_img3_yuv420[img_w * img_h : img_w * img_h + (img_w * img_h // 4)] *= w_human[1]
# w_img3_yuv420[img_w * img_h + (img_w * img_h // 4):] *= w_human[2]
# w_img3_yuv420 = w_img3_yuv420.reshape(img_w, img_w)
# mse, psnr = mse_psnr(
#     torch.tensor(img2_yuv420) / 255.0,
#     w_img3_yuv420,
# )
# print(f"YUV weighted : MSE {mse:.3e}, PSNR {psnr:.3f} dB")

# plt.figure(2)
# plt.imshow(img3_rgb)
# plt.show()


#####################

print("")
dct_size = (32, 32)

try:
    print(list(result_probe.keys()))
    grad_yuv420_norm = result_probe["grads_norm_420"][dct_size[0] * dct_size[1]].numpy()
    grad_yuv420 = result_probe["grads_yuv_420"].numpy()
    grad_rgb = result_probe["grads_rgb"]
    W = result_probe["W"]
except:
    torch.manual_seed(42)
    gpu_i = 0
    os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_i}"
    # device = set_device(f"cuda:0")
    device = set_device(f"cpu")
    # cpus = set(range(84, 96))
    # os.sched_setaffinity(0, cpus)
    # affinity = os.sched_getaffinity(0)
    # print(f"Running on CPUs: {affinity}")
    # torch.set_num_threads(len(cpus))

    color_space = "yuv"
    modes = [444]  # , 420]
    num_batches_probe = 32
    batch_size_probe = 1

    config = {
        "name": "yolov8",
        "variant": "n",
        "snapshot": "yolov8n.pt",
        "unit": "mAP_50_95",
        "task": "detect",
        "data_file": "coco.yaml",
    }

    model_probe = get_model(
        config, device, num_batches_probe, batch_size_probe, color_space="yuv"
    )

    result_probe = model_probe.run_probe("dist_sq", [64, 256, 1024])
    print(list(result_probe.keys()))
    grad_yuv420_norm = result_probe["grads_norm_420"][dct_size[0] * dct_size[1]].numpy()
    grad_yuv420 = result_probe["grads_yuv_420"].numpy()
    grad_rgb = result_probe["grads_rgb"]
    W = result_probe["W"]


# plot_channels(grad_yuv420_norm, "Norm sq grad YUV 420 block", show=True)

### YUV

print("")
y3, u3, v3 = jpeg.decode_to_yuv_planes(jpeg_bytes)
y3 = torch.tensor(y3) / 255.0
u3 = torch.tensor(u3) / 255.0
v3 = torch.tensor(v3) / 255.0

norm = "ortho"
dct_y = dct_2d_block(
    y3.unsqueeze(0).unsqueeze(0), norm, dct_size[0], dct_size[1]
).squeeze()
dct_u = dct_2d_block(
    u3.unsqueeze(0).unsqueeze(0), norm, dct_size[0], dct_size[1]
).squeeze()
dct_v = dct_2d_block(
    v3.unsqueeze(0).unsqueeze(0), norm, dct_size[0], dct_size[1]
).squeeze()

print("grads norm yuv420: ", grad_yuv420_norm.shape)
print("grads yuv420: ", grad_yuv420.shape)
print("grads rgb: ", grad_rgb.shape)
print("W:", W)
print("DCT Y: ", dct_y.shape)
print("DCT U: ", dct_u.shape)
print("DCT V: ", dct_v.shape)

plt.imsave("dct_y.png", dct_y.square().mean(dim=0).log10().numpy())
plt.imsave("dct_u.png", dct_u.square().mean(dim=0).log10().numpy())
plt.imsave("dct_v.png", dct_v.square().mean(dim=0).log10().numpy())
plt.imsave("grad_norm_y.png", grad_yuv420_norm[0])
plt.imsave("grad_norm_u.png", grad_yuv420_norm[1])
plt.imsave("grad_norm_v.png", grad_yuv420_norm[2])

### RGB

print("")
metadata = Metadata(
    (grad_rgb.shape[1], grad_rgb.shape[2]),
    dct_size,
    (grad_rgb.shape[1] // dct_size[0], grad_rgb.shape[2] // dct_size[1]),
)
data = YuvImage(grad_rgb[0, :, :], grad_rgb[1, :, :], grad_rgb[2, :, :])
grad_rgb_chunks, _ = ChunkSplit(dct_size)(data, metadata)
print("grad_rgb_chunks: ", grad_rgb_chunks.shape)
nchunks = grad_rgb_chunks.shape[0]
grad_rgb_chunks_r = grad_rgb_chunks[: nchunks // 3]
grad_rgb_chunks_g = grad_rgb_chunks[nchunks // 3 : 2 * nchunks // 3]
grad_rgb_chunks_b = grad_rgb_chunks[3 * nchunks // 3 :]

grad_rgb_norm = torch.stack(
    [
        grad_rgb_chunks_r.mean(dim=0),
        grad_rgb_chunks_g.mean(dim=0),
        grad_rgb_chunks_b.mean(dim=0),
    ]
)
print("grad_rgb_norm: ", grad_rgb_norm.shape)

grad_norm_r = grad_rgb_norm[0, :, :]
grad_norm_g = grad_rgb_norm[1, :, :]
grad_norm_b = grad_rgb_norm[2, :, :]
print("grad norm 1ch: ", grad_norm_r.shape, grad_norm_g.shape, grad_norm_b.shape)

rgb = (torch.tensor(img3_rgb) / 255.0).permute((2, 0, 1))

r = rgb[0, :, :]
g = rgb[1, :, :]
b = rgb[2, :, :]

dct_r = dct_2d_block(
    r.unsqueeze(0).unsqueeze(0), norm, dct_size[0], dct_size[1]
).squeeze()
dct_g = dct_2d_block(
    g.unsqueeze(0).unsqueeze(0), norm, dct_size[0], dct_size[1]
).squeeze()
dct_b = dct_2d_block(
    b.unsqueeze(0).unsqueeze(0), norm, dct_size[0], dct_size[1]
).squeeze()
print("dct_rgb: ", dct_r.shape, dct_g.shape, dct_b.shape)

ref_rgb = (torch.tensor(img) / 255.0).permute((2, 0, 1))

ref_r = ref_rgb[0, :, :]
ref_g = ref_rgb[1, :, :]
ref_b = ref_rgb[2, :, :]

ref_dct_r = dct_2d_block(
    ref_r.unsqueeze(0).unsqueeze(0), norm, dct_size[0], dct_size[1]
).squeeze()
ref_dct_g = dct_2d_block(
    ref_g.unsqueeze(0).unsqueeze(0), norm, dct_size[0], dct_size[1]
).squeeze()
ref_dct_b = dct_2d_block(
    ref_b.unsqueeze(0).unsqueeze(0), norm, dct_size[0], dct_size[1]
).squeeze()
print("ref_dct_rgb: ", ref_dct_r.shape, ref_dct_g.shape, ref_dct_b.shape)

mse, psnr = mse_psnr(ref_rgb, rgb)
print(f"DCT RGB      : MSE {mse:.3e}, PSNR {psnr:.3f} dB")

er = ((dct_r - ref_dct_r).square() * grad_norm_r).mean()
eg = ((dct_g - ref_dct_g).square() * grad_norm_g).mean()
eb = ((dct_b - ref_dct_b).square() * grad_norm_b).mean()

print("err grad dct rgb: ", er, eg, eb)


# TODO: compare each plane separately and weigh them accurding to W
# mse, psnr = mse_psnr( ...

# yuv3_420 = YuvImage(y3, u3, v3)

# split = ChunkSplit(dct_size)
# metadata = Metadata((img_w, img_h), dct_size, 0)

# blocks, _ = split(yuv3_420, metadata)
# print("blocks: ", blocks.shape)

# dct_blocks = dct_2d(blocks)
# print(dct_blocks.shape)

# mode = 420
# dct_block = DctBlock(
#     dct_size,
#     mode,
#     ???
# )

In [None]:
import math

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from scipy.stats import ortho_group, special_ortho_group
from scipy.io import savemat
from PIL import Image

from transforms.dct import dct_2d, idct_2d

seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

x = (
    torch.tensor(np.asarray(Image.open("backup/kodim23.png").convert("L"), dtype=float))
    / 255.0
)
x = x - x.mean()

plt.figure()
plt.title("x")
plt.imshow(x, cmap="gray")
plt.show()

n = 8
(h, w) = x.shape

if (w % n != 0) or (h % n != 0):
    raise ValueError("Invalid n")

wch = w // n
hch = h // n
N = n * n
nbins = N

plt.figure()
plt.title("x hist")
plt.hist(x.reshape(-1), bins=nbins)
plt.show()

X = dct_2d(x)

Xch = X.unfold(0, hch, hch).unfold(1, wch, wch).reshape((-1, hch, wch))
print("Xch", Xch.shape)
savemat("Xch.mat", {"Xch": Xch.reshape(N, -1).detach().cpu().numpy()})

plt.figure()
plt.title("X energy")
plt.bar(np.arange(N), Xch.square().mean((1, 2)))
plt.yscale("log")
# plt.hist(X.reshape(-1), bins=nbins)
plt.show()

H = torch.empty(N, N).type_as(Xch)
nn.init.orthogonal_(H)
# H = (H - H.mean()) / H.std()
print("H mean", H.mean(), "std", H.std())

plt.figure()
plt.title("H")
plt.imshow(H)
plt.show()

plt.figure()
plt.title("H hist")
plt.hist(H.reshape(-1), bins=nbins)
plt.show()

Y = torch.matmul(H, Xch.reshape(N, -1)).reshape(Xch.shape)

plt.figure()
plt.title("Y energy")
plt.bar(np.arange(N), Y.square().mean((1, 2)))
plt.yscale("log")
plt.show()

plt.figure()
plt.title("Y hist")
plt.hist(Y.reshape(-1), bins=nbins)
plt.show()

# trying to convert normal to uniform (CDF of normal is uniform)
J = (1 / 2) * (1 + torch.erf((H - H.mean()) / (H.std() * math.sqrt(2))))

plt.figure()
plt.title("J")
plt.imshow(J)
plt.show()

plt.figure()
plt.title("J hist")
plt.hist(J.reshape(-1), bins=nbins)
plt.show()

# Z = J * X
Z = torch.matmul(J, Xch.reshape(N, -1)).reshape(Xch.shape)

plt.figure()
plt.title("Z energy")
plt.bar(np.arange(N), Z.square().mean((1, 2)))
plt.yscale("log")
plt.show()

plt.figure()
plt.title("Z hist")
plt.hist(Z.reshape(-1), bins=nbins)
plt.show()

W = idct_2d(Xch.reshape((N, -1))).reshape(Xch.shape)
# W = torch.empty(Xch.shape)
# for i in range(W.shape[0]):
#     W[i, :, :] = idct_2d(Xch[i])
print("W:", W.shape)

plt.figure()
plt.title("W energy")
plt.bar(np.arange(N), W.square().mean((1, 2)))
plt.yscale("log")
plt.show()

plt.figure()
plt.title("W hist")
plt.hist(W.reshape(-1), bins=nbins)
plt.show()

O = torch.tensor(ortho_group.rvs(N))
print(np.dot(O, O.T))

plt.figure()
plt.title("O")
plt.imshow(O)
plt.show()

plt.figure()
plt.title("O hist")
plt.hist(O.reshape(-1), bins=nbins)
plt.show()

# Q = torch.tensor(O) * X
Q = torch.matmul(O, Xch.reshape(N, -1)).reshape(Xch.shape)

plt.figure()
plt.title("Q energy")
plt.bar(np.arange(N), Q.square().mean((1, 2)))
plt.yscale("log")
plt.show()

plt.figure()
plt.title("Q hist")
plt.hist(Q.reshape(-1), bins=nbins)
plt.show()

In [None]:
import torch
import matplotlib.pyplot as plt

torch.random.manual_seed(42)


def rand_ortho(n: int) -> torch.Tensor:
    A = torch.zeros(n, n, dtype=torch.float64)

    vals = torch.tensor([1, -1])
    row = torch.zeros(1, n)

    i = 0
    while i < n:
        idx = torch.randint(0, 2, row.shape)
        A[i] = vals[idx]
        rank = torch.linalg.matrix_rank(A)

        if rank == i + 1:
            i += 1

    U, _, _ = torch.linalg.svd(A)
    U = U / U.square().sum(dim=1).sqrt()

    return U


# For comparison with MATLAB version:
# n = 8

# A = torch.zeros(n, n)

# vals = torch.tensor([1, -1])
# row = torch.zeros(1, n)

# A_rows = torch.tensor([
#    [-1,    1,   -1,    1,   -1,    1,    1,   -1 ],
#    [-1,    1,    1,    1,    1,   -1,   -1,   -1 ],
#    [ 1,    1,   -1,   -1,   -1,   -1,    1,    1 ],
#    [ 1,   -1,   -1,    1,    1,    1,    1,   -1 ],
#    [ 1,   -1,   -1,   -1,   -1,    1,   -1,    1 ],
#    [ 1,   -1,   -1,    1,   -1,   -1,   -1,   -1 ],
#    [ 1,    1,    1,    1,    1,   -1,    1,    1 ],
#    [ 1,    1,   -1,   -1,   -1,   -1,    1,    1 ],
#    [-1,    1,   -1,   -1,   -1,   -1,    1,   -1 ],
#    [ 1,    1,   -1,    1,    1,   -1,   -1,   -1 ],
#    [-1,   -1,    1,    1,   -1,   -1,    1,   -1 ],
#    [-1,   -1,    1,   -1,    1,    1,   -1,    1 ],
# ])
# iA = 0

# i = 0
# while i < n:
#     idx = torch.randint(0, 2, row.shape)
#     # print(idx)
#     # A[i] = vals[idx]
#     A[i] = A_rows[iA]
#     iA += 1

#     rank = torch.linalg.matrix_rank(A)
#     print(rank)

#     if rank == i + 1:
#         i += 1


# Aref = torch.tensor([
#     -1,    1,   -1,    1,   -1,    1,    1,   -1,
#     -1,    1,    1,    1,    1,   -1,   -1,   -1,
#      1,    1,   -1,   -1,   -1,   -1,    1,    1,
#      1,   -1,   -1,    1,    1,    1,    1,   -1,
#      1,   -1,   -1,   -1,   -1,    1,   -1,    1,
#      1,   -1,   -1,    1,   -1,   -1,   -1,   -1,
#      1,    1,    1,    1,    1,   -1,    1,    1,
#     -1,   -1,    1,   -1,    1,    1,   -1,    1,
# ], dtype=float).reshape(n, n)

# print("A:", A)
# print("A == Aref: ", (A == Aref).all().item())

# U, S, Vh = torch.linalg.svd(A)

Osvd = rand_ortho(N)
print(np.dot(Osvd, Osvd.T))

plt.figure()
plt.title("Osvd")
plt.imshow(Osvd)
plt.show()

plt.figure()
plt.title("Osvd hist")
plt.hist(Osvd.reshape(-1), bins=nbins)
plt.show()

print(Xch.dtype)
Osvd_res = torch.matmul(Osvd, Xch.reshape(N, -1)).reshape(Xch.shape)

plt.figure()
plt.title("Osvd result energy")
plt.bar(np.arange(N), Osvd_res.square().mean((1, 2)))
plt.yscale("log")
plt.show()

plt.figure()
plt.title("Osvd result hist")
plt.hist(Osvd_res.reshape(-1), bins=nbins)
plt.show()

In [None]:
import scipy.io as sio

print(sio.loadmat("reference/H48.mat")["H"])