In [1]:
import os
import pickle
import src.main as lc
import torch

In [7]:
root = "example"
model_name = "redandblack-d8-w256"
target_frame_list = list(range(1450, 1455))

model_dir = f"{root}/{model_name}"

COMPRESSED_SAVELOC = f"{model_dir}/compressed-test"
DECOMPRESSED_SAVELOC = f"{model_dir}/decompressed-test"


'''
For PyTorch version of NeRF, the checkpoint saves the following: 
- network_fn_state_dict: coarse network
- network_fine_state_dict: fine network
- optimizer_state_dict: optimizer
We only need the fine network for compression
'''
dict_name = "network_fine_state_dict"
BASE_DICT = torch.load(f"{model_dir}/{target_frame_list[0]}.tar", map_location = torch.device('cpu'))[dict_name] # Base dictionary (for decompressor to understand structure of the model)

lc.dict_name = dict_name
lc.is_float16 = False # flag setting for float16

In [8]:
'''
Compression
'''
num_bits = 2
enc_model_list = [f"{str(i)}.tar" for i in target_frame_list]

lc.compress_set(filename=model_dir, models=enc_model_list, saveloc = COMPRESSED_SAVELOC, num_bits=num_bits)

Delta Compression on: 1451.tar
Delta Compression on: 1452.tar
Delta Compression on: 1453.tar
Delta Compression on: 1454.tar
Saving Compressed Format: 1450.tar
Saving Compressed Format: 1451.tar
Saving Compressed Format: 1452.tar
Saving Compressed Format: 1453.tar
Saving Compressed Format: 1454.tar


In [9]:
'''
Decompression
'''
dec_model_list = [f"compressed_{str(i)}.pt" for i in target_frame_list]

lc.load_compressed_set(COMPRESSED_SAVELOC, dec_model_list, DECOMPRESSED_SAVELOC, BASE_DICT)


'''
1. Read the uncompressed network_fine_state_dict, 
2. Stich it with other dicts (network_fn_state_dict, optimizer_state_dict), save them as a checkpoint
'''

# iterate the DECOMPRESSED_SAVELOC to restore the checkpoints
for file in sorted(os.listdir(DECOMPRESSED_SAVELOC)):
    if file.endswith(".pt"):
        file_path = os.path.join(DECOMPRESSED_SAVELOC, file)
        # capture the digits of the filename
        frame_no = int(''.join(filter(str.isdigit, file)))
        decompressed_state_dict = lc.read_decompressed_state_dict(file_path)
        decompressed_checkpoint = torch.load(f"{model_dir}/{frame_no}.tar", map_location = torch.device('cpu'))
        decompressed_checkpoint["network_fine_state_dict"].update(decompressed_state_dict)
        torch.save(decompressed_checkpoint, f"{DECOMPRESSED_SAVELOC}/{frame_no}.tar")


Decompressing for: compressed_1450.pt
Decompressing for: compressed_1451.pt
Decompressing for: compressed_1452.pt
Decompressing for: compressed_1453.pt
Decompressing for: compressed_1454.pt
Saving Decompressed Model at: decompressed_1450.pt
Saving Decompressed Model at: decompressed_1451.pt
Saving Decompressed Model at: decompressed_1452.pt
Saving Decompressed Model at: decompressed_1453.pt
Saving Decompressed Model at: decompressed_1454.pt
