In [None]:
# Trying to implement GRACE using torchjpeg

import math
from pathlib import Path

import plotly.express as px
import torch
import torch.nn as nn
import torchvision
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import to_tensor, to_pil_image

import torchjpeg.codec
from gradients.grace import *
from lvc.lvc import *
from models.model import plot_channels
from transforms.dct import *
from transforms.metrics import mse_psnr
from utils import block_process, Size

INP_FILE = "/home/jakub/data/kodim/raw/kodim23.png"
ENC_FILE = "kodim_out.jpg"
OUT_FILE = "kodim_out_dec.png"
# RUN_DIR = Path("/media/data1/jakub/lvc-for-cv/experiments/runs/run43_keep")  # full-frame DCT
# RUN_DIR = Path("/media/data1/jakub/lvc-for-cv/experiments/runs/run46_keep")  # full-frame DCT with 255 range
# RUN_DIR = Path("/media/data1/jakub/lvc-for-cv/experiments/runs/run44_keep")  # block-based DCT
# RUN_DIR = Path("/media/data1/jakub/lvc-for-cv/experiments/runs/run45_keep") # subtracted mean => less q dynamic range
# RUN_DIR = Path("/media/data1/jakub/lvc-for-cv/experiments/runs/run47_keep")  # all probe variants
# RUN_DIR = Path(
#     "/media/data1/jakub/lvc-for-cv/experiments/runs/run48_keep"
# )  # all probe variants
# RUN_DIR = Path("/media/data1/jakub/lvc-for-cv/experiments/runs/run254")  # all probe variants, probed with 1 image
# RUN_DIR = Path("/media/data1/jakub/lvc-for-cv/experiments/runs/run256")  # all probe variants, probed with 1 image, DCT 16x16
RUN_DIR = Path(
    "/home/jakub/git/nn-spectral-sensitivity/experiments_tupu/runs/run10_keep"
)  # all probe variants

QUALITY = 100
COLOR_SAMP_FACTOR_VERTICAL = 1
COLOR_SAMP_FACTOR_HORIZONTAL = 1

W_HUMAN = (0.299, 0.587, 0.114)

# Get probe results
model = "fastseg_small"
mode = 444
dist = "dist_abs"
sub = "submean"
dct_size = Size(h=8, w=8)
dctsz = "ff"  # bb8x8 or ff
sc = 255  # 1 or 255

# generated by experiments.probe_only():
try:
    probe_results = torch.load(RUN_DIR / f"probe_result_full_{model}_{mode}.pt")
except FileNotFoundError:
    try:
        probe_results = torch.load(
            RUN_DIR / f"probe_result_full_{model}_{mode}_{dist}_{sub}.pt"
        )
    except FileNotFoundError:
        probe_results = torch.load(
            RUN_DIR / f"probe_result_full_{model}_{mode}_{dist}_{sub}_{dctsz}_sc{sc}.pt"
        )

W = probe_results["W"]
print(f"W: {W}")

norm_abs = lambda tensor: tensor.abs().mean()
# norm_sq = lambda tensor: tensor.square().sum().sqrt()

g = probe_results["grads_yuv"]
dct = probe_results["dct_yuv"]
print(f"g: {g.shape}, min: {g.min()}, max: {g.max()}")
print(f"dct: {dct.shape}, min: {dct.min()}, max: {dct.max()}")

if dctsz == "ff":
    # average DCT/grad obtained with full-frame DCT
    g_block_norm = block_process(
        g, (int(g.shape[1] / dct_size.h), int(g.shape[2] / dct_size.w)), norm_abs
    )
    dct_block_norm = block_process(
        dct,
        (int(dct.shape[1] / dct_size.h), int(dct.shape[2] / dct_size.w)),
        norm_abs,
    )
else:
    # use this for DCT/grad obtained with block DCT
    sz = math.ceil(g.shape[0] / 3)
    g_block_norm = torch.stack([x.abs().mean(dim=0) for x in g.split(sz)])
    dct_block_norm = torch.stack([x.abs().mean(dim=0) for x in dct.split(sz)])

mx = torch.tensor([g_block_norm.max(), dct_block_norm.max()]).max()
mn = torch.tensor([g_block_norm.min(), dct_block_norm.min()]).min()

# Bring to 16 bits
# g_block_norm = (2**16 - 1) * (g_block_norm - mn) / (mx - mn) + 1
# dct_block_norm = (2**16 - 1) * (dct_block_norm - mn) / (mx - mn) + 1

print(f"g_block_norm: {g_block_norm.shape}")
print(f"dct_block_norm: {dct_block_norm.shape}")
print(f"total dynamic range: {mn} -- {mx}: {mx / mn}")
assert g_block_norm.shape[1] == dct_size.h
assert g_block_norm.shape[2] == dct_size.w
assert dct_block_norm.shape[1] == dct_size.h
assert dct_block_norm.shape[2] == dct_size.w
plot_channels(
    g_block_norm.cpu().detach().numpy(), "g_block_norm", width=500, height=1000
)
plot_channels(
    dct_block_norm.cpu().detach().numpy(),
    "dct_block_norm",
    log=True,
    width=500,
    height=1000,
)

