In [1]:
import math
import io
import torch
from torchvision import transforms
import numpy as np
import pickle
import re
import scipy.stats as st
import PIL
from collections import Counter
from PIL import Image
from PIL import ImageChops
import imageio
import matplotlib.pyplot as plt
import compressai

from pytorch_msssim import ms_ssim
from compressai.zoo import bmshj2018_factorized
from compressai.dna_entropy_coding.coder import Coder
from ipywidgets import interact, widgets

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
JPEG_SIZES = [
    '1192x832', 
    '853x945', 
    '945x840', 
    '2000x2496', 
    '560x888',
    '2048x1536', 
    '1600x1200', 
    '1430x1834', 
    '2048x1536', 
    '2592x1946'  
]

In [4]:
def compute_x_hat(x, quality=1, include_rounded=False):
    """
    Runs the given image tensor for a given quality level through the forward part of the CompressAI autoencoder, 
    encodes the non trivial channels of the quantized latent space into DNA, decodes the DNA strand and runs the
    latent space representation though the inverse transformations of the CompressAI autoencoder. Results in the 
    reconstructed image tensor. This is the modified autoencoder.

    Parameters
    ----------
    x : tensor
        image tensor to run through the modified autoencoder
    quality : int, optional
        quality level at which to encode the image
        

    Returns
    -------
    The reconstructed image tensor and length of the encoded DNA strand
    """
    net = bmshj2018_factorized(quality = quality, pretrained=True).eval().to(device)

    with torch.no_grad():
        dictionary = net.get_y_hat_and_medians(x)
        y_hat = dictionary['y_hat']
        medians = dictionary['medians']

    coder = Coder() ## Class defined in coder.py
    rounded = y_hat-torch.unsqueeze(medians, dim=0)
    rounded = torch.round(rounded).int()

    dna = coder.encode(rounded, quality)                               
    rounded_decoded = coder.decode(dna, quality, x.shape)

    y_hat_decoded = rounded_decoded + torch.unsqueeze(medians, dim=0)

    with torch.no_grad():
        dict_dna = net.get_x_hat(y_hat_decoded)
        out_net_dna = {'x_hat': dict_dna['x_hat']}
        out_net_dna['x_hat'] = out_net_dna['x_hat'].clamp(0, 1)
    
    if include_rounded:
        return out_net_dna['x_hat'], len(dna), rounded
    return out_net_dna['x_hat'], dna

In [5]:
def get_kodim_tensor(img_number): 
    img_name = '/kodim{:02d}.png'.format(img_number)
    img = Image.open('./assets'+ img_name).convert('RGB')
    x = transforms.ToTensor()(img).unsqueeze(0).to(device)
    return img, x

In [13]:
'''
for i in range(1,25): 
    for q in range(1, 9): 
        img, x = get_kodim_tensor(i)
        x_hat, dna = compute_x_hat(x, q)
        
        img_decoded = transforms.ToPILImage()(x_hat.squeeze().cpu())
        path = 'assets/transcoder/'
        img_name = '{}kodim{:02d}.png'.format(q, i)
        img_path = path + img_name
        imageio.imwrite(img_path, img_decoded)
        
        dna_name = '{}kodim{:02d}.fasta'.format(q, i)
        dna_path = path + dna_name
        with open(dna_path, 'w') as f:
            f.write(dna)
'''




"\nfor i in range(1,25): \n    for q in range(1, 9): \n        img, x = get_kodim_tensor(i)\n        x_hat, dna = compute_x_hat(x, q)\n        \n        img_decoded = transforms.ToPILImage()(x_hat.squeeze().cpu())\n        path = 'assets/transcoder/'\n        img_name = '{}kodim{:02d}.png'.format(q, i)\n        img_path = path + img_name\n        imageio.imwrite(img_path, img_decoded)\n        \n        dna_name = '{}kodim{:02d}.fasta'.format(q, i)\n        dna_path = path + dna_name\n        with open(dna_path, 'w') as f:\n            f.write(dna)\n"

In [8]:
def get_original_tensor(data_set, img_id): 
    path = get_original_image_path(data_set, img_id)
    img = Image.open(path).convert('RGB')
    x = transforms.ToTensor()(img).unsqueeze(0).to(device)
    return img, x

In [9]:
def get_original_image_path(data_set, img_number): 
    img_path = 'assets/' + data_set
    if data_set == 'kodak':
        img_name = '/kodim{:02d}.png'.format(img_number)
        
    if data_set == 'jpeg_dna': 
        img_name = f"/{str(img_number).zfill(5)}_" + JPEG_SIZES[img_number-1]+'.png'
    return img_path + img_name

In [11]:
i = 10
q = 8
for i in range (1,11): 
    for q in range(1,9): 
        img, x = get_original_tensor('jpeg_dna', i)
        x_hat, dna = compute_x_hat(x, q)

        img_decoded = transforms.ToPILImage()(x_hat.squeeze().cpu())
        path = 'assets/jpeg_dna/learningbased/'
        img_name = f"/JPEG-1_{str(i).zfill(5)}" + '_' + JPEG_SIZES[i-1] + '_' + str(q) + '_decoded' + '.png'
        img_path = path + img_name
        imageio.imwrite(img_path, img_decoded)

        dna_name = f"/JPEG-1_{str(i).zfill(5)}" + '_' + JPEG_SIZES[i-1] + '_' + str(q) + '.fasta'
        dna_path = path + dna_name
        with open(dna_path, 'w') as f:
            f.write(dna)

In [1]:

for i in range(1, 25): 
    for q in range(1,8):
        img, x = get_original_tensor(i)
        x_hat, dna = compute_x_hat(x, q)

        img_decoded = transforms.ToPILImage()(x_hat.squeeze().cpu())
        path = 'assets/jpeg_dna/learningbased/'
        img_name = f"/JPEG-1_{str(i).zfill(5)}" + '_' + JPEG_SIZES[i-1] + '_' + str(q) + '_decoded' + '.png'
        img_path = path + img_name
        imageio.imwrite(img_path, img_decoded)

        dna_name = f"/JPEG-1_{str(i).zfill(5)}" + '_' + JPEG_SIZES[i-1] + '_' + str(q) + '.fasta'
        dna_path = path + dna_name
        with open(dna_path, 'w') as f:
            f.write(dna)

NameError: name 'get_original_tensor' is not defined