In [None]:
%load_ext autoreload 
%autoreload 2 

from pathlib import Path

import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.cm as cm

from torch.nn import NLLLoss, LogSoftmax
from torchvision import transforms
from PIL import Image

from lvc import create_dct
from test_nn import get_imagenet_labels, WEIGHTS, Dct, Idct
from probe import probe
from utils import set_device

device = set_device()
norm = "ortho"
show_dct = False
show_gradients = True
do_dct_test = False
img_size = (3, 256, 256)  # size of resized images
inp_size = (3, 224, 224)  # size of cropped images fed into NN
N = 1
dct_size = 16

# Setup data
inp_dir = '/home/kubouch/data/imagenet_subsets/subset_10class_1000perclass/val'
# inp_dir = '/home/jakub/pictures/kodim/raw'
glob = '*/*.jpg'
# glob = '*23.png'
pics = [str(img) for img in Path(inp_dir).glob(glob)][:N]
labels = get_imagenet_labels(pics)  # np.zeros(len(pics))
batch_size = min(len(pics), 32)
w_human = torch.Tensor([0.299, 0.587, 0.114])

# Try block-based DCT
show_gradients = False
preprocess = transforms.Compose(
    [
        transforms.Resize(img_size[1]),
        transforms.CenterCrop(inp_size[1]),
        transforms.ToTensor(),
        # RgbToYcbcr(W),
        #  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

img = preprocess(Image.open(pics[0]).convert("RGB"))

# dct_enc, dct_dec = create_dct(
#     (dct_size, dct_size),
#     dct_size * dct_size,
#     w_human,
#     device,
#     False,
#     do_print=False,
# )

model_name = "alexnet"
model = torch.hub.load(
    "pytorch/vision:v0.13.1", model_name, weights=WEIGHTS[model_name]
)
activ = LogSoftmax(dim=1)
loss_func = NLLLoss()

grads = probe(
    model,
    pics,
    labels,
    preprocess,
    Dct(),
    Idct(),
    activ,
    loss_func,    
    device=device,
)

print("Img size: ", img.shape)
dct = Dct()(img)
dct_blocks = dct[0][0]
dct_y, dct_u, dct_v = dct_blocks.chunk(3)
print("DCT shape: ", dct_blocks.shape)
print("DCT Y shape: ", dct_y.shape)
print("DCT U shape: ", dct_u.shape)
print("DCT V shape: ", dct_v.shape)
idct = Idct()(dct)
print("IDCT shape: ", idct.shape)
print((idct - img).mean().square())

idct_img = transforms.ToPILImage()(idct)

plt.figure()
plt.imshow(idct_img)
plt.figure()
plt.imshow(transforms.ToPILImage()(img))
plt.figure()
plt.imshow(transforms.ToPILImage()(grads[0]))

# cmap = cm.jet
# gmin = g_yuv_mean_blocks.min()
# gmax = g_yuv_mean_blocks.max()

# dct_y_mean = dct_y.mean(dim=0)
# dct_u_mean = dct_u.mean(dim=0)
# dct_v_mean = dct_v.mean(dim=0)

# plt.figure()
# plt.subplot(311)
# plt.imshow(dct_y_mean, cmap=cmap)#, norm=plt.Normalize(gmin, gmax))
# plt.colorbar(fraction=0.045)
# plt.title("sensitivity map -- mean DCT gradient (Y)")

# plt.subplot(312)
# plt.imshow(dct_u_mean, cmap=cmap)#, norm=plt.Normalize(gmin, gmax))
# plt.colorbar(fraction=0.045)
# plt.title("sensitivity map -- mean DCT gradient (U)")

# plt.subplot(313)
# plt.imshow(dct_v_mean, cmap=cmap)#, norm=plt.Normalize(gmin, gmax))
# plt.colorbar(fraction=0.045)
# plt.title("sensitivity map -- mean DCT gradient (V)")

# plt.show()

In [None]:
import torch

t = torch.rand(5)
t.requires_grad_(True)
t.retain_grad()
print(t)

ttop, ttop_idx = t.topk(3)
ttop.requires_grad_(True)
ttop.retain_grad()
print(ttop)

y = ttop.sum()
print(y)

y.backward()
print(ttop.grad)
print(t.grad)