approx_quant = False

if approx_quant:
    # Choose B such that with approximate quantization table max. q is 255. This
    # gives a reasonable initial ballpark value.
    B = 255 * g_block_norm.min().item() * dct_size.w * dct_size.h / 2
else:
    B = 5e-5

print(f"B: {B}")

### TEST
q2 = torch.zeros_like(g)
d2 = torch.zeros_like(g)
bounded2 = torch.zeros_like(g)
max_loss_increase2 = torch.zeros_like(g)

for i, (g_norm, dct_norm) in enumerate(zip(g_block_norm, dct_block_norm)):
    print(i)
    q2[i], d2[i], bounded2[i], max_loss_increase2[i] = get_quant_table(g, dct, B, do_print=False)

q2_block_norm = block_process(
    q2, (int(q2.shape[1] / dct_size.h), int(q2.shape[2] / dct_size.w)), norm_abs
)

plot_channels(
    q2_block_norm.cpu().detach().numpy(),
    f"Q2 table (Y/U/V), B = {B}",
    log=False,
    width=500,
    height=1000,
    # save="experiments/plots/grace_q_table.png",
    save=None,
)

### TEST


q = torch.zeros_like(g_block_norm)
d = torch.zeros_like(g_block_norm)
bounded = torch.zeros_like(g_block_norm)
max_loss_increase = torch.zeros_like(g_block_norm)

for i, (g_norm, dct_norm) in enumerate(zip(g_block_norm, dct_block_norm)):
    if approx_quant:
        q[i], d[i] = get_quant_table_approx(g_norm, B)
    else:
        q[i], d[i], bounded[i], max_loss_increase[i] = get_quant_table(
            g_norm, dct_norm, B, do_print=True
        )

QUANT = q.type(torch.int16)

plot_channels(
    q.cpu().detach().numpy(),
    f"Q table (Y/U/V), B = {B}",
    log=False,
    width=500,
    height=1000,
    save="experiments/plots/grace_q_table.png",
)
print("dct_block_norm:")
print(dct_block_norm)
print("g_block_norm:")
print(g_block_norm)
print("d:")
print(d)
print("q:")
print(q)
print(f"q dynamic range: {q.max() / q.min()}")

qdct_block_norm = (dct_block_norm / q).round()
print("quantized DCT block norm")
print(qdct_block_norm)
# plot_channels(qdct_block_norm, "quantized DCT block norm")

print("bounded:")
print(bounded)
print("q - 2 * dct_block_norm (should be negative or 0)")
print(q - 2 * dct_block_norm)
print("d <= max_loss_increase")
print(d <= (g_block_norm * dct_norm))

In [None]:
# Run JPEG encode + decode (requires first cell)

from transforms import Interpolate

# Read image
# INP_FILE = "probed_cityscapes_images_batchsize1/batch0_img0.png"
img_rgb = to_tensor(Image.open(INP_FILE))

if img_rgb.shape[0] > 3:
    img_rgb = img_rgb[:3]

# RGB -> YUV
rgb2yuv = RgbToYcbcr(W)
img_yuv = rgb2yuv(img_rgb)
print(f"img_yuv {img_yuv.shape}")

# encode
dimensions, quantization, Y_coefficients, CbCr_coefficients, enc_data = (
    torchjpeg.codec.quantize_at_quality_custom(
        img_yuv,
        QUALITY,
        QUANT,
        COLOR_SAMP_FACTOR_VERTICAL,
        COLOR_SAMP_FACTOR_HORIZONTAL,
    )
)

# compare JPEG DCT against our DCT
y_dct = Y_coefficients.reshape(-1, 8, 8)
cb_dct = CbCr_coefficients[0].reshape(-1, 8, 8)
cr_dct = CbCr_coefficients[1].reshape(-1, 8, 8)


dct_ycbcr = torch.stack(
    [
        y_dct.float().abs().mean(dim=0).squeeze(),
        cb_dct.float().abs().mean(dim=0).squeeze(),
        cr_dct.float().abs().mean(dim=0).squeeze(),
    ]
)

mask = torch.ones_like(dct_ycbcr).bool()
mask[:, 0, 0] = False  # ignore DC coeff.
minval = 1
maxval = 100

# dct_ycbcr = Interpolate(minval, maxval, mask)(dct_ycbcr)
# dct_ycbcr[:, 0, 0] = minval
plot_channels(
    dct_ycbcr.cpu().numpy(),
    "DCT after JPEG",
    log=False,
    width=500,
    height=1000,
)

dct_grace = dct_block_norm
# dct_grace = Interpolate(minval, maxval, mask)(dct_grace)
# dct_grace[:, 0, 0] = minval
plot_channels(
    dct_grace.cpu().detach().numpy(),
    "DCT from probe",
    log=False,
    width=500,
    height=1000,
)

diff = dct_ycbcr - dct_grace
plot_channels(diff.cpu().numpy(), "Diff", width=500, height=1000)

