In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import PIL.Image as Image
import numpy as np

from datasets import load_dataset
from torchvision import transforms
from compressai.entropy_models import EntropyBottleneck
from compressai.layers import GDN
from compressai.models import CompressionModel
from compressai.models.utils import conv, deconv

In [2]:
class Network(CompressionModel):
    def __init__(self, N=128):
        super().__init__()
        self.entropy_bottleneck = EntropyBottleneck(N)
        self.encode = nn.Sequential(
            conv(3, N),
            GDN(N),
            conv(N, N),
            GDN(N),
            conv(N, N),
        )

        self.decode = nn.Sequential(
            deconv(N, N),
            GDN(N, inverse=True),
            deconv(N, N),
            GDN(N, inverse=True),
            deconv(N, 3),
        )

    def forward(self, x):
        y = self.encode(x)
        y_hat, y_likelihoods = self.entropy_bottleneck(y)
        x_hat = self.decode(y_hat)
        return x_hat, y_likelihoods

In [3]:
lmbda = 0.01
dataset = load_dataset("danjacobellis/vimeo90k_triplet",split='train').with_format("torch")
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

Found cached dataset parquet (/home/server/.cache/huggingface/datasets/danjacobellis___parquet/danjacobellis--vimeo90k_triplet-de2b1cb7b7e1797e/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)


In [4]:
# net = Network().to("cuda")

net = Network()
net = net.to("cuda")
checkpoint = torch.load("checkpoint.pth")
net.load_state_dict(checkpoint['model_state_dict'])

parameters = set(p for n, p in net.named_parameters() if not n.endswith(".quantiles"))
aux_parameters = set(p for n, p in net.named_parameters() if n.endswith(".quantiles"))
optimizer = optim.Adam(parameters, lr=1e-4)
aux_optimizer = optim.Adam(aux_parameters, lr=1e-3)

In [None]:
bpp = np.array([11])
mse = np.array([1])
for i,batch in enumerate(dataloader):
    optimizer.zero_grad()
    aux_optimizer.zero_grad()

    x = batch['image'].to("cuda")
    x = x.to(torch.float)
    x = x/255
    x = x - 0.5
    x = x.permute(0, 3, 1, 2)
    
    x_hat, y_likelihoods = net(x)
    
    # bitrate of the quantized latent
    N, _, H, W = x.size()
    num_pixels = N * H * W
    bpp_loss = torch.log(y_likelihoods).sum() / (-math.log(2) * num_pixels)
    bpp = np.append(bpp,bpp_loss.detach().cpu().numpy())
    
    # mean square error
    mse_loss = F.mse_loss(x, x_hat)
    mse = np.append(mse,mse_loss.detach().cpu().numpy())
    
    # final loss term
    loss = mse_loss + lmbda * bpp_loss
    
    loss.backward()
    optimizer.step()
    
    aux_loss = net.aux_loss()
    aux_loss.backward()
    aux_optimizer.step()

    torch.save({
            'epoch': i,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'aux_optimizer_state_dict': aux_optimizer.state_dict(),
            'loss': loss,
        }, f"checkpoint.pth")
    np.save('mse',mse)
    np.save('bpp',bpp)