In [1]:
import torch
from dataset import *
import torchvision.transforms as transforms
from utils import *
from train import evaluate, evaluate_plus
from utils_compress import *

In [2]:
import random, numpy, os
torch.manual_seed(808)
random.seed(909)
numpy.random.seed(303)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
import argparse

parser = argparse.ArgumentParser()
# Model and Dataset configuration
parser.add_argument('--dataset', type=str, default='UVG', help='dataset')
parser.add_argument('--model_type', type=str, default='D-NeRV', choices=['NeRV', 'D-NeRV', 'HDNeRV','HDNeRV2', 'HDNeRV3','RAFT', 'RAFT_t'])
parser.add_argument('--model_size', type=str, default='S', choices=['XS', 'S', 'M', 'L', 'XL'])
parser.add_argument('--embed', type=str, default='1.25_240', help='base value/embed length for position encoding')
parser.add_argument('--spatial_size_h', type=int, default=256)
parser.add_argument('--spatial_size_w', type=int, default=320)
parser.add_argument('--keyframe_quality', type=int, default=3, help='keyframe quality, control flag used for keyframe image compression')
parser.add_argument('--clip_size', type=int, default=8, help='clip_size to sample at a single time')
parser.add_argument('--fc_hw', type=str, default='4_5', help='out hxw size for mlp')
parser.add_argument('--fc_dim', type=str, default='100', help='out channel size for mlp')
parser.add_argument('--enc_dim', type=int, nargs='+', default=[80, 160, 320, 640], help='enc latent dim and embedding ratio')
parser.add_argument('--enc_block', type=int, nargs='+', default=[3, 3, 9, 3, 3], help='blocks list')
parser.add_argument('--expansion', type=float, default=2, help='channel expansion from fc to conv')
parser.add_argument('--strides', type=int, nargs='+', default=[4, 2, 2, 2, 2], help='strides list')
parser.add_argument('--lower_width', type=int, default=32, help='lowest channel width for output feature maps')
parser.add_argument('--ver', action='store_true', default=True, help='ConvNeXt Version')
parser.add_argument('--ignore', action='store_true', default=False, help='Ignore image')

# General training setups
parser.add_argument('-j', '--workers', type=int, help='number of data loading workers', default=16)
parser.add_argument('-b', '--batchSize', type=int, default=8, help='input batch size')
parser.add_argument('-e', '--epochs', type=int, default=400, help='number of epochs to train for')
parser.add_argument('--warmup', type=float, default=0.2, help='warmup epoch ratio compared to the epochs, default=0.2')
parser.add_argument('--lr', type=float, default=5e-4, help='learning rate, default=0.0002')
parser.add_argument('--lr_type', type=str, default='cos', help='learning rate type, default=cos')
parser.add_argument('--loss_type', type=str, default='Fusion6', help='loss type, default=L2')
parser.add_argument('--start_epoch', type=int, default=0, help='starting epoch')

# evaluation parameters
parser.add_argument('--weight', default='None', type=str, help='pretrained weights for ininitialization')
parser.add_argument('--eval_only', action='store_true', default=False, help='do evaluation only')
parser.add_argument('--quant_model', action='store_true', default=False, help='apply model quantization from torch.float32 to torch.int8')
parser.add_argument('--quant_model_bit', type=int, default=8, help='bit length for model quantization, default int8')
parser.add_argument('--quant_axis', type=int, default=1, help='quantization axis (1 for D-NeRV, 0 for NeRV)')
parser.add_argument('--dump_images', action='store_true', default=False, help='dump the prediction images')

# distribute learning parameters
parser.add_argument('--seed', type=int, default=1, help='manual seed')
parser.add_argument('--init_method', default='tcp://127.0.0.1:9888', type=str, help='url used to set up distributed training')
parser.add_argument('-d', '--distributed', action='store_true', default=False, help='distributed training')

parser.add_argument('-p', '--print-freq', default=500, type=int,)
args = parser.parse_args('')

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
args.dataset_mean = [0.4519, 0.4505, 0.4519]
args.dataset_std = [0.2434, 0.2547, 0.2958]
args.clip_size=8
args.strides=[4, 2, 2, 2, 2]
args.lower_width=32
args.fc_hw='4_5'
args.enc_dim=[80, 80, 80, 80, 40]
args.enc_block=[3, 3, 9, 3, 3]
args.fc_dim=args.enc_dim[-1]
args.expansion=2
args.outf='out_compress/hdnerv3'
args.model_type='HDNeRV3'
PE = PositionalEncoding('1.25_240')