print("Encoded JPEG size:", enc_data.shape, enc_data.dtype)

print("quantization:")
print(quantization)

torchjpeg.codec.write_coefficients_custom(
    ENC_FILE, dimensions, quantization, Y_coefficients, CbCr_coefficients
)

# decode
dimensions, quantization, Y_coefficients, CbCr_coefficients = (
    torchjpeg.codec.read_coefficients(ENC_FILE)
)
print("Y", Y_coefficients.shape, "CbCr", CbCr_coefficients.shape)
out_yuv = torchjpeg.codec.reconstruct_full_image(
    Y_coefficients, quantization, CbCr_coefficients, dimensions, raw=True
)
print("out_yuv:", out_yuv.shape)

# YUV -> RGB
yuv2rgb = YcbcrToRgb(W)
out_rgb = yuv2rgb(out_yuv).clamp(0.0, 1.0)

to_pil_image(out_rgb).save(OUT_FILE)

In [None]:
# testing block DCT (requires first cell)
import matplotlib.pyplot as plt

chunk_size = 64
mode = 444
dct_size = (8, 8)
device = "cpu"

preprocess = nn.ModuleList(
    [
        WrapMetadata(chunk_size),
        Pad(mode, dct_size),
        RgbToYcbcrMetadata(W),
        ChunkSplit(dct_size),
    ]
)

postprocess = nn.ModuleList(
    [
        ChunkCombine(mode, device, dct_size, is_half=False),
        YcbcrToRgbMetadata(W),
        Crop(),
        StripMetadata(),
    ]
)

dct = Dct(norm="ortho")
idct = Idct(norm="ortho")

print("yuv shape:", img_yuv.shape)
inp = img_yuv
for layer in preprocess:
    if type(inp) == tuple:
        inp = layer(*inp)
    else:
        inp = layer(inp)

inp, meta = inp

print("before DCT:", inp.shape)
dct_block = dct(inp)
print("block DCT:", dct_block.shape)

out_dct = (dct_block, meta)
for layer in postprocess:
    if type(out_dct) == tuple:
        out_dct = layer(*out_dct)
    else:
        out_dct = layer(out_dct)

idct_block = idct(dct_block)
print("block IDCT:", idct_block.shape)

out = (idct_block, meta)
for layer in postprocess:
    if type(out) == tuple:
        out = layer(*out)
    else:
        out = layer(out)

print("out:", out.shape)
mse, psnr = mse_psnr(out, img_yuv)
print(f"psnr: {psnr:.3f} dB")

fig = px.imshow(out_dct.permute(1, 2, 0).cpu().detach().numpy(), title="dct")
fig.show()
fig = px.imshow(out.permute(1, 2, 0).cpu().detach().numpy(), title="out")
fig.show()

# print("trying block DCT")
# yuv_block2 = dct_2d_block(img_yuv.unsqueeze(0), "ortho", 8, 8)
# print("yuv block:", yuv_block2.shape)
# idct_block = idct_2d(yuv_block2)
# print("idct block:", idct_block.shape)

In [None]:
# Search for JPEG compression ratio
# Requires first cell to get initial B estimate and W

import itertools

import torch

from utils.q_search import find_val, ParamConfig

# cpus = [0]
# os.sched_setaffinity(0, set(cpus))
# affinity = os.sched_getaffinity(0)
# print(f"Running on CPUs: {affinity}")
# torch.set_num_threads(len(cpus))

nbits_per_sym = [2, 4, 6]  # 4-, 16- and 64-QAM
# target_crs = [0.03125, 0.06250, 0.12500, 0.25000, 0.50000, 1.00000]
target_crs = [0.50000, 1.00000]
nimages = 1
res = {}

params_jpeg: ParamConfig = {
    "codec": "jpeg",
    "init": 50,
    "min": 1,
    "max": 100,
    "g_block_norm": None,
    "W": None,
    "valname": "Q",
    "fmt": "{:3}",
}

params_grace: ParamConfig = {
    "codec": "grace",
    "init": B,
    "min": B / 255,
    "max": B * 255,
    "g_block_norm": g_block_norm,
    "W": W,
    "valname": "B",
    "fmt": "{:12.5e}",
}

params = params_grace

for nbits, target_cr in itertools.product(nbits_per_sym, target_crs):
    q, cr, psnr = find_val(
        target_cr,
        nbits,
        params,
        "coco",
        nimages=nimages,
        do_print=True,
    )
    res[(nbits, target_cr)] = (q, cr)
    qfmt = params["fmt"].format(q)
    print(
        f"{2**nbits:2}-QAM, target LCT CR: {target_cr:.5f}, actual LCT CR: {cr:.5f}, Q: {qfmt}, {psnr:7.3f} dB PSNR"
    )

print("Done")
torch.save(res, "q_search_jpeg_grace.pt")

In [None]:
from experiments import q_search

outdir = Path("experiments/q_search_test")
outdir.mkdir(exist_ok=True, parents=True)
q_search(outdir)