In [None]:
!pip install piqa



In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import linalg

import torch.nn as nn
import torchvision
import torch

In [None]:
if torch.cuda.is_available():
  device = torch.device("cuda")
else:
  device = torch.device("cpu")

device

device(type='cuda')

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# DCT Helper Function(s)


def dct(x, norm=None):
    x_shape = x.shape
    N = x_shape[-1]
    x = x.contiguous().view(-1, N)
    v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
    Vc = torch.view_as_real(torch.fft.fft(v, dim=1))  # add this line

    k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
    W_r = torch.cos(k)
    W_i = torch.sin(k)

    V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i

    if norm == 'ortho':
        V[:, 0] /= np.sqrt(N) * 2
        V[:, 1:] /= np.sqrt(N / 2) * 2

    V = 2 * V.view(*x_shape)

    return V


def idct(X, norm=None):

    x_shape = X.shape
    N = x_shape[-1]

    X_v = X.contiguous().view(-1, x_shape[-1]) / 2

    if norm == 'ortho':
        X_v[:, 0] *= np.sqrt(N) * 2
        X_v[:, 1:] *= np.sqrt(N / 2) * 2

    k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N)
    W_r = torch.cos(k)
    W_i = torch.sin(k)

    V_t_r = X_v
    V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)

    V_r = V_t_r * W_r - V_t_i * W_i
    V_i = V_t_r * W_i + V_t_i * W_r

    V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
    v = torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)

    x = v.new_zeros(v.shape)
    x[:, ::2] += v[:, :N - (N // 2)]
    x[:, 1::2] += v.flip([1])[:, :N // 2]

    return x.view(*x_shape)


def dct2(x):
    return torch.transpose(dct(torch.transpose(dct(x, norm='ortho'), -1, -2), norm='ortho'), -1, -2)

def idct2(X):
    return torch.transpose(idct(torch.transpose(idct(X, norm='ortho'), -1, -2), norm='ortho'), -1, -2)

In [None]:
# More Helper Function(s)

# inplace partition of image
def partition_inplace(img, block_size=8, crop=True):
    if crop:
        _, x, y = img.shape
        x_crop, y_crop = x - x % block_size, y - y % block_size
        img = img[...,:x_crop,:y_crop]
    else:
        assert img.shape[-1] % block_size == 0
        assert img.shape[-2] % block_size == 0
    c, x, y = img.shape
    img_partitions = np.zeros((x//block_size, y//block_size, c, block_size, block_size), dtype=img.dtype)
    for i in range(0, x, block_size):
        for j in range(0, y, block_size):
            block = img[...,i:i+block_size,j:j+block_size]
            img_partitions[i//block_size,j//block_size,...] = block
    return img_partitions

# inplace undo partition
def reduce_inplace(img_partitions):
    n, xb, yb, c, b, _ = img_partitions.shape
    img_reduced = torch.zeros((n, c, xb * b, yb * b), dtype=img_partitions.dtype, device=img_partitions.device)
    x, y = xb * b, yb * b
    for i in range(0, x, b):
        for j in range(0, y, b):
            img_reduced[...,i:i+b,j:j+b] = img_partitions[:,i//b,j//b,...]
    return img_reduced


# partitions block in 1d list of 8 by 8 in order
def partition(img, block_size=8, crop=True):
    if crop:
        _, x, y = img.shape
        x_crop, y_crop = x - x % block_size, y - y % block_size
        img = img[...,:x_crop,:y_crop]
    b = block_size
    _, x, y = img.shape
    return [img[...,i:i+b,j:j+b] for i in range(0, x, b) for j in range(0, y, b)]


# converts rgb to ycbcr colorspace
def rgb_ycbcr(im):
    im.astype(np.double)
    xform = np.array([[.299, .587, .114], [-.1687, -.3313, .5], [.5, -.4187, -.0813]])
    ycbcr = im.dot(xform.T)
    ycbcr[:, :, [1, 2]] += 128
    return np.round(ycbcr)


# converts the ycbcr colorspace back to rgb
def ycbcr_rgb(im):
    xform = np.array([[1, 0, 1.402], [1, -0.34414, -.71414], [1, 1.772, 0]])
    rgb = im.astype(np.double)
    rgb[:, :, [1, 2]] -= 128
    rgb = rgb.dot(xform.T)
    np.putmask(rgb, rgb > 255, 255)
    np.putmask(rgb, rgb < 0, 0)
    return np.round(rgb)

# zig-zag encoder
def zz_encode(block):
    i = [0, 0, 1, 2, 1, 0, 0, 1, 2, 3, 4, 3, 2, 1, 0, 0, 1, 2, 3, 4, 5, 6, 5, 4, 
         3, 2, 1, 0, 0, 1, 2, 3, 4, 5, 6, 7, 7, 6, 5, 4, 3, 2, 1, 2, 3, 4, 5, 6, 
         7, 7, 6, 5, 4, 3, 4, 5, 6, 7, 7, 6, 5, 6, 7, 7]
    j = [0, 1, 0, 0, 1, 2, 3, 2, 1, 0, 0, 1, 2, 3, 4, 5, 4, 3, 2, 1, 0, 0, 1, 2, 
         3, 4, 5, 6, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 7, 6, 5, 4, 3, 
         2, 3, 4, 5, 6, 7, 7, 6, 5, 4, 5, 6, 7, 7, 6, 7]
    idx = list(range(64))
    n, c, d, _ = block.shape
    encoded = torch.zeros((n, c, d * d), dtype=block.dtype, device=block.device)
    encoded[...,idx] = block[...,i,j]
    return encoded

# zig-zag decoder
def zz_decode(encoded):
    i = [0, 0, 1, 2, 1, 0, 0, 1, 2, 3, 4, 3, 2, 1, 0, 0, 1, 2, 3, 4, 5, 6, 5, 4, 
         3, 2, 1, 0, 0, 1, 2, 3, 4, 5, 6, 7, 7, 6, 5, 4, 3, 2, 1, 2, 3, 4, 5, 6, 
         7, 7, 6, 5, 4, 3, 4, 5, 6, 7, 7, 6, 5, 6, 7, 7]
    j = [0, 1, 0, 0, 1, 2, 3, 2, 1, 0, 0, 1, 2, 3, 4, 5, 4, 3, 2, 1, 0, 0, 1, 2, 
         3, 4, 5, 6, 7, 6, 5, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5, 6, 7, 7, 6, 5, 4, 3, 
         2, 3, 4, 5, 6, 7, 7, 6, 5, 4, 5, 6, 7, 7, 6, 7]
    idx = list(range(64))
    n, c, d = encoded.shape
    d = int(d ** 0.5)
    decoded = torch.zeros((n, c, d, d), dtype=encoded.dtype, device=encoded.device)
    decoded[...,i,j] = encoded[...,idx]
    return decoded

In [None]:
# LOSS FUNCTION(S)

from piqa import SSIM, MS_SSIM
from piqa.utils.functional import gaussian_kernel

interval = 2

def zz_batch_encode(tensor_input):
    n, xb, yb, c, b, _ = tensor_input.shape
    encoded = torch.zeros((n, xb * yb, c, b * b), dtype=tensor_input.dtype, device=tensor_input.device)
    t = 0
    for i in range(xb):
        for j in range(yb):
            encoded[:,t,...] = zz_encode(tensor_input[:,i,j,...])
            t += 1
    return encoded

def reconstruct_img(x, table):
    # batch, channels, x, y
    if table.shape[1] == 2:
        table = torch.cat((table, table[:,-1:,:,:]), axis=1)
    table = torch.unsqueeze(torch.unsqueeze(table, dim=1), dim=1)
    x_quantize = x / table
    x_reconstruct = idct2(torch.round(x_quantize) * table) + 128
    x_reconstruct[x_reconstruct < 0] = 0.0
    x_reconstruct[x_reconstruct > 255] = 255.0
    x_reconstruct = reduce_inplace(x_reconstruct)
    zz_quantized = zz_batch_encode(x_quantize)
    return zz_quantized, x_reconstruct


class RateLoss(nn.Module):

    band_scales = [i for i in range(1, 16*interval, interval)]
    lbf_scale = torch.tensor(np.array([band_scales[0]]+
                                      [band_scales[1]]*2+
                                      [band_scales[2]]*3, 
                                      dtype=np.float64), 
                             device=device)
    mbf_scale = torch.tensor(np.array([band_scales[3]]*4+
                                      [band_scales[4]]*5+
                                      [band_scales[5]]*6+
                                      [band_scales[6]]*7, 
                                      dtype=np.float64),
                             device=device)
    hbf_scale = torch.tensor(np.array([band_scales[7]]*8+
                                      [band_scales[8]]*7+
                                      [band_scales[9]]*6+
                                      [band_scales[10]]*5+
                                      [band_scales[11]]*4+
                                      [band_scales[12]]*3+
                                      [band_scales[13]]*2+
                                      [band_scales[14]], 
                                      dtype=np.float64),
                             device=device)
    
    def __init__(self, band_scales):
        super(RateLoss, self).__init__()
        self.band_scales = band_scales

    def forward(self, z):
        z[...,:6] *= self.band_scales[0] * self.lbf_scale
        z[...,6:28] *= self.band_scales[1] * self.mbf_scale
        z[...,28:] *= self.band_scales[2] * self.hbf_scale
        # enforce_sparsity = torch.mean(torch.linalg.norm(z, ord=2, dim=1))
        enforce_sparsity = torch.mean(torch.linalg.norm(z, ord=2, dim=-1))
        return enforce_sparsity


class DistortionLoss(nn.Module):

    def __init__(self, win_size=3, blur_kernel=13, ssim_scale=1e6, alpha=0.84, n_channels=1, blur_fn=None):
        super(DistortionLoss, self).__init__()
        self.alpha = alpha
        self.ssim_scale = ssim_scale
        self.blur_fn = blur_fn
        self.l1loss = nn.L1Loss()
        self.criterion = MS_SSIM(window_size=win_size, 
                                 sigma=1e5, 
                                 n_channels=n_channels, 
                                 padding=True, 
                                 value_range=255.0,
                                 reduction='mean').double()

    def forward(self, x, y):
        ssim = self.criterion(x, y)
        ssim_loss = -1 * self.ssim_scale * torch.log(ssim)
        if self.blur_fn is not None:
            l1_loss = self.l1loss(self.blur_fn(x), self.blur_fn(y))
        else:
            l1_loss = self.l1loss(x, y)
        img_quality = self.alpha * ssim_loss + (1 - self.alpha) * l1_loss
        return ssim, img_quality


class QuantizationLoss(nn.Module):
    
    def __init__(self, rate_weight, distortion_weight):
        super(QuantizationLoss, self).__init__()
        self.rate_weight = rate_weight
        self.distortion_weight = distortion_weight

    def forward(self, rate, distortion):
        print('rate:', self.rate_weight * rate.item())
        print('distortion:', self.distortion_weight * distortion.item())
        return self.rate_weight * rate + self.distortion_weight * distortion


In [None]:
# OPTIMIZERS

temp_update = lambda t, beta: t * beta

class AnnealingOptimizer:

    def __init__(self, rate_weight, distortion_weight, t=1e6, beta=0.98):
        self.rate_weight = rate_weight
        self.distortion_weight = distortion_weight
        self.beta = beta
        self.t = t
        self.inflection = True
    
    def set_original_entropy(self, rate):
        self.original_entropy = rate.item()

    def forward(self, ssim, entropy_estimate, epoch):

        prev_rate_weight = self.rate_weight

        if epoch == 0:
            self.prev_ssim = ssim
            self.prev_entropy_estimate = entropy_estimate

        if ssim < 0.97:
            self.distortion_weight = 1e5

        delta_c = ssim - self.prev_ssim
        # if model performs worse
        if delta_c < 0:
            self.distortion_weight += (1 - ssim.item()) * self.t
                
        delta_c = entropy_estimate / self.rate_weight - self.prev_entropy_estimate
        # if model performs worse
        if delta_c > 0:
            # increase hyperparameter
            entropy_estimate_normalized = (entropy_estimate / self.rate_weight).item()
            self.rate_weight += entropy_estimate_normalized / self.original_entropy * self.t
        else:
            self.rate_weight *= 1e-1
            if self.rate_weight < 1:
                self.rate_weight = 1

        # sufficient condition to decrease temperature
        if ssim > 0.98:
            self.distortion_weight = 1
            if self.inflection:
                self.t = temp_update(self.t, self.beta)
                self.inflection = False
        else:
            self.inflection = True

        self.prev_ssim = ssim
        self.prev_entropy_estimate = entropy_estimate / prev_rate_weight

        return self.rate_weight, self.distortion_weight

In [None]:
# x = torch.randint(256, (1, 1, 8, 8), dtype=torch.float)
# y = torch.randint(256, (1, 1, 8, 8), dtype=torch.float)
# plt.figure()
# plt.imshow(np.squeeze(x.numpy(), axis=(0, 1)))
# plt.figure()
# plt.imshow(np.squeeze(y.numpy(), axis=(0, 1)))
# weights = torch.rand(5)
# criterion = MS_SSIM(window_size=3, sigma=1e10, n_channels=1, padding=True, value_range=255.0)
# criterion(x, y)

In [None]:
# # # LOSS FUNCTION TEST BLOCK

# from scipy import fftpack

# def idct2d(a):
#     # https://inst.eecs.berkeley.edu/~ee123/sp16/Sections/JPEG_DCT_Demo.html
#     return fftpack.idct(fftpack.idct(a, axis=0, norm='ortho'), axis=1, norm='ortho')
# def dct2d(a):
#     # https://inst.eecs.berkeley.edu/~ee123/sp16/Sections/JPEG_DCT_Demo.html
#     return fftpack.dct(fftpack.dct(a, axis=0, norm='ortho'), axis=1, norm='ortho')

# block1 = np.random.randint(0, 256, size=(8, 8))
# qtable1 = np.ones((8, 8)) * 
# qtable1 = np.expand_dims(np.expand_dims(np.expand_dims(qtable1, axis=0), axis=0), axis=0)
# block1 = np.expand_dims(np.expand_dims(np.expand_dims(block1, axis=0), axis=0), axis=0)
# print(block1.shape, qtable1.shape)
# transform = torch.tensor(fftpack.dct(fftpack.dct(block1, axis=-1), axis=-2), dtype=torch.double)
# qtable1 = torch.tensor(qtable1)
# x_input = reconstruct_img(transform, qtable1)
# x_input[1], block1

In [None]:
# MODEL

def zz_encode_model(block):
    zz = [0,  1,  5,  6,  14, 15, 27, 28, 2,  4,  7,  13, 16, 26, 29, 42, 3,  8,  
          12, 17, 25, 30, 41, 43, 9,  11, 18, 24, 31, 40, 44, 53, 10, 19, 23, 32, 
          39, 45, 52, 54, 20, 22, 33, 38, 46, 51, 55, 60, 21, 34, 37, 47, 50, 56, 
          59, 61, 35, 36, 48, 49, 57, 58, 62, 63]
    n, s, c, d, _ = block.shape
    encoded = torch.zeros((n, s, c, d * d), dtype=block.dtype, device=block.device)
    c = 0
    for i in range(d):
        for j in range(d):
            encoded[...,zz[c]] = block[...,i,j]
            c += 1
    return encoded


def zz_decode_model(encoded):
    zz = [0,  1,  5,  6,  14, 15, 27, 28, 2,  4,  7,  13, 16, 26, 29, 42, 3,  8,  
          12, 17, 25, 30, 41, 43, 9,  11, 18, 24, 31, 40, 44, 53, 10, 19, 23, 32, 
          39, 45, 52, 54, 20, 22, 33, 38, 46, 51, 55, 60, 21, 34, 37, 47, 50, 56, 
          59, 61, 35, 36, 48, 49, 57, 58, 62, 63]
    n, s, c, d = encoded.shape
    d = int(d ** 0.5)
    decoded = torch.zeros((n, s, c, d, d), dtype=encoded.dtype, device=encoded.device)
    c = 0
    for i in range(d):
        for j in range(d):
            decoded[...,i,j] = encoded[...,zz[c]]
            c += 1
    return decoded

class QTableOptimizer(nn.Module):
    def __init__(self, max_q, input_channels=1, n_qtables=1, samples=32):
        super(QTableOptimizer, self).__init__()
        self.output_activation = nn.Sigmoid()
        self.input_activation = nn.Tanh()
        self.max_q = max_q
        self.qtables_out = n_qtables
        self.sample_learning = nn.Sequential(
            nn.Conv2d(samples, int(samples//2), kernel_size=1, stride=(1, 2)),
            self.input_activation,
            nn.BatchNorm2d(int(samples//2)),
            # nn.Dropout(p=0.3),
            nn.Conv2d(int(samples//2), 1, kernel_size=1, stride=(1, 2)),
            self.output_activation,
            nn.BatchNorm2d(1),
            # nn.Dropout(p=0.3),
        )
        self.channel_embedding = nn.Sequential(
            nn.Conv2d(input_channels, n_qtables, kernel_size=1, stride=1),
            # nn.ReLU(),
            self.output_activation,
            nn.BatchNorm2d(n_qtables),
            # nn.Dropout(p=0.3),
            # nn.Conv2d(input_channels * 4, n_qtables, kernel_size=1, stride=1),
            # self.output_activation,
            # nn.BatchNorm2d(n_qtables),
            # # nn.Dropout(p=0.3),
        )
        self.embedding_layer = nn.Sequential(
            nn.Linear(64, 256),
            self.output_activation,
            nn.BatchNorm2d(samples), 
        )
        
    def forward(self, x):
        # embed input over the sample
        x = zz_encode_model(x) # (b, s, c, p: 64)
        x = self.embedding_layer(x) # (b, s, c, p: 256)
        x = self.sample_learning(x) # (b, s: 1, c, p: 64)
        x = zz_decode_model(x) # (b, c, x: 8, y: 8)
        x = torch.squeeze(x, dim=1) # (b, c, x: 8, y: 8)
        if self.qtables_out > 1:
            x = self.channel_embedding(x) # (b, c: n_tables, x: 8, y: 8)
        y = self.output_activation(x) * self.max_q
        return y

# class QTableOptimizer(nn.Module):
#     def __init__(self, max_q, input_channels=1, n_qtables=1, samples=32):
#         super(QTableOptimizer, self).__init__()
#         self.output_activation = nn.Sigmoid()
#         self.max_q = max_q
#         self.qtables_out = n_qtables
#         self.embedding_layer = nn.Sequential(
#             nn.Linear(64, 256),
#             nn.ReLU(),
#             nn.BatchNorm2d(input_channels),
#             nn.Dropout(p=0.5),
#             nn.Linear(256, 1024),
#             nn.ReLU(),
#             nn.BatchNorm2d(input_channels),
#             nn.Dropout(p=0.5),
#             nn.Linear(1024, 1024),
#             nn.ReLU(),
#             nn.BatchNorm2d(input_channels),
#             nn.Dropout(p=0.5),
#             # nn.Linear(1024, 4096),
#             # nn.ReLU(),
#             # nn.BatchNorm2d(input_channels),
#             # nn.Linear(4096, 16384),
#             # nn.ReLU(),
#             # nn.BatchNorm2d(input_channels),
#         )
#         self.pixel_learning = nn.Sequential(
#             # nn.Linear(16384, 4096),
#             # nn.ReLU(),
#             # nn.BatchNorm2d(input_channels),
#             # nn.Linear(4096, 1024),
#             # nn.ReLU(),
#             # nn.BatchNorm2d(input_channels),
#             # nn.Linear(256, 256),
#             # nn.ReLU(),
#             # nn.BatchNorm2d(input_channels),
#             nn.Linear(1024, 1024),
#             nn.ReLU(),
#             nn.BatchNorm2d(input_channels),
#             nn.Dropout(p=0.5),
#             nn.Linear(1024, 256),
#             nn.ReLU(),
#             nn.BatchNorm2d(input_channels),
#             nn.Dropout(p=0.5),
#             nn.Linear(256, 64),
#             nn.ReLU(),
#             nn.BatchNorm2d(input_channels),
#             nn.Dropout(p=0.5),
#         )
#         self.sample_learning = nn.Sequential(
#             nn.Linear(samples, samples*4),
#             nn.ReLU(),
#             nn.BatchNorm2d(input_channels),
#             nn.Dropout(p=0.5),
#             nn.Linear(samples*4, 1),
#             nn.ReLU(),
#             nn.BatchNorm2d(input_channels),
#             nn.Dropout(p=0.5),
#         )
#         self.channel_learning = nn.Sequential(
#             nn.Linear(input_channels, n_qtables*4),
#             nn.ReLU(),
#             nn.BatchNorm2d(1), 
#             nn.Dropout(p=0.5),
#             nn.Linear(n_qtables*4, n_qtables),
#             nn.ReLU(),
#             nn.BatchNorm2d(1),
#             nn.Dropout(p=0.5),
#         )
        
#     def forward(self, x):
#         # embed dct input (-1, 1)
#         x = self.embedding_layer(x)
#         # learn q table (0, 1)
#         x = self.pixel_learning(x)
#         # combine q tables across image samples (0, 1)
#         x = self.sample_learning(torch.transpose(x, -1, -2)) # b, c, p, s
#         if self.qtables_out > 1:
#             # learn across image channels
#             y = self.channel_learning(torch.permute(x, (0, 3, 2, 1))) # b, s, p, c
#             y = torch.permute(y, (0, 3, 1, 2)) # b, c, s, p
#         else:
#             y = torch.transpose(x, -1, -2) # b, c, s, p

#         y = self.output_activation(y) * self.max_q

#         return y 

In [None]:
# # MODEL TEST BLOCK

# model = QTableOptimizer()

# x = torch.randn(1, 3, 8, 8)
# y = model(x)
# y

In [None]:
#  DATASET LOADER

from torch.utils.data import Dataset
import os, cv2, random, heapq
from PIL import Image 


def sample_img(img, n_samples=32):
    assert n_samples % 4 == 0
    partitions = partition(img)
    partitions.sort(reverse=True, key=lambda x: np.var(x))
    partitions_large = partitions[:len(partitions)//2]
    partitions_small = partitions[len(partitions)//2:]
    # random sample then swap channels and samples axes
    samples_a = random.sample(partitions_large, 3 * (n_samples // 4))
    samples_b = random.sample(partitions_small, n_samples // 4)
    samples = samples_a + samples_b
    random.shuffle(samples)
    samples = torch.tensor(np.array(samples), dtype=torch.double)
    return samples

def read_img(file_path, convert=False):
    img = cv2.imread(file_path, cv2.IMREAD_UNCHANGED)
    # channels first
    if len(img.shape) > 2:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = np.moveaxis(img, -1, 0)
    else:
        img = np.expand_dims(img, axis=0)
    img = img.astype(np.float64)
    if convert:
        img = rgb_ycbcr(img)
    img -= 128
    return img

def normalize(tensor):
    return (tensor - torch.mean(tensor)) / torch.std(tensor)

class ImageCompressionDataset(Dataset):
    def __init__(self, img_path, crop=(344, 344), samples=64):
        self.samples = samples
        self.img_path = img_path
        self.crop = crop

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        img = read_img(self.img_path, False)
        c, x, y = img.shape
        assert x >= self.crop[0] and y >= self.crop[1]
        assert x * y / (8 * 8) >= self.samples
        spatial_samples = sample_img(img, self.samples)
        freq_samples = dct2(spatial_samples)
        cx, cy = self.crop
        x1, y1 = random.randint(0, x - cx - 1), random.randint(0, y - cy - 1)
        x2, y2 = x1 + cx, y1 + cy
        img = img[:,x1:x2,y1:y2]
        partition_freq = dct2(torch.tensor(partition_inplace(img), 
                                           dtype=torch.double, 
                                           device=device))
        return freq_samples, partition_freq, img + 128, self.img_path

def save_test_image_color(qtable, path, n):
    file_save = f'/content/drive/My Drive/Research/CompAlgo/training_examples_color/epoch_{n}.jpg'
    test_tables = np.round(zz_encode(qtable).detach().cpu().numpy()).astype(int)[0]
    table1, table2 = test_tables
    im1 = plt.imread(path)
    img = Image.fromarray(im1)
    img.save(file_save, qtables={0: table1, 1: table2}, optimize=False)

def save_test_image(qtable, path, n):
    file_save = f'/content/drive/My Drive/Research/CompAlgo/training_examples_single/epoch_{n}.jpg'
    test_table = np.round(torch.squeeze(zz_encode(qtable)).detach().cpu().numpy()).astype(int)
    im1 = plt.imread(path)
    img = Image.fromarray(im1)
    img.save(file_save, qtables={0: test_table}, optimize=False)

In [None]:
# # TRAINING LOOP FOR COLOR IMAGES


from torch.utils.data import DataLoader
import torchvision.transforms as T
import torch.optim as optim

IMG_PATH = '/content/drive/My Drive/Projects/Compression Algorithm/Dataset/Color/mandril_color.tif'
# IMG_PATH = '/content/drive/My Drive/Projects/Compression Algorithm/Dataset/Color/lena.jpg'
# IMG_PATH = '/content/drive/My Drive/Projects/Compression Algorithm/Dataset/Color/chalk-RGB.tif'
# IMG_PATH = '/content/drive/My Drive/Research/CompAlgo/img_tests/salad.jpeg'
# IMG_PATH = '/content/drive/My Drive/Research/CompAlgo/img_tests/salad.jpeg'


SAVE_PATH = '/content/drive/My Drive/Research/CompAlgo/qoptim_models_single/qoptimizer'
SAMPLES = 256

dataset = ImageCompressionDataset(IMG_PATH, crop=(384, 384), samples=SAMPLES)
dataset = DataLoader(dataset, batch_size=1, shuffle=False)

model = QTableOptimizer(170, input_channels=3, n_qtables=2, samples=SAMPLES)
model = model.double()
model = model.to(device)

blur_fn = torchvision.transforms.GaussianBlur(13, sigma=(0.1, 2.0))
criterion = QuantizationLoss(rate_weight=1, distortion_weight=1).to(device)
rate_criterion = RateLoss(band_scales=(1e-4, 1e2, 1e3)).to(device)
distortion_criterion = DistortionLoss(win_size=3, 
                                      blur_kernel=21, 
                                      ssim_scale=1e6, 
                                      alpha=0.85,
                                      n_channels=3,
                                      blur_fn=blur_fn).to(device)
annealer = AnnealingOptimizer(1, 1, t=1e5, beta=0.65)

# optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=0.09)
# optimizer = optim.Adadelta(model.parameters(), lr=1.0, rho=0.9, eps=1e-06, weight_decay=0)
# optimizer = optim.Adagrad(model.parameters(), lr=1e-3, lr_decay=1e-4, weight_decay=1e-8, initial_accumulator_value=0, eps=1e-10)
optimizer = optim.Adam(model.parameters(), lr=1, betas=(0.9, 0.999), weight_decay=0, amsgrad=True)
epochs = 70

best_in_bin = [[]*10]

best_loss = float('inf')
for epoch in range(epochs):  # loop over the dataset multiple times
    print('EPOCH: %d\n%s' % (epoch + 1, '-'*40))

    frequency_data, freq_data_partition, spatial_data, image_path = next(iter(dataset))
    frequency_data = frequency_data.to(device)
    freq_data_partition = freq_data_partition.to(device)
    spatial_data = spatial_data.to(device)
    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    frequency_data_input = normalize(frequency_data)
    qtables = model(frequency_data_input)
    # print(qtables)

    zz_quantized, reconstruction = reconstruct_img(freq_data_partition, qtables)
    rate_loss = rate_criterion(zz_quantized)
    ssim_value, distortion_loss = distortion_criterion(reconstruction, spatial_data)
    loss = criterion(rate_loss, distortion_loss)
    measure_loss = rate_loss + distortion_loss

    if epoch == 0:
        annealer.set_original_entropy(rate_loss)

    print('ssim:', ssim_value.item())

    loss.backward()
    optimizer.step()
    
    rate_update, distortion_update = annealer.forward(ssim_value, rate_loss, epoch)
    criterion.rate_weight = rate_update
    criterion.distortion_weight = distortion_update

    print('rate weight:', criterion.rate_weight)
    print('distortion weight:', criterion.distortion_weight)
    
    # print statistics
    save_test_image_color(qtables, image_path[0], epoch)      
    print('epoch %3d loss: %.3f' % (epoch + 1, abs(measure_loss)))
    if abs(best_loss) > abs(measure_loss):
        print("Loss improved, saving model")
        torch.save(model.state_dict(), f'{SAVE_PATH}_best_{epoch+1}')
        best_loss = measure_loss
    print()

print('Finished Training')
torch.save(model.state_dict(), f'{SAVE_PATH}_{epoch+1}')

EPOCH: 1
----------------------------------------
rate: 29314.37159032062
distortion: 90963.91097932929
ssim: 0.8985115189632233
rate weight: 1
distortion weight: 100000.0
epoch   1 loss: 120278.283
Loss improved, saving model

EPOCH: 2
----------------------------------------
rate: 12436.295288242818
distortion: 14564985731.801388
ssim: 0.8425249928299898
rate weight: 1
distortion weight: 115747.50071700102
epoch   2 loss: 158086.153

EPOCH: 3
----------------------------------------
rate: 18092.230287509665
distortion: 12912501200.057652
ssim: 0.8770055936467109
rate weight: 61718.95370665076
distortion weight: 100000.0
epoch   3 loss: 129649.725

EPOCH: 4
----------------------------------------
rate: 1603515182.9380345
distortion: 5912865985.4348135
ssim: 0.9328020125885679
rate weight: 6171.895370665076
distortion weight: 100000.0
epoch   4 loss: 85109.580
Loss improved, saving model

EPOCH: 5
----------------------------------------
rate: 300100617.3333917
distortion: 2491218334.

In [None]:
# TRAINING LOOP FOR GRAYSCALE IMAGES


from torch.utils.data import DataLoader
import torchvision.transforms as T
import torch.optim as optim

IMG_PATH = '/content/drive/My Drive/Projects/Compression Algorithm/Dataset/Grayscale/rose512.jpg'
# IMG_PATH = '/content/drive/My Drive/Projects/Compression Algorithm/Dataset/Grayscale/mandril_gray.jpg'
SAVE_PATH = '/content/drive/My Drive/Research/CompAlgo/qoptim_models_single/qoptimizer'
SAMPLES = 384

dataset = ImageCompressionDataset(IMG_PATH, samples=SAMPLES)
dataset = DataLoader(dataset, batch_size=1, shuffle=False)

model = QTableOptimizer(180, input_channels=1, n_qtables=1, samples=SAMPLES)
model = model.double()
model = model.to(device)

blur_fn = torchvision.transforms.GaussianBlur(13, sigma=(0.1, 2.0))
criterion = QuantizationLoss(rate_weight=1, distortion_weight=1).to(device)
rate_criterion = RateLoss(band_scales=(1e-4, 1e2, 1e3)).to(device)
distortion_criterion = DistortionLoss(win_size=3, 
                                      blur_kernel=21, 
                                      ssim_scale=1e6,
                                      alpha=0.85,
                                      n_channels=1,
                                      blur_fn=blur_fn).to(device)
annealer = AnnealingOptimizer(1, 1, t=1e5, beta=0.65)

# optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.09)
# optimizer = optim.Adadelta(model.parameters(), lr=1.0, rho=0.9, eps=1e-06, weight_decay=0)
# optimizer = optim.Adagrad(model.parameters(), lr=1e-3, lr_decay=1e-4, weight_decay=1e-8, initial_accumulator_value=0, eps=1e-10)
optimizer = optim.Adam(model.parameters(), lr=1, betas=(0.9, 0.999), weight_decay=0, amsgrad=True)
epochs = 70

best_loss = float('inf')
for epoch in range(epochs):  # loop over the dataset multiple times
    print('EPOCH: %d\n%s' % (epoch + 1, '-'*40))

    frequency_data, freq_data_partition, spatial_data, image_path = next(iter(dataset))
    frequency_data = frequency_data.to(device)
    freq_data_partition = freq_data_partition.to(device)
    spatial_data = spatial_data.to(device)
    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    frequency_data_input = normalize(frequency_data)
    qtables = model(frequency_data_input)
    # print(qtables)

    zz_quantized, reconstruction = reconstruct_img(freq_data_partition, qtables)
    rate_loss = rate_criterion(zz_quantized)
    ssim_value, distortion_loss = distortion_criterion(reconstruction, spatial_data)
    loss = criterion(rate_loss, distortion_loss)
    measure_loss = rate_loss + distortion_loss
    
    if epoch == 0:
        annealer.set_original_entropy(rate_loss)

    print('ssim:', ssim_value.item())

    loss.backward()
    optimizer.step()
    
    rate_update, distortion_update = annealer.forward(ssim_value, rate_loss, epoch)
    criterion.rate_weight = rate_update
    criterion.distortion_weight = distortion_update

    print('rate weight:', criterion.rate_weight)
    print('distortion weight:', criterion.distortion_weight)
    
    # print statistics
    save_test_image(qtables, image_path[0], epoch)      
    print('epoch %3d loss: %.3f' % (epoch + 1, abs(measure_loss)))
    if abs(best_loss) > abs(measure_loss):
        print("Loss improved, saving model")
        torch.save(model.state_dict(), f'{SAVE_PATH}_best_{epoch+1}')
        best_loss = measure_loss
    print()

print('Finished Training')
torch.save(model.state_dict(), f'{SAVE_PATH}_{epoch+1}')

EPOCH: 1
----------------------------------------
rate: 10405.915984410743
distortion: 73779.92989581393
ssim: 0.9168610947704263
rate weight: 1
distortion weight: 100000.0
epoch   1 loss: 84185.846
Loss improved, saving model

EPOCH: 2
----------------------------------------
rate: 6389.299581132414
distortion: 5247201622.762029
ssim: 0.9401358067330224
rate weight: 1
distortion weight: 100000.0
epoch   2 loss: 58861.316
Loss improved, saving model

EPOCH: 3
----------------------------------------
rate: 15810.060689404092
distortion: 2660518619.832072
ssim: 0.9691851201655513
rate weight: 151934.3878256309
distortion weight: 100000.0
epoch   3 loss: 42415.247
Loss improved, saving model

EPOCH: 4
----------------------------------------
rate: 6016446957.796045
distortion: 2015036763.7851734
ssim: 0.976572920323821
rate weight: 15193.438782563091
distortion weight: 100000.0
epoch   4 loss: 59749.349

EPOCH: 5
----------------------------------------
rate: 317677688.7548428
distortion: