In [None]:
# Plotting distribution of the difference between input and reconstructed images

%load_ext autoreload 
%autoreload 2 

import math

import matplotlib.pyplot as plt
import numpy as np
import torch

from pathlib import Path
from PIL import Image
from scipy.optimize import curve_fit
from torchvision import transforms

from lvc import create_lvc
from utils import set_device

device = set_device()

# img_dir = Path("/home/jakub/pictures/kodim/raw")
img_dir = Path("/home/kubouch/pictures/kodim/raw")
image_names = list(img_dir.glob("*.png"))
images = torch.stack([
    # transforms.CenterCrop((288, 352))(
    transforms.Resize((512, 768))(
    # transforms.Resize((288, 672))(
        transforms.ToTensor()(
            Image.open(image_name).convert(mode="RGB")
        )
    ) 
    for image_name in image_names 
])
print("Images: ", images.shape)

lvc_params_zf = {
    "csnr_db": 0,
    "cr": 1.0,
    "mode": 444,
    # "chunk_w": 32,
    # "chunk_h": 32,
    "nchunks": 64,
    "seed": 42,
    "estimator": "zf",  # llse
    "packet_loss": None, #0.0, 
    # "dct_w": 16,
    # "dct_h": 16,
    # "grouping": "vertical_uv",
}

results_zf = {"noise": []}
results_llse = {"noise": []}

lvc_params_llse = lvc_params_zf
lvc_params_llse["estimator"] = "llse"

lvc_zf = create_lvc(lvc_params_zf, device, half=False, results=results_zf)
lvc_llse = create_lvc(lvc_params_llse, device, half=False, results=results_llse)

res_zf = lvc_zf(images)
res_llse = lvc_llse(images)

noise = torch.stack(results_zf["noise"])
print("noise mean: {}".format(noise.mean()))

TODO: Do this per chunk, not on whole image
diff_zf = res_zf - images
diff_llse = res_llse - images

nbins = 128
    
def plot_gaussian_histogram(ax, values: torch.Tensor, nbins: int):
    def gaussian(x, mean, amplitude, stddev):
        return amplitude \
            / (stddev * np.sqrt(2 * np.pi)) \
            * np.exp(-0.5 * ((x - mean) / stddev)**2) 

    bin_heights, bin_borders, _ = ax.hist(values.reshape(-1).numpy(), bins=nbins)
    bin_centers = bin_borders[:-1] + np.diff(bin_borders) / 2
    popt_gaussian, _ = curve_fit(
        gaussian, 
        bin_centers, 
        bin_heights, 
        p0=[0.0, 0.0, 1.0],
    )

    x = np.linspace(bin_borders[0], bin_borders[-1], 10000)
    ax.plot(
        x, 
        gaussian(x, *popt_gaussian), 
        "r-", 
        label="mean={:5.3f}, stddev={:5.3f}".format(
            tuple(popt_gaussian)[0], 
            tuple(popt_gaussian)[2],
        )
    )
    ax.legend(loc='upper right')
    
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
fig.suptitle("Difference distributions")

axs[0].set_title("Difference Histogram (ZF)")
plot_gaussian_histogram(axs[0], diff_zf, nbins)

axs[1].set_title("Difference Histogram (LLSE)")
plot_gaussian_histogram(axs[1], diff_llse, nbins)

axs[2].set_title("Noise histogram")
plot_gaussian_histogram(axs[2], noise, nbins)

plt.tight_layout()
plt.show()

In [None]:
# Estimating the power savings
#
# We're minimizing the power given a constraint on the total distortion C

%load_ext autoreload 
%autoreload 2 

import math
from pathlib import Path

import matplotlib.pyplot as plt
import torch
from PIL import Image
from torch.nn import NLLLoss, LogSoftmax
from torchvision import transforms

from lvc import Pad, LvcEncode, RgbToYcbcrMetadata, Downsample, SubtractMean, \
    ChunkSelect, ChunkSplit, DCT, DctBlock, Metadata, YuvImage
from probe import probe_images
from test_nn import WEIGHTS, Dct, Idct, get_imagenet_labels
from utils import set_device

