In [1]:
import numpy as np
import math
from copy import deepcopy

### Thumbnail Inputs

In [2]:
j2k_file = ("beach_4x4.j2k","sunrise_6x6.j2k")[1]
tgt_thumb_size = 1024

### Algorithm Classes and Functions

In [3]:
def read_marker(advance=True):
    global im_cs
    m = im_cs[:2]
    if advance:
        im_cs = im_cs[2:]
    return m

def read_int(nbytes,advance=True):
    global im_cs
    v = int.from_bytes(im_cs[:nbytes],"big")
    if advance:
        im_cs = im_cs[nbytes:]
    return v

def read_bytes(nbytes,advance=True):
    global im_cs
    v = im_cs[:nbytes]
    if advance:
        im_cs = im_cs[nbytes:]
    return v

def dump_marker(m):
    return m.to_bytes(2,"big")

def dump_int(v,n):
    return v.to_bytes(n,"big")

In [4]:
SIZ = b'\xff\x51'

class SIZMarker:
    def __init__(self):
        self.SIZ = read_marker()
        if self.SIZ != SIZ:
            raise Exception("ERROR... Codestream missing SIZ")
        
        self.Lsiz = read_int(2)
        self.Rsiz = read_int(2)
        self.Xsiz = read_int(4)
        self.Ysiz = read_int(4)
        self.XOsiz = read_int(4)
        self.YOsiz = read_int(4)
        self.XTsiz = read_int(4)
        self.YTsiz = read_int(4)
        self.XTOsiz = read_int(4)
        self.YTOsiz = read_int(4)
        self.Csiz = read_int(2)
        self.Ssiz = []
        self.XRsiz = []
        self.YRsiz = []
        for _ in range(self.Csiz):
            ssiz = read_int(1)
            xrsiz = read_int(1)
            yrsiz = read_int(1)
            self.Ssiz.append(ssiz)
            self.XRsiz.append(xrsiz)
            self.YRsiz.append(yrsiz)
            
    def clone(self):
        return deepcopy(self)
            
    def set_thumb(self,step_size):
        self.XRsiz = [step_size for _ in range(self.Csiz)]
        self.YRsiz = [step_size for _ in range(self.Csiz)]
        
    def to_bytes(self):
        return ( SIZ
            + dump_int(self.Lsiz,2)
            + dump_int(self.Rsiz,2)
            + dump_int(self.Xsiz,4)
            + dump_int(self.Ysiz,4)
            + dump_int(self.XOsiz,4)
            + dump_int(self.YOsiz,4)
            + dump_int(self.XTsiz,4)
            + dump_int(self.YTsiz,4)
            + dump_int(self.XTOsiz,4)
            + dump_int(self.YTOsiz,4)
            + dump_int(self.Csiz,2)
            + b"".join([dump_int(s,1)+dump_int(xr,1)+dump_int(yr,1)
                        for s,xr,yr in zip(self.Ssiz,self.XRsiz,self.YRsiz)])
        )

In [5]:
COD = b'\xff\x52'
class CODMarker():
    def __init__(self):
        m = read_marker()
        if m != COD:
            raise Exception(f"ERROR... incorrect marker. Found {m.hex()}. Expected COD")
        self.Lcod = read_int(2)
        if self.Lcod != 12:
            raise Exception(f"Sorry... need to expand COD, Cannot handle Lcod={self.Lcod}")
        self.Scod = read_int(1)
        self.prog_order = read_int(1)
        self.nlayers = read_int(2)
        self.mcomp = read_int(1)
        self.nlevels = read_int(1)
        self.cbwidth = read_int(1)
        self.cbheight = read_int(1)
        self.cbstyle = read_int(1)
        self.wavelet = read_int(1)
        self.wavelet = "NL" if self.wavelet == 1 else "VL"
            
    def clone(self):
        return deepcopy(self)
    
    def set_thumb(self,ndecomp):
        self.prog_order = 1
        self.nlevels -= ndecomp
        
    def to_bytes(self):
        return (
            COD 
            + dump_int(self.Lcod,2)
            + dump_int(self.Scod,1)
            + dump_int(self.prog_order,1)
            + dump_int(self.nlayers,2)
            + dump_int(self.mcomp,1)
            + dump_int(self.nlevels,1)
            + dump_int(self.cbwidth,1)
            + dump_int(self.cbheight,1)
            + dump_int(self.cbstyle,1)
            + dump_int(1 if self.wavelet=="NL" else 0,1)
            )

