In [1]:
from datasets import load_dataset
from torch.utils.data import DataLoader
from fastprogress.fastprogress import master_bar, progress_bar
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from compressai.models import CompressionModel
from compressai.entropy_models import EntropyBottleneck
from compressai.layers import GDN

In [None]:
imagenet_train = load_dataset("danjacobellis/imagenet_dino",split='train').with_format("torch")

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

In [None]:
def conv(in_channels, out_channels, kernel_size=5, stride=2):
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=kernel_size // 2,
        groups=512
    )

In [None]:
def deconv(in_channels, out_channels, kernel_size=5, stride=2):
    return nn.ConvTranspose2d(
        in_channels,
        out_channels,
        kernel_size=kernel_size,
        stride=stride,
        output_padding=stride - 1,
        padding=kernel_size // 2,
        groups=512
    )

In [None]:
class RateDistortionAutoEncoder(CompressionModel):
    def __init__(self, N=4096):
        super().__init__()
        self.entropy_bottleneck = EntropyBottleneck(N)
        self.encode = nn.Sequential(
            conv(1536, N),
            GDN(N),
            conv(N, N),
            GDN(N),
            conv(N, N),
        )

        self.decode = nn.Sequential(
            deconv(N, N),
            GDN(N, inverse=True),
            deconv(N, N),
            GDN(N, inverse=True),
            deconv(N, 1536),
        )

    def forward(self, x):
        y = self.encode(x)
        y_hat, y_likelihoods = self.entropy_bottleneck(y)
        x_hat = self.decode(y_hat)
        return x_hat, y_likelihoods

In [None]:
net = RateDistortionAutoEncoder().to("cuda")
parameters = set(p for n, p in net.named_parameters() if not n.endswith(".quantiles"))
aux_parameters = set(p for n, p in net.named_parameters() if n.endswith(".quantiles"))
optimizer = optim.AdamW(parameters, lr=1e-4)
aux_optimizer = optim.Adam(aux_parameters, lr=1e-3)
λ = 0.01

In [None]:
rate = np.array([])
distortion = np.array([])
epochs = 100
mb = master_bar(range(1, epochs+1))
mb.names = ['rate', 'distortion']
train_loss, valid_loss = [], []
for epoch in mb:
    
    dataloader = DataLoader(imagenet_train, batch_size=64, shuffle=True)
    for batch in progress_bar(dataloader, parent=mb):
        x = batch['patch_tokens'].reshape((64,16,16,1536)).permute((0,3,1,2)).to("cuda")

        x_hat, y_likelihoods = net(x)
        
        N, C, H, W = x.size()
        num_samples = N * C * H * W
        bps_loss = torch.log(y_likelihoods).sum() / (-np.log(2) * num_samples)
        rate = np.append(rate,bps_loss.detach().cpu().numpy())
        
        mse_loss = F.mse_loss(x, x_hat)
        distortion = np.append(distortion,mse_loss.detach().cpu().numpy())
        
        loss = mse_loss + λ * bps_loss
        loss.backward()
        optimizer.step()
        aux_loss = net.aux_loss()
        aux_loss.backward()
        aux_optimizer.step()

        graphs = [[range(len(rate)),np.log(rate)], [range(len(distortion)),np.log(distortion)]]
        mb.update_graph(graphs)