device = set_device()

w_human = torch.Tensor([0.299, 0.587, 0.114])
norm = "ortho"
power = 1.0
do_print = False
half = False
references = None
results = None

csnr_db = 10
cr = 0.5
yuv_mode = 444
# chunk_w = 32
# chunk_h = 32
nchunks = 64
seed = 42
estimator = "zf"  # llse
packet_loss = None #0.0 
# dct_w = 16
# dct_h = 16
# grouping = vertical_uv
chunk_size = nchunks
dct_size = None

if dct_size is None:
    dct_layers = (DCT(half, norm), ChunkSplit(dct_size, do_print=do_print))
else:
    grouping = "vertical_uv"
    dct_layers = (
        ChunkSplit(dct_size, do_print=do_print),
        DctBlock(
            dct_size, yuv_mode, grouping, is_half=half, norm=norm, do_print=do_print
        ),
    )

lvc_encoder = LvcEncode(
    [
        Pad(yuv_mode, dct_size, do_print=do_print),
        RgbToYcbcrMetadata(w=w_human),
        Downsample(mode=yuv_mode),
        SubtractMean(device),
        dct_layers[0],
        dct_layers[1],
        ChunkSelect(cr, do_print=do_print),
    ],
    references,
    None, #results,
    chunk_size,
).to(device, non_blocking=True)

img_size = (3, 256, 256)  # size of resized images
inp_size = (3, 224, 224)  # size of cropped images fed into NN
nimages = 10
# img_dir = Path("/home/jakub/pictures/kodim/raw")
# image_names = list(img_dir.glob("*23.png"))
# img_dir = Path("/mnt/1tb_storage/data/imagenet_subsets/subset_10class_1000perclass/val")
img_dir = Path("/home/kubouch/data/imagenet_subsets/subset_10class_1000perclass/val")
image_names = list(img_dir.glob("*/*.jpg"))[:nimages]