In [6]:
QCD = b'\xff\x5c'
class QCDMarker():
    def __init__(self):
        m = read_marker()
        if m != QCD:
            raise Exception(f"ERROR... incorrect marker. Found {m.hex()}. Expected QCD")
        self.Lqcd = read_int(2)
        self.data = read_bytes(self.Lqcd-2)
        
    def to_bytes(self):
        return (
            QCD
            + dump_int(self.Lqcd,2)
            + self.data
        )

In [7]:
COM = b'\xff\x64'
class COMMarker():
    def __init__(self):
        m = read_marker()
        if m != COM:
            raise Exception(f"ERROR... incorrect marker. Found {m.hex()}. Expected COM")
        self.Lcom = read_int(2)
        self.Rcom = read_int(2)
        self.data = read_bytes(self.Lcom-4)
        
    def to_bytes(self):
        return (
            COM
            + dump_int(self.Lcom,2)
            + dump_int(self.Rcom,2)
            + self.data
        )

In [8]:
class SOTMarker():
    def __init__(self,*,Isot=None,Psot=None,TPsot=None,TNsot=0):
        if Isot is None:
            m = read_marker()
            if m != SOT:
                raise Exception(f"ERROR... incorrect marker. Found {m.hex()}. Expected SOT")
            self.Lsot = read_int(2)
            self.Isot = read_int(2)
            self.Psot = read_int(4)
            self.TPsot = read_int(1)
            self.TNsot = read_int(1)
        else:
            self.Lsot = 10
            self.Isot = Isot
            self.Psot = Psot
            self.TPsot = TPsot
            self.TNsot = TNsot
        
    def to_bytes(self):
        return (
            SOT
            + dump_int(self.Lsot,2)
            + dump_int(self.Isot,2)
            + dump_int(self.Psot,4)
            + dump_int(self.TPsot,1)
            + dump_int(self.TNsot,1)
        )
    

In [9]:
PLT = b'\xff\x58'
class PLTMarker():
    def __init__(self,*,iplt=None):
        if iplt is None:
            m = read_marker()
            if m != PLT:
                raise Exception(f"ERROR... incorrect marker. Found {m.hex()}. Expected PLT")
            self.Lplt = read_int(2)
            self.Zplt = read_int(1)
            self.Iplt_raw = read_bytes(self.Lplt-3,advance=False)
            self.Iplt = []
            iplt = 0
            for i in range(self.Lplt-3):
                n = read_int(1)
                c,n = divmod(n,128)
                iplt = 128*iplt + n
                if not c:
                    self.Iplt.append(iplt)
                    iplt = 0
        else:
            global debug_plt
            self.Zplt = 0
            self.Iplt = iplt
            iplt_raw = b""
            for x in iplt:
                t=[]
                f=1
                while x > 0:
                    x,b7 = divmod(x,128)
                    t.append(128+b7)
                t[0] -= 128
                t.reverse()
                iplt_raw = iplt_raw + b"".join([tt.to_bytes(1,"big") for tt in t])
            self.Iplt_raw = iplt_raw
            self.Lplt = len(iplt_raw) + 3
            
    def to_bytes(self):
        return (
            PLT
            + dump_int(self.Lplt,2)
            + dump_int(self.Zplt,1)
            + self.Iplt_raw
        )

### Import J2K Image

In [10]:
with open(j2k_file,"rb") as file:
    im_cs = file.read()
    
orig_j2k_len = len(im_cs)

FileNotFoundError: [Errno 2] No such file or directory: 'sunrise_6x6.j2k'

### Read Main Header

In [None]:
SOC = b'\xff\x4f'

soc = read_marker()
if soc != SOC:
    raise Exception("ERROR... Codestream doesn't start with SOC")
    
