In [1]:
import os
import sys
import pickle
# import torch.multiprocessing as mp
# from functools import partial

import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.models as models
from PIL import Image
from compressai.models import ScaleHyperprior
from compressai.zoo import bmshj2018_hyperprior
from torch import optim, nn, utils
from torchvision import transforms
from torchmetrics import Accuracy

import wandb
from pytorch_lightning.loggers import WandbLogger
from compressai.losses import RateDistortionLoss
from compressai.models import ScaleHyperprior
from compressai.zoo import bmshj2018_hyperprior

import math

# mp.set_start_method('spawn', force=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE=32


In [2]:
class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()

    def forward(self, target, x, predict, decode):
        loss1 = F.cross_entropy(predict, target)
        # loss2 = F.mse_loss(x, x_hat)
        # print(x.size())
        N, _, H, W = x.size()
        num_pixels = N * H * W
        bpp_loss = torch.log(decode["likelihoods"]["y"]).sum() / (-math.log(2) * num_pixels)

        # mean square error
        mse_loss = F.mse_loss(x, decode["x_hat"])

        # final loss term
        loss2 = mse_loss + 1e-2 * bpp_loss
        # loss2 = mse_loss
        
        max_classification_loss = torch.max(loss1)
        max_reconstruction_loss = torch.max(loss2)
        
        normalized_loss1 = loss1 / max_classification_loss
        normalized_loss2 = loss2 / max_reconstruction_loss
        
        total_loss = normalized_loss1 + normalized_loss2
        
        normalized_loss1 = normalized_loss1.clone().detach()
        normalized_loss2 = normalized_loss2.clone().detach()

       
        print("Task loss ", loss1, "  -- Reconstruction loss ", loss2)

        return total_loss

In [3]:


def conv3x3(in_channels, out_channels, stride=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3,
                     stride=stride, padding=1, bias=False)


# 1x1 convolution
def conv1x1(in_channels, out_channels, stride=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=1,
                     stride=stride, bias=False)


# Residual block
class ResidualBlock(nn.Module):

    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = conv1x1(in_channels, in_channels, stride)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv2 = conv3x3(in_channels, in_channels)
        self.bn2 = nn.BatchNorm2d(in_channels)
        self.conv3 = conv1x1(in_channels, out_channels, stride)
        self.bn3 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out