preprocess = transforms.Compose(
    [
        transforms.Resize(img_size[1]),
        transforms.CenterCrop(inp_size[1]),
        transforms.ToTensor(),
        # RgbToYcbcr(W),
        #  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

images = torch.stack([
    preprocess(Image.open(image_name).convert(mode="RGB")).to(device)
    for image_name in image_names 
])
print("Images: ", images.shape)

results = lvc_encoder(images)
metadatas = [m for _, m in results]

show_gradients = False

# dct_enc, dct_dec = create_dct(
#     (dct_size, dct_size),
#     dct_size * dct_size,
#     w_human,
#     device,
#     False,
#     do_print=False,
# )

model_name = "alexnet"
model = torch.hub.load(
    "pytorch/vision:v0.13.1", model_name, weights=WEIGHTS[model_name]
)
for param in model.parameters():
    param.requires_grad = True

model = model.to(device)
model.eval()
activ = LogSoftmax(dim=1)
loss_func = NLLLoss()

labels = get_imagenet_labels(image_names)  # np.zeros(len(pics))

grads = probe_images(
    model,
    image_names,
    labels,
    preprocess,
    Dct(),
    Idct(),
    activ,
    loss_func,    
    device=device,
)
print("Probed gradients:", grads.shape)

report_n = 0
print("Image: ", '/'.join(Path(image_names[report_n]).parts[-2:]))
metadata = metadatas[report_n]
var_all = metadata.all_variances
g = grads[report_n]

g_split = ChunkSplit(dct_size)
g_metadata = Metadata(metadata.image_size, metadata.chunk_size)

g_chunks, _ = g_split(YuvImage(g[0], g[1], g[2]), g_metadata)
g_chunks_mean = g_chunks.abs().mean(dim=(1,2))

var_t = var_all[metadata.bitmap]
var_d = var_all[~metadata.bitmap]
g_t = g_chunks_mean[metadata.bitmap[:]]
g_d = g_chunks_mean[~metadata.bitmap[:]]
# K = len(var_t)
sigma2 = power / math.pow(10, csnr_db / 10) # noise power
C = torch.logspace(start=-3.8, end=-2, steps=100).to(device)

# Theorem 3
P_zf = sigma2 \
    * (g_t.square() * var_t).pow(1/3).sum().pow(3) \
    / (C - (g_d * var_d.sqrt()).sum()).square()

fig, axs = plt.subplots(1, 1, figsize=(7, 7))
axs.plot(C.cpu(), P_zf.cpu())
axs.set_xscale("log")
axs.grid(which="major")
axs.grid(which="minor")
axs.set_xlabel("C")
axs.set_ylabel("P")
plt.show()

# Theorem 1
import numpy as np

def pert(g, var_all, csnr_db, k, P=1.0):
    _, top_chunks_indices = var_all.topk(k)
    top_chunks_indices, _ = top_chunks_indices.sort()
    top_chunks = var_all[top_chunks_indices]

    bitmap = torch.zeros_like(var_all, dtype=torch.bool)
    bitmap[top_chunks_indices] = True

    var_t = var_all[bitmap]
    var_d = var_all[~bitmap]
    g_t = g[bitmap[:]]
    g_d = g[~bitmap[:]]
    
    tmp_sum = (g_t.square() * var_t).pow(2 / 3).sum()

    beta = torch.sqrt(P / tmp_sum * (g_t / var_t).pow(2 / 3))
    
    sigma2 = P / math.pow(10, csnr_db / 10) # noise power
    pert_t = (g_t * math.sqrt(sigma2) / beta).sum()
    pert_d = (g_d * var_d).sum()
    
    return pert_t + pert_d

fig, axs = plt.subplots(1, 1, figsize=(10, 10))
ks = np.arange(1, len(var_all) + 1)
for csnr_db in [0, 5, 10, 20, 30]:
    perts = [pert(g_chunks_mean, var_all, csnr_db, k).cpu() for k in ks]
    min_k = ks[np.argmin(perts)]
    print(f"CSNR {csnr_db:2d} dB: min. pert. at K = {min_k}")
    axs.plot(ks, perts, label=f"{csnr_db} dB")
    
axs.legend()
axs.grid(which="major")
axs.grid(which="minor")
axs.set_xlabel("K")
axs.set_ylabel("perturbation")
plt.show()
    

In [None]:
# Trying optimal power allocation

%load_ext autoreload 
%autoreload 2 

import math
from pathlib import Path

import matplotlib.pyplot as plt
import torch
from PIL import Image
from torch.nn import NLLLoss, LogSoftmax, CrossEntropyLoss
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms, datasets
from torchsummary import summary

from lvc import Pad, LvcEncode, RgbToYcbcrMetadata, Downsample, SubtractMean, \
    ChunkSelect, ChunkSplit, DCT, DctBlock, Metadata, YuvImage
from test_nn import WEIGHTS, Dct, Idct, get_imagenet_labels
from utils import set_device

device = set_device()

img_size = (3, 256, 256)  # size of resized images
inp_size = (3, 224, 224)  # size of cropped images fed into NN
nimages = 10

# img_dir = Path("/home/jakub/pictures/kodim/raw")
# image_names = list(img_dir.glob("*23.png"))
# img_dir = Path("/mnt/1tb_storage/data/imagenet_subsets/subset_10class_1000perclass/val")
img_dir = Path("/home/kubouch/data/imagenet_subsets/subset_10class_1000perclass/val")
image_names = list(img_dir.glob("*/*.jpg"))[:nimages]

weights = models.ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1
model = models.shufflenet_v2_x1_0(weights)
# model = model.to(device)
model.eval()
preprocess = weights.transforms()
from torchvision.io import read_image
img = read_image(str(image_names[5]))
batch = preprocess(img).unsqueeze(0)
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score:.1f}%")
plt.figure()
plt.imshow(transforms.ToPILImage()(img))
# transforms.Compose(
#     [
#         transforms.Resize(img_size[1]),
#         transforms.CenterCrop(inp_size[1]),
#         transforms.ToTensor(),
#         # RgbToYcbcr(W),
#         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
#     ]
# )

# batch_size = min(len(image_names), 32)
# labels = get_imagenet_labels(image_names)  # np.zeros(len(pics))
# dataset = LocalDataset(image_names, labels=labels, transform=preprocess)
# dataset = datasets.ImageNet(
#     "/home/kubouch/data/imagenet_subsets/subset_10class_1000perclass",
#     split="val",
#     transform=preprocess,
# )
# dataloader = DataLoader(dataset, batch_size=batch_size, pin_memory=True)
# summary(model, inp_size)
# activ = LogSoftmax(dim=1)
# activ = lambda x: x
# loss_func = NLLLoss()
# loss_func = CrossEntropyLoss()

# loss, acc, class_correct, class_total = val(
#     model, device, dataloader, activ, loss_func, 10)
# print(f"Loss: {loss}, Acc: {acc}")

In [None]:
# (FastSeg testing old vs new results)

%load_ext autoreload 
%autoreload 2 
%reset -f

from pathlib import Path

import torch

from models import fastseg, get_model
from models.model import plot_channels 
from utils import set_device

SEED = 42
torch.manual_seed(SEED)

color_space = "yuv"
num_batches = 8
batch_size = 8
num_batches_probe = 1
batch_size_probe = 1

device = set_device("cuda", do_print=True)
do_print = False
show_plots = False

for variant in ["small", "large"]:
    result_orig = fastseg.main(
        mode="eval", 
        device=device, 
        num_batches=num_batches, 
        batch_size=batch_size,
        num_workers=0,
        color_space=color_space,
        do_print=do_print,
        variant=variant,
    )

    result_probe = fastseg.main(
        mode="probe",
        device=device,
        num_batches=num_batches_probe, 
        batch_size=batch_size_probe,
        num_workers=0,
        color_space=color_space,
        do_print=do_print,
        show_plots=show_plots,
        variant=variant,
    )

    plot_channels(
        result_probe["grads_norm"].cpu().numpy(),
        [
            "block_norm DCT gradient (Y)",
            "block_norm DCT gradient (U)",
            "block_norm DCT gradient (V)",
        ],
        0,
        show=True,
        save=f"fastseg{variant}_grads_norm_test_old.png",
        tight_layout=False,
    )

    if variant == "large":
        snapshot = Path.home() / "data/models/fastseg/raw/large/best_checkpoint_ep172.pth"
    else:
        snapshot = Path.home() / "data/models/fastseg/raw/small/best_checkpoint_ep171.pth"

    config = {
        "name": "fastseg",
        "variant": variant,
        "snapshot": snapshot,
    }

    model = get_model(config, device, num_batches, batch_size, color_space=color_space)
    result_new_orig = model.eval()

    model_probe = get_model(config, device, num_batches_probe, batch_size_probe, color_space=color_space)
    result_new_probe = model_probe.run_probe(chunk_size=(128, 256))

    plot_channels(
        result_new_probe["grads_norm"].cpu().numpy(),
        [
            "block_norm DCT gradient (Y)",
            "block_norm DCT gradient (U)",
            "block_norm DCT gradient (V)",
        ],
        0,
        show=True,
        save=f"fastseg{variant}_grads_norm_test_new.png",
        tight_layout=False,
    )

    print(f"{variant}::      orig results:", result_orig)
    print(f"{variant}::     probe results:", list(result_probe.keys()), result_probe["W"])
    print(f"{variant}:: new  orig results:", result_new_orig)
    print(f"{variant}:: new probe results:", list(result_new_probe.keys()), result_new_probe["W"])


In [None]:
# Write result to JSON
import json

res2 = res.copy()

for k, v in res.items():
    res2[k] = res[k].tolist()

# with open("res_yuv_fixed.json", "w") as wf:
#     json.dump(res2, wf)


In [None]:
# Read from JSON

import json

with open("res_yuv.json", "r") as rf:
    res = json.load(rf)


In [None]:
# Plot result

import numpy as np
import matplotlib.pyplot as plt

ious = res["ious"][1:]
ious_g = res["ious_g"][1:]

ious = np.array(ious) * 100
ious_g = np.array(ious_g) * 100

K = np.arange(1, 65)

plt.figure()
plt.plot(K, ious, label="baseline")
plt.plot(K, ious_g, label="gradients")
plt.plot(K, [res["iou_orig"][0] * 100] * len(K), label="reference")
plt.xlabel("K")
plt.ylabel("mIoU (%)")
plt.legend()

diff = np.array(ious_g) - np.array(ious)

plt.figure()
plt.plot(K, [0.0] * len(K), label="zero")
plt.plot(K, diff, label="diff")
plt.xlabel("K")
plt.ylabel("mIoU_grad - mIoU_base (pp)")
plt.legend()
plt.show()


In [None]:
res


In [None]:
grad_norms = fastseg.main(
    mode="probe",
    device=device,
    num_batches=None,
    batch_size=batch_size_probe,
    num_workers=0,
    color_space=color_space,
    do_print=True,
    show_plots=show_plots,
)


In [None]:
# Estimating bitmap based on gradients
import torch

grad_norms_yuv = torch.tensor(
    [
        [
            [
                6.8752e-04,
                1.0005e-03,
                8.8672e-04,
                8.1265e-04,
                7.7627e-04,
                5.3197e-04,
                2.9156e-04,
                1.4191e-04,
            ],
            [
                7.6486e-04,
                7.0540e-04,
                7.1268e-04,
                6.7172e-04,
                5.6181e-04,
                3.7691e-04,
                2.3673e-04,
                1.5010e-04,
            ],
            [
                8.0659e-04,
                7.2506e-04,
                7.9096e-04,
                6.9572e-04,
                5.0154e-04,
                3.1885e-04,
                2.1427e-04,
                1.6550e-04,
            ],
            [
                7.9612e-04,
                6.7133e-04,
                6.5953e-04,
                5.8854e-04,
                4.3853e-04,
                3.0538e-04,
                2.1175e-04,
                1.4944e-04,
            ],
            [
                7.4625e-04,
                5.3806e-04,
                4.3351e-04,
                4.0188e-04,
                3.5539e-04,
                2.7602e-04,
                1.9974e-04,
                1.4987e-04,
            ],
            [
                5.0461e-04,
                3.5121e-04,
                2.7269e-04,
                2.6673e-04,
                2.4662e-04,
                2.0697e-04,
                1.6101e-04,
                1.5607e-04,
            ],
            [
                2.7449e-04,
                2.0482e-04,
                1.9560e-04,
                1.9627e-04,
                1.7156e-04,
                1.3907e-04,
                1.2345e-04,
                1.6210e-04,
            ],
            [
                1.3655e-04,
                1.2796e-04,
                1.3620e-04,
                1.4445e-04,
                1.5163e-04,
                1.3441e-04,
                1.0513e-04,
                1.0732e-04,
            ],
        ],
        [
            [
                3.6655e-04,
                2.3057e-04,
                1.7865e-04,
                1.8027e-04,
                1.6247e-04,
                1.1040e-04,
                6.7999e-05,
                7.4597e-05,
            ],
            [
                3.8346e-04,
                2.2649e-04,
                1.6651e-04,
                1.3307e-04,
                1.0197e-04,
                6.6842e-05,
                5.2780e-05,
                8.2484e-05,
            ],
            [
                2.7195e-04,
                1.7968e-04,
                1.3611e-04,
                9.7587e-05,
                6.4875e-05,
                4.3859e-05,
                4.0964e-05,
                6.0591e-05,
            ],
            [
                1.6147e-04,
                1.1255e-04,
                8.4427e-05,
                5.8382e-05,
                3.6385e-05,
                2.4421e-05,
                2.7155e-05,
                3.5520e-05,
            ],
            [
                1.1307e-04,
                8.1603e-05,
                5.9668e-05,
                3.9579e-05,
                2.4929e-05,
                1.7698e-05,
                2.1803e-05,
                2.5623e-05,
            ],
            [
                1.0168e-04,
                6.4970e-05,
                4.5855e-05,
                3.1443e-05,
                1.9750e-05,
                1.8425e-05,
                2.3671e-05,
                2.8941e-05,
            ],
            [
                9.7706e-05,
                5.5146e-05,
                3.9245e-05,
                3.0614e-05,
                2.4767e-05,
                2.0179e-05,
                2.4544e-05,
                3.2968e-05,
            ],
            [
                9.7746e-05,
                6.4203e-05,
                4.7085e-05,
                4.9695e-05,
                4.7705e-05,
                3.3300e-05,
                3.0521e-05,
                2.8349e-05,
            ],
        ],
        [
            [
                3.3019e-04,
                2.4871e-04,
                1.8138e-04,
                1.3581e-04,
                9.3264e-05,
                7.3249e-05,
                8.1976e-05,
                1.0235e-04,
            ],
            [
                3.8119e-04,
                2.0982e-04,
                1.2904e-04,
                8.5882e-05,
                5.9820e-05,
                4.8262e-05,
                6.1512e-05,
                1.1707e-04,
            ],
            [
                3.0196e-04,
                1.7015e-04,
                1.1067e-04,
                6.9676e-05,
                4.1495e-05,
                3.7797e-05,
                4.9818e-05,
                9.3241e-05,
            ],
            [
                1.9895e-04,
                1.3512e-04,
                1.0428e-04,
                6.9022e-05,
                3.6996e-05,
                3.4375e-05,
                3.8280e-05,
                5.0030e-05,
            ],
            [
                1.6139e-04,
                1.1244e-04,
                8.7535e-05,
                6.2750e-05,
                4.0694e-05,
                3.2123e-05,
                3.2318e-05,
                3.3000e-05,
            ],
            [
                1.2854e-04,
                9.0579e-05,
                6.8180e-05,
                5.0702e-05,
                3.7586e-05,
                2.6411e-05,
                3.1091e-05,
                3.9480e-05,
            ],
            [
                1.3134e-04,
                8.2332e-05,
                5.7584e-05,
                3.9073e-05,
                2.8568e-05,
                2.1731e-05,
                3.1127e-05,
                5.1141e-05,
            ],
            [
                1.1424e-04,
                8.4455e-05,
                6.8010e-05,
                4.6614e-05,
                2.6990e-05,
                2.1787e-05,
                3.3424e-05,
                4.4609e-05,
            ],
        ],
    ]
)

nsend = 8
print("K: ", nsend, ", CR: ", nsend / 64)

grad_norms_y = grad_norms_yuv[0].reshape(-1)

_, top_chunks_indices = grad_norms_y.topk(nsend)
print(top_chunks_indices)
top_chunks_indices, _ = top_chunks_indices.sort()
print(top_chunks_indices)

bitmap = torch.zeros_like(grad_norms_y, dtype=torch.int)
bitmap[top_chunks_indices] = 1
bitmap = bitmap.reshape(8, 8)

print(bitmap)


In [None]:
%matplotlib inline

from models.model import plot_channels
plot_channels(
    grad_norms.cpu().numpy(), 
    [
        "norm DCT gradient (Y)",
        "norm DCT gradient (U)",
        "norm DCT gradient (V)", 
    ], 
    0, 
    show=True,
    save="test.png"
)

In [None]:
# Gradient-optimized chunk selection 

%load_ext autoreload 
%autoreload 2 
%reset -f
%matplotlib inline

from multiprocessing import Pool, cpu_count
from pathlib import Path
# Force all new tensors to be created on cuda device
# torch.set_default_tensor_type('torch.cuda.FloatTensor')

import torch
import numpy as np

from models import run_model
from models.model import plot_channels


config_yolov8 = {
    "name": "yolov8",
    "variant": "n",
    "task": "detect",
    "snapshot": "yolov8n.pt",
    "unit": "mAP_50_95",
}

config_fastseg = {
    "name": "fastseg",
    "variant": "small",
    "snapshot": Path.home() / "data/models/fastseg/raw/small/best_checkpoint_ep171.pth",
    # "variant": "large",
    # "snapshot": Path.home() / "data/models/fastseg/raw/large/best_checkpoint_ep172.pth"
    "unit": "mean_iu",
}

config = config_fastseg

color_space = "yuv"
K = [2, 4, 8, 16, 32, 64, 128, 192] #range(8, 9)
lvc_params = [
    {
        "packet_loss": None,
        "seed": 42,
        "mode": 444,
        "cr": k / (64 * 3),
        "csnr_db": "inf",
        "estimator": "zf",
        "nchunks": 64,
        "color_space": color_space,
    }
    for k in K
]

# device = set_device("cuda:0")
num_batches = None
batch_size = 16
num_batches_probe = 8
batch_size_probe = 8

do_print = False
show_plots = False

file_name = f'{config["name"]}{config["variant"]}_{color_space}_probe{num_batches_probe}nb{batch_size_probe}bs'

# Setup for parallel processing
ngpus = 1
ngpus = min(len(lvc_params), ngpus)
ncpus = cpu_count()

cpus = np.array_split(range(ncpus), ngpus)
K_groups = np.array_split(K, ngpus)
lvc_params_groups = np.array_split(lvc_params, ngpus)

configs = [config] * ngpus
devices = [f"cuda:{i}" for i in range(ngpus)]
color_spaces = [color_space] * ngpus
ranks = range(ngpus)
num_workers_groups = [0] * ngpus
num_batches_groups = [num_batches] * ngpus
batch_size_groups = [batch_size] * ngpus
num_batches_probe_groups = [num_batches_probe] * ngpus
batch_size_probe_groups = [batch_size_probe] * ngpus
do_print_groups = [do_print] * ngpus
show_plots_groups = [show_plots] * ngpus

results_groups = []

with Pool(processes=ngpus) as pool:
    results_groups = pool.starmap(
        run_model,
        zip(
            configs,
            devices,
            lvc_params_groups,
            color_spaces,
            ranks,
            cpus,
            num_workers_groups,
            num_batches_groups,
            batch_size_groups,
            num_batches_probe_groups,
            batch_size_probe_groups,
            do_print_groups,
            show_plots_groups, 
        ),
    )

torch.save(results_groups, f"{file_name}.pt")
plot_channels(results_groups[0]["grads_norm"], 
    [
        f"norm DCT gradient ({color_space[0].upper()})",
        f"norm DCT gradient ({color_space[1].upper()})",
        f"norm DCT gradient ({color_space[2].upper()})",
    ],
    0,
    show=True,
    save=f"{file_name}_grads_norm.png"
)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

res = {}
score_orig = []

for results_group in results_groups:
    score_orig.append(results_group["orig"][config["unit"]])

    for params, results_lvc, results_lvc_g in zip(
        results_group["lvc_params"],
        results_group["lvc"],
        results_group["lvc_g"],
    ):
        k = int(params["cr"] * params["nchunks"])
        unit = config["unit"]
        score = results_lvc[unit]
        score_g = results_lvc_g[unit]

        res[k] = {}
        res[k]["score"] = score
        res[k]["score_g"] = score_g

res = {k: res[k] for k in sorted(res.keys())}
scores = [val["score"] * 100 for val in res.values()]
scores_g = [val["score_g"] * 100 for val in res.values()]

print(score_orig)

plt.figure()
plt.subplot(211)
plt.plot(res.keys(), scores, label="base")
plt.plot(res.keys(), scores_g, label="grad")
plt.plot(res.keys(), [score_orig[0] * 100] * len(res), label="reference")
plt.xlabel("K")
plt.ylabel(config["unit"])
plt.legend()

plt.subplot(212)
plt.plot(res.keys(), [0.0] * len(res), label="zero")
plt.plot(res.keys(), np.array(scores_g) - np.array(scores), label="grad - base")
plt.xlabel("K")
plt.ylabel(f"{config['unit']}_grad - {config['unit']}_base")

plt.savefig(f"{file_name}_score.png")
plt.show()

In [None]:
print(score_orig[0] * 100, scores[-1], scores_g[-1])


In [None]:
import torch

# torch.save(results_groups, "res2.pt")
torch.save(results_groups, f"{file_name}.pt")