siz = SIZMarker()

SOT = b'\xff\x90'

marker = read_marker(advance=False)
while marker != SOT:
    if marker == COD:
        cod = CODMarker()
    elif marker == QCD:
        qcd = QCDMarker()
    elif marker == COM:
        com = COMMarker()
    else:
        im_cs = im_cs_hold
        raise Exception(f"Need to add marker {marker.hex()}")
    marker = read_marker(advance=False)

### Loop Over Tiles

In [None]:
EOC = b'\xff\xd9'
SOD = b'\xff\x93'
tile_parts = []

while marker == SOT:
    sot = SOTMarker()
    plt_list = []
    
    len_data = sot.Psot - (sot.Lsot + 2)
    
    marker = read_marker(advance=False)
    while marker not in (EOC,SOT):
        if marker == PLT:
            plt = PLTMarker()
            plt_list.append(plt)
            len_data -= plt.Lplt + 2 
        elif marker == SOD:
            len_data -= 2
            data = im_cs[2:2+len_data]
            im_cs = im_cs[2+len_data:]
        else:
            raise Exception(f"Need to add handler for {marker.hex()}")
        marker = read_marker(advance=False)
        
    iplt = [x for xx in plt_list for x in xx.Iplt]
        
    tile_parts.append({
        'sot':sot,
        'iplt':iplt,
        'data':data,
    })
        
    if marker == EOC:
        break
    
if marker != EOC:
    raise Exception("Codestream doesn't end with EOC")

### Extract Packets

In [None]:
packets = dict()
ncomp = siz.Csiz
nlayers = cod.nlayers
nlevels = cod.nlevels+1
for tp in tile_parts:
    tile_index = tp['sot'].Isot
    lims = [0] + np.cumsum(tp['iplt']).tolist()
    for L in range(nlayers):
        for R in range(nlevels):
            for C in range(ncomp):
                packet_index = ncomp*(nlevels * L + R) + C
                start = lims[packet_index]
                end = lims[packet_index+1]
                packets[(tile_index,L,R,C)] = tp['data'][start:end]

### Update Main Headers

In [None]:
# Determine decomp level
orig_size = (siz.Ysiz,siz.Xsiz)

ndecomp = min(nlevels-1,max(int(math.log(x/tgt_thumb_size)) for x in orig_size))
thumb_step = 2**ndecomp
thumb_size = tuple(x//thumb_step for x in orig_size)

(ndecomp,thumb_step,thumb_size)

In [None]:
thumb_siz = siz.clone()
thumb_siz.set_thumb(step_size=thumb_step)

thumb_cod = cod.clone()
thumb_cod.set_thumb(ndecomp=ndecomp)


In [None]:
thumb_j2k = SOC + thumb_siz.to_bytes() + thumb_cod.to_bytes() + qcd.to_bytes() + com.to_bytes()

### Build Tile Parts

In [None]:
ntiles = len(tile_parts)
TNsot = nlevels - ndecomp
thumb_tile_parts = []
for R in range(TNsot):
    for T in range(ntiles):
        psot = 14
        iplt = []
        data = b""
        for L in range(nlayers):
            for C in range(ncomp):
                packet = packets[(T,L,R,C)]
                data = data + packet
                packet_len = len(packet)
                psot += packet_len
                iplt.append(packet_len)
                
        tp_plt = PLTMarker(iplt=iplt)
        psot += tp_plt.Lplt + 2
        tp_sot = SOTMarker(Isot=T,Psot=psot,TPsot=R,TNsot=TNsot)
        
        thumb_j2k = (
            thumb_j2k
            + tp_sot.to_bytes()
            + tp_plt.to_bytes()
            + SOD
            + data
        )
        
thumb_j2k += EOC

### Export Thumb J2k

In [None]:
thumb_file = j2k_file[:-4] + ".thumb.j2k"
with open(thumb_file,"wb") as file:
    file.write(thumb_j2k)

In [None]:
(len(thumb_j2k),orig_j2k_len,len(thumb_j2k)/orig_j2k_len)