In [23]:
import math
import io
import torch
from torchvision import transforms
import numpy as np
import pickle
import re
import scipy.stats as st
import PIL
from collections import Counter
from PIL import Image
from PIL import ImageChops
import imageio
import matplotlib.pyplot as plt
import compressai

from pytorch_msssim import ms_ssim
from compressai.zoo import bmshj2018_factorized
from compressai.dna_entropy_coding.coder import Coder
from ipywidgets import interact, widgets

In [24]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [25]:
def get_kodim_tensor(img_number): 
    img_name = '/kodim{:02d}.png'.format(img_number)
    img = Image.open('./assets'+ img_name).convert('RGB')
    x = transforms.ToTensor()(img).unsqueeze(0).to(device)
    return img, x

In [26]:
def compute_x_hat(x, quality):
    
    net = bmshj2018_factorized(quality, pretrained=True).eval().to(device)

    #COMPRESS
    with torch.no_grad():
        dictionary = net.get_y_hat_and_medians(x)
        y_hat = dictionary['y_hat']
        medians = dictionary['medians']

    coder = Coder() ## Class defined in coder.py

    y_rounded = y_hat-torch.unsqueeze(medians, dim=0)
    y_rounded = torch.round(y_rounded).int()

    #ENCODE
    dna = coder.encode(y_rounded, quality) 

    #DECODE
    y_rounded_decoded = coder.decode(dna, quality, x.shape)
    y_hat_decoded = y_rounded_decoded + torch.unsqueeze(medians, dim=0)

    #DECOMPRESS
    with torch.no_grad():
        dict_dna = net.get_x_hat(y_hat_decoded)
        out_net_dna = {'x_hat': dict_dna['x_hat']}
        out_net_dna['x_hat'] = out_net_dna['x_hat'].clamp(0, 1)

    x_hat =  out_net_dna['x_hat'].clamp(0, 1)
    return x_hat, dna

In [28]:
for i in range(1,25): 
    for q in range(1, 9): 
        img, x = get_kodim_tensor(i)
        x_hat, dna = compute_x_hat(x, q)
        
        img_decoded = transforms.ToPILImage()(x_hat.squeeze().cpu())
        path = 'assets/transcoder/'
        img_name = '{}kodim{:02d}.png'.format(q, i)
        img_path = path + img_name
        imageio.imwrite(img_path, img_decoded)
        
        dna_name = '{}kodim{:02d}.fasta'.format(q, i)
        dna_path = path + dna_name
        with open(dna_path, 'w') as f:
            f.write(dna)