# NNCodec Demo @ ICML 2023 Neural Compression Workshop  

## Compressed DeepLabV3 solving the Pascal VOC semantic segmentation task

### Setup

#### Imports

In [None]:
import nnc
import torch, torchvision

#### Reproducibility

In [None]:
import random, numpy, os
torch.manual_seed(808)
random.seed(909)
numpy.random.seed(303)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if int(torch.version.cuda.split(".")[0]) > 10 or \
        (int(torch.version.cuda.split(".")[0]) == 10 and int(torch.version.cuda.split(".")[1]) >= 2):
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

#### Environment, data and evaluation function initialization

In [None]:
from framework.pytorch_model import __initialize_data_functions, np_to_torch
from framework.use_case_init import use_cases
from framework.applications.utils.evaluation import evaluate_classification_model

dataset_path = "./example/VOC_demo"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torchvision.models.get_model("deeplabv3_resnet50", weights="DEFAULT")
criterion = torch.nn.CrossEntropyLoss()

# data loaders
test_set, test_loader, val_set, val_loader, train_loader = __initialize_data_functions(handler=use_cases['NNR_PYT_VOC'],
                                                                                       dataset_path=dataset_path,
                                                                                       batch_size=1,
                                                                                       num_workers=0)
# evaluation function
def eval_compressed_model(bitstream, prefix, verbose=False, uncompressed_iou=None):
    model.load_state_dict(np_to_torch(nnc.decompress(bitstream, verbose=verbose)))
    _ = evaluate_classification_model(model, criterion, test_loader, test_set, device=device, verbose=verbose, 
                                      plot_segmentation_masks=True, prefix=prefix, orig_iou=uncompressed_iou)

### Test the uncompressed model, transparenetly compress it and verify the decoded model

In [None]:
test_model = torchvision.models.get_model("deeplabv3_resnet50", weights="DEFAULT")
sIoU,_,_ = evaluate_classification_model(test_model, criterion, test_loader, test_set, device=device, verbose=False, 
                                         plot_segmentation_masks=True)

# compress the model
bs = nnc.compress_model(test_model,
                        bitstream_path='bitstream.nnc',
                        use_case='NNR_PYT_VOC',
                        dataset_path=dataset_path,
                        qp=-46,
                        use_dq=True,
                        opt_qp=True,
                        return_bitstream=True,
                        num_workers=0,
                        batch_size=1)

# decompress the bitstream
rec_mdl_params = nnc.decompress(bs)

# evaluation of decoded and reconstructed model
test_model.load_state_dict(np_to_torch(rec_mdl_params))

_ = evaluate_classification_model(test_model, criterion, test_loader, test_set, device=device, verbose=False, 
                                  plot_segmentation_masks=True, prefix="_compressed", orig_iou=sIoU)

## Now lets compress the model higher until the segmentation masks are too faulty

In [None]:
def nnc_compress_model(qp):
    return nnc.compress_model(torchvision.models.get_model("deeplabv3_resnet50", weights="DEFAULT"),
                              bitstream_path='bitstream.nnc',
                              use_case='NNR_PYT_VOC',
                              dataset_path="./example/VOC_demo",
                              qp=qp,
                              opt_qp=True,
                              use_dq=True,
                              return_bitstream=True,
                              num_workers=0,
                              batch_size=1,
                              verbose=False)

In [None]:
for qp in [-44, -42, -40, -38, -36, -34, -33, -32, -30]:
    print(f"QP: {qp}")
    bitstream = nnc_compress_model(qp=qp)
    eval_compressed_model(bitstream, prefix=f"_compressed_qp_{qp}", uncompressed_iou=sIoU)