In [6]:
from model import HDNeRV3
model = HDNeRV3(fc_hw=args.fc_hw, enc_dim=args.enc_dim,enc_block=args.enc_block, fc_dim=args.fc_dim, expansion=args.expansion, 
                        stride_list=args.strides, lower_width=args.lower_width, 
                        clip_size=args.clip_size, device=device,
                        dataset_mean=args.dataset_mean, dataset_std=args.dataset_std, ver=args.ver).to(device)

transform_rgb = transforms.Compose([transforms.ToTensor()])
transform_keyframe = transforms.Compose([transforms.ToTensor(), transforms.Normalize(args.dataset_mean, args.dataset_std)])

val_dataset = Dataset_DNeRV_UVG(args, transform_rgb, transform_keyframe)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batchSize, shuffle=False,
        num_workers=args.workers, pin_memory=True, drop_last=False, worker_init_fn=worker_init_fn, collate_fn=my_collate_fn)

In [7]:
import yaml

# Lấy danh sách các khóa từ model
model_keys = list(model.state_dict().keys())

# Tên tệp YAML bạn muốn tạo
yaml_file = 'hdnerv3.yaml'

# Ghi danh sách khóa vào tệp YAML
with open(yaml_file, 'w') as f:
    yaml.dump(model_keys, f)


In [7]:
# video, input_index, keyframe, backward_distance, frame_mask = next(iter(val_dataloader))


In [8]:
# video, input_index, keyframe, backward_distance, frame_mask = video.to(device), input_index.to(device), keyframe.to(device), backward_distance.to(device), frame_mask.to(device)

In [9]:
checkpoint_path='checkpoints/HDNeRV3/S.pth'
checkpoint = torch.load(checkpoint_path, map_location='cpu')
orig_ckt = checkpoint['state_dict']

In [10]:
model.load_state_dict(orig_ckt)

<All keys matched successfully>

In [11]:
psnr1,ms1,bpp1 = evaluate_plus(model, val_dataloader, local_rank=0, args=args, method ='normal',length_dataset=len(val_dataset), frame_path_list=val_dataset.frame_path_list)

1464it [02:33,  9.52it/s]


Quantization
Huffman Encoding
Dequantization
BPP:  0.04133354028065999
Rank:0, Step [1/1464], PSNR: 33.2, MSSSIM: 0.8627
Rank:0, Step [501/1464], PSNR: 35.43, MSSSIM: 0.9406
Rank:0, Step [1001/1464], PSNR: 35.47, MSSSIM: 0.9538


In [12]:
psnr2,ms2,bpp2 = evaluate_plus(model, val_dataloader, local_rank=0, args=args, method ='cabac',length_dataset=len(val_dataset), frame_path_list=val_dataset.frame_path_list)

194it [00:31,  6.83it/s]

1464it [03:38,  6.69it/s]

INITIALIZE APPROXIMATOR AND ENCODER...DONE in 0.0063 s
APPROXIMATING WITH METHOD uniform...




DONE in 4.5990 s
ENCODING...DONE in 115.8408 s
COMPRESSED FROM 224870400 BYTES TO 64618126 BYTES (64618.13 KB, 64.62 MB, COMPRESSION RATIO: 28.74 %) in 120.4472 s
DECODING...DONE in 27.6450 s
RECONSTRUCTING...DONE in 0.2619 s
INITIALIZE APPROXIMATOR AND ENCODER...DONE in 0.0011 s
APPROXIMATING WITH METHOD uniform...DONE in 0.1391 s
ENCODING...DONE in 1.9404 s
COMPRESSED FROM 6691340 BYTES TO 1583435 BYTES (1583.43 KB, 1.58 MB, COMPRESSION RATIO: 23.66 %) in 2.0817 s
DECODING...DONE in 0.3788 s
RECONSTRUCTING...DONE in 0.0091 s
BPP:  0.06899970320404553
Rank:0, Step [1/1464], PSNR: 33.2, MSSSIM: 0.8627
Rank:0, Step [501/1464], PSNR: 35.41, MSSSIM: 0.9405
Rank:0, Step [1001/1464], PSNR: 35.58, MSSSIM: 0.9541


In [13]:
# -30 --> -46, step = 4
# bitstream = nnc.compress_model(model, qp=-44, return_bitstream=True)