class cResnet39(pl.LightningModule):
    def __init__(self, inchannels=192):
        super().__init__()

        # init a pretrained resnet
        backbone = models.resnet50(weights="DEFAULT")
        layers = nn.ModuleList(list(backbone.children())[5:-1])

        self.in_channels = inchannels
        self.layer1_y_hat = self.make_layer(ResidualBlock, 128, 1)

        self.in_channels = inchannels
        self.layer1_scales_hat = self.make_layer(ResidualBlock, 128, 1)

        self.feature_extractor = nn.Sequential(*layers)

        num_target_classes = 23
        # self.classifier = nn.Linear(128*2048, 23)
        self.classifier = nn.Linear(2048, 23)
        
        
        
        self.training_targets = []
        self.validation_targets = []
        
        self.training_predictions = []
        self.validation_predictions = []
        
        self.training_step_losses = []
        self.validation_step_losses = []
        
        self.top1_accuracy = Accuracy(task="multiclass", num_classes=num_target_classes)
        self.top5_accuracy = Accuracy(task="multiclass", num_classes=num_target_classes, top_k=5)
        
        self.model = bmshj2018_hyperprior(quality=8, pretrained=True).to(device)
        
        self.loss = CustomLoss()
        
        # save hyper-parameters to self.hparamsm auto-logged by wandb
        self.save_hyperparameters()
        

    def make_layer(self, block, out_channels, blocks, stride=1):

        downsample = None
        if (stride != 1) or (self.in_channels != out_channels):
            downsample = nn.Sequential(
                conv1x1(self.in_channels, out_channels, stride=stride),
                nn.BatchNorm2d(out_channels))
        layers = nn.ModuleList()
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels
        for i in range(1, blocks):
            layers.append(block(out_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        y = self.model.g_a(x)
        z = self.model.h_a(torch.abs(y))
        z_hat, z_likelihoods = self.model.entropy_bottleneck(z)
        _scales_hat = self.model.h_s(z_hat)
        _y_hat, y_likelihoods = self.model.gaussian_conditional(y, _scales_hat) 
        with torch.no_grad():
            _x_hat = self.model.g_s(_y_hat)
            print(F.mse_loss(x, _x_hat))
        
        _y_hat, _scales_hat = torch.squeeze(_y_hat, 0).to(device), torch.squeeze(_scales_hat, 0).to(device)
        _y_hat = transforms.Resize(32)(_y_hat)
        _scales_hat = transforms.Resize(32)(_scales_hat)
        _y_hat = transforms.RandomCrop(28)(_y_hat)
        _scales_hat = transforms.RandomCrop(28)(_scales_hat)
        
        p = float(torch.randint(0, 2, (1, )).item())
        _y_hat = transforms.RandomHorizontalFlip(p=p)(_y_hat)
        _scales_hat = transforms.RandomHorizontalFlip(p=p)(_scales_hat)
        
        _y_hat = self.layer1_y_hat(_y_hat)
        _scales_hat = self.layer1_scales_hat(_scales_hat)
        x = torch.concat((_y_hat, _scales_hat), 1)
        

        representations = self.feature_extractor(x)

        representations = representations.view(representations.size(0), -1)
        
        output = self.classifier(representations)

        
        return output, _y_hat, _scales_hat, {
            "x_hat": _x_hat,
            "likelihoods": {"y": y_likelihoods, "z": z_likelihoods},
        }

    def training_step(self, batch, batch_idx):
        x, target = batch
        target = target.to(device)    
      
        predict, y_hat, scales_hat, decode = self.forward(x)
        
        # loss = self.loss(target, x, predict, decode)
        
        
#         reconstruction_loss = self.criterion(x, x_hat)
        
        batch_loss = F.cross_entropy(predict, target)
        
#         N, _, H, W = x.size()
#         num_pixels = N * H * W
#         bpp_loss = torch.log(decode["likelihoods"]["y"]).sum() / (-math.log(2) * num_pixels)

#         # mean square error
#         mse_loss = F.mse_loss(x, decode["x_hat"])

#         # final loss term
#         loss2 = mse_loss + 1e-2 * bpp_loss
#         print("Reconstruction Loss is ", loss2)

        
        self.training_step_losses.append(batch_loss)

        self.training_targets.append(target)
        self.training_predictions.append(predict)
        
        return batch_loss
    
    def validation_step(self, batch, batch_idx):
        x, target = batch
        
        target = target.to(device)

      
        with torch.no_grad():
            predict, y_hat, scales_hat, decode = self.forward(x)
            # loss = self.loss(target, x, predict, decode)
            batch_loss = F.cross_entropy(predict, target)
            
        self.validation_step_losses.append(batch_loss)

            
        self.validation_targets.append(target)
        self.validation_predictions.append(predict)
        
        return batch_loss
    
    def on_train_epoch_end(self):
        # print(torch.cat(self.training_targets).shape)
        # print(torch.cat(self.training_predictions).shape)
        
        loss = F.cross_entropy(torch.cat(self.training_predictions), torch.cat(self.training_targets))
        # loss1 = sum(self.training_step_losses) / len(self.training_step_losses)
        top1_accuracy = self.top1_accuracy(torch.cat(self.training_predictions), torch.cat(self.training_targets)) 
        top5_accuracy = self.top5_accuracy(torch.cat(self.training_predictions), torch.cat(self.training_targets)) 
        print("\nTrain loss:", loss)
        print("Train top-1 acc:", top1_accuracy)
        print("Train top-3 acc:", top5_accuracy)
        self.log("train loss", loss)
        self.log("train top-1 acc", top1_accuracy)
        self.log("train top-3 acc", top5_accuracy)
        self.training_targets.clear()
        self.training_predictions.clear()
        self.training_step_losses.clear()
        
    def on_validation_epoch_end(self):
        print(torch.cat(self.validation_targets).shape)
        print(torch.cat(self.validation_predictions).shape)
        
        loss = F.cross_entropy(torch.cat(self.validation_predictions), torch.cat(self.validation_targets))
        top1_accuracy = self.top1_accuracy(torch.cat(self.validation_predictions), torch.cat(self.validation_targets)) 
        top5_accuracy = self.top5_accuracy(torch.cat(self.validation_predictions), torch.cat(self.validation_targets)) 
        # loss = sum(self.validation_step_losses) / len(self.validation_step_losses)
        
        print("\nVal loss:", loss)
        print("Val top-1 acc:", top1_accuracy)
        print("Val top-5 acc:", top5_accuracy)
        self.log("val loss", loss)
        self.log("val top-1 acc", top1_accuracy)
        self.log("val top-5 acc", top5_accuracy)
        self.validation_targets.clear()
        self.validation_predictions.clear()
        self.validation_step_losses.clear()
    

    def configure_optimizers(self):
        optimizer = optim.SGD(self.parameters(), lr=0.01, momentum=0.9, weight_decay=0.005)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
        return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}]
        
        # return optimizer

In [4]:
# We use the train-validation-test split
# 1 provided in the dataset, with 2125 training images, 125
# validation images and 250 testing images for each class.

class MINCDataset(data.Dataset):
    NUM_CLASS = 23

    def __init__(self, root='/work/pi_adrozdov_umass_edu/kgupta_umass_edu/learning_based_img_compression/',
                 train=True):
        split = 'train' if train == True else 'val'
        root = os.path.join(root, 'minc-mini')
        print(root)
        
        self.classes, self.class_to_idx = find_classes(root + '/images')
        if split == 'train':
            filename = os.path.join(root, 'labels/train1.txt') # 2125
        else:
            filename = os.path.join(root, 'labels/validate1.txt') # 125

        self.images, self.labels = make_dataset(filename, root, self.class_to_idx)
        
        assert (len(self.images) == len(self.labels))

    def __getitem__(self, index):
        _image = self.images[index]
        _img = Image.open(_image).convert('RGB')
        _label = self.labels[index]
        
        _img = transforms.ToTensor()(_img)
        _img = transforms.Resize(384)(_img)

        return _img, _label


    def __len__(self):
        return len(self.images)


def find_classes(dir):
    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx


def make_dataset(filename, datadir, class_to_idx):
    images = []
    labels = []
    crs = []
    
    i = 0
    with open(os.path.join(filename), "r") as lines:
        for line in lines:
            _image = os.path.join(datadir, line.rstrip('\n'))
            _dirname = os.path.split(os.path.dirname(_image))[1]
            # _compressed_rep = os.path.join(datadir, 'compressed_rep', f'bpp{quality}', _dirname, os.path.splitext(os.path.split(_image)[1])[0])
            assert os.path.isfile(_image)
            # assert os.path.isfile(_compressed_rep)
            label = class_to_idx[_dirname]
            images.append(_image)
            # crs.append(_compressed_rep)
            labels.append(label)
            
            i += 1
            if i % 1000 == 0: sys.stdout.write('\r'+str(i)+' items loaded')
            
    sys.stdout.write('\r'+str(i)+' items loaded')
                           
              
    return images, labels

In [5]:
train_minc = MINCDataset(train=True)
val_minc = MINCDataset(train=False)

/work/pi_adrozdov_umass_edu/kgupta_umass_edu/learning_based_img_compression/minc-mini
25 items loaded/work/pi_adrozdov_umass_edu/kgupta_umass_edu/learning_based_img_compression/minc-mini
0 items loaded

In [6]:
train_loader = utils.data.DataLoader(train_minc, batch_size=32, shuffle=True)
valdn_loader = utils.data.DataLoader(val_minc, batch_size=32)

In [7]:
wandb_logger = WandbLogger(project='696ds-learning-based-image-compression', log_model=True)

downstream_model = cResnet39(inchannels=192)
downstream_model = downstream_model.to(device)
#trainer = pl.Trainer(fast_dev_run = True, logger=wandb_logger)
# print(0.001*48875 , 0.015*2875)
trainer = pl.Trainer(max_epochs=30, logger=wandb_logger)
# trainer = pl.Trainer(max_epochs=300)
trainer.fit(model=downstream_model, train_dataloaders=train_loader, val_dataloaders=valdn_loader)
wandb.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkartikgt23[0m ([33mumass-iesl-is[0m). Use [1m`wandb login --relogin`[0m to force relogin


Downloading: "https://compressai.s3.amazonaws.com/models/v1/bmshj2018-hyperprior-8-a583f0cf.pth.tar" to /home/kgupta_umass_edu/.cache/torch/hub/checkpoints/bmshj2018-hyperprior-8-a583f0cf.pth.tar
100%|██████████| 46.0M/46.0M [00:14<00:00, 3.22MB/s]
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type               | Params
---------------------------------------------------------
0 | layer1_y_hat      | Sequential         | 419 K 
1 | layer1_scales_hat | Sequential         | 419 K 
2 | feature_extractor | Sequential         | 23.3 M
3 | classifier        | Linear             | 47.1 K
4 | top1_accuracy     | MulticlassAccuracy | 0     
5 | top5_accuracy     | MulticlassAccuracy | 0     
6 | model             | ScaleHyperprior    | 11.8 M
7 | loss              | CustomLoss         | 0

                                   

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Epoch 0:   0%|          | 0/1 [00:00<?, ?it/s] 



tensor(4.6151e-05, device='cuda:0')


RuntimeError: Given groups=1, weight of size [192, 192, 1, 1], expected input[25, 320, 28, 28] to have 192 channels, but got 320 channels instead

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# checkpoint_reference = "umass-iesl-is/696ds-learning-based-image-compression/model-5h3ayg6d"

# # download checkpoint locally (if not already cached)
# run = wandb.init(project='696ds-learning-based-image-compression')
# artifact = run.use_artifact(checkpoint_reference, type="model")
# artifact_dir = artifact.download()

# # load checkpoint
# downstream_model = cResnet39.load_from_checkpoint(os.path.join(artifact_dir,"model.ckpt"))

In [None]:
downstream_model.to(device).eval()

with torch.no_grad():
    for batch_idx, batch in enumerate(train_loader):
        x, target = batch
        
        target = target.to(device)
        x = x.to(device)


        predict, y_hat, scales_hat, decode = downstream_model(x)
        loss = F.cross_entropy(predict, target)
        print("top 1:", Accuracy(task="multiclass", num_classes=23).to(device)(predict, target).item())
        print("top 5:", Accuracy(task="multiclass", num_classes=23, top_k=3).to(device)(predict, target).item())
        print("loss :", loss.item())
        
        for idx, img in enumerate(x):
            print("Label", target[idx].item())
            
            transform = transforms.ToPILImage()
            img = transform(img)
            
            f, axarr = plt.subplots(1,3)
            axarr[0].axis('off')
            axarr[1].axis('off')
            axarr[2].axis('off')
            
            axarr[0].imshow(img)
            axarr[0].title.set_text('Image')
            
            axarr[1].imshow(y_hat[idx].sum(axis=0).cpu())
            axarr[1].title.set_text('latent space')
            
            axarr[2].imshow(scales_hat[idx].sum(axis=0).cpu())
            axarr[2].title.set_text('std. dev.')
            
            f.suptitle(f'Target: {train_minc.classes[target[idx].item()]}, Prediction: {train_minc.classes[torch.argmax(predict[idx])]}, Top 5: {[train_minc.classes[p] for p in torch.topk(predict[idx], 5).indices.tolist()]}', fontsize=10)
            f.tight_layout()
            f.subplots_adjust(top=1.3)
            plt.show()
            
            
wandb.finish()