In [1]:
import json
import os

import pyrallis
import torch
from torch.utils.tensorboard import SummaryWriter
from clearml import Task

from HAND.HAND_train import load_original_model
from HAND.loss.attention_loss import AttentionLossFactory
from HAND.loss.distillation_loss import DistillationLossFactory
from HAND.loss.reconstruction_loss import ReconstructionLossFactory
from HAND.loss.task_loss import TaskLossFactory
from HAND.options import TrainConfig
from HAND.predictors.factory import HANDPredictorFactory
from HAND.tasks.model_factory import ModelFactory
import HAND.log_utils as log_utils
from HAND.trainer import Trainer
from HAND.eval_func import EvalFunction
from HAND.tasks.dataloader_factory import DataloaderFactory
from HAND.tasks.model_factory import ModelFactory

In [2]:
def init_predictor(cfg, predictor):
    if cfg.hand.init == "fmod":
        print("Initializing using fmod")
        for p in predictor.parameters():
            if len(p.shape) >= 2:
                p.data = torch.fmod(p.data, 2)
    elif cfg.hand.init == "checkpoint":
        print(f"Loading pretrained weights from: {cfg.hand.checkpoint_path}")
        predictor.load(cfg.hand.checkpoint_path)
    elif cfg.hand.init == "default":
        print("Using default torch initialization")
    else:
        raise ValueError(f"Unsupported initialization method: {cfg.hand.init}")

In [3]:
cfg = pyrallis.load(TrainConfig,
                              open('experiments/resnet56/cifar10/debug_interpolation.yaml',"r"))

In [4]:
use_cuda = not cfg.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [5]:
original_model, reconstructed_model = load_original_model(cfg, device)

Trying to load precomputed embeddings
Couldn't load precomputed embeddings, computing embeddings
Calculating layer 1/57 embeddings
Calculating layer 2/57 embeddings
Calculating layer 3/57 embeddings
Calculating layer 4/57 embeddings
Calculating layer 5/57 embeddings
Calculating layer 6/57 embeddings
Calculating layer 7/57 embeddings
Calculating layer 8/57 embeddings
Calculating layer 9/57 embeddings
Calculating layer 10/57 embeddings
Calculating layer 11/57 embeddings
Calculating layer 12/57 embeddings
Calculating layer 13/57 embeddings
Calculating layer 14/57 embeddings
Calculating layer 15/57 embeddings
Calculating layer 16/57 embeddings
Calculating layer 17/57 embeddings
Calculating layer 18/57 embeddings
Calculating layer 19/57 embeddings
Calculating layer 20/57 embeddings
Calculating layer 21/57 embeddings
Calculating layer 22/57 embeddings
Calculating layer 23/57 embeddings
Calculating layer 24/57 embeddings
Calculating layer 25/57 embeddings
Calculating layer 26/57 embeddings
Ca

In [7]:
pos_embedding = reconstructed_model.output_size
from HAND.predictors.factory import HANDPredictorFactory, PredictorDataParallel

predictor = HANDPredictorFactory(cfg, input_size=pos_embedding).get_predictor().to(device)
predictor = PredictorDataParallel(predictor)
predictor.load_state_dict(torch.load("/nfs/private/Maor/HANDCompress/HAND/outputs/resnet56_0_B_cifar10_ranger_nogc_350_epochs_1.1MB_base0.76_size40_1_31_08_2022_122955/hand_resnet56_0_B_cifar10_ranger_nogc_350_epochs_1.1MB_base0.76_size40_1_best.pth").state_dict())

init_predictor(cfg, predictor)

if not cfg.logging.disable_logging:
    if cfg.logging.use_tensorboard:
        logger = SummaryWriter(log_dir=os.path.join(cfg.logging.log_dir, "tb_logs", cfg.logging.exp_name))
        logger.add_text("TrainConfig", json.dumps(pyrallis.encode(cfg), indent=4))
    else:
        clearml_task = Task.init(project_name='HAND_compression', task_name=cfg.logging.exp_name, deferred_init=True)
        clearml_task.connect(log_utils.flatten(pyrallis.encode(cfg)))  # Flatten because of clearml bug
        logger = clearml_task.get_logger()
else:
    logger = None

num_predictor_params = sum([p.numel() for p in predictor.parameters()])
print(f"Predictor:"
      f"\t-> Number of parameters: {num_predictor_params / 1000}K"
      f"\t-> Size: {num_predictor_params * 4 / 1024 / 1024:.2f}Mb")

num_predicted_params = sum([p.numel() for p in original_model.get_learnable_weights()])
num_total_params = sum([p.numel() for p in original_model.parameters()])
print(f"\nOriginal Model:"
      f"\t-> Number of learnable parameters: {num_predicted_params / 1000}K"
      f"\t-> Size of learnable parameters: {num_predicted_params * 4 / 1024 / 1024:.2f}Mb",
      f"\n\t-> Total model size: {num_total_params * 4 / 1024 / 1024:.2f}Mb")

dataloaders = DataloaderFactory.get(cfg.task, **{'batch_size': cfg.batch_size,
                                                 'num_workers': cfg.num_workers})


Initializing using fmod
Predictor:	-> Number of parameters: 321.705K	-> Size: 1.23Mb

Original Model:	-> Number of learnable parameters: 850.864K	-> Size of learnable parameters: 3.25Mb 
	-> Total model size: 3.26Mb


ValueError: Unsupported task

In [73]:
from HAND.predictors.predictor import HANDPredictorBase
indices, positional_embeddings = reconstructed_model.get_indices_and_positional_embeddings()
original_weights = original_model.get_learnable_weights()
learnable_weights_shapes = reconstructed_model.get_learnable_weights_shapes()

reconstructed_weights = HANDPredictorBase.predict_all(predictor, positional_embeddings,
                                                          original_weights,
                                                          learnable_weights_shapes)
reconstructed_model.update_weights(reconstructed_weights)

In [74]:
from HAND.eval_func import EvalFunction
eval_fn = EvalFunction(cfg)
eval_fn.eval(original_model,dataloaders[1], None, None)


 Starting eval on test set.

Test set: Average loss: -17.2330, Accuracy: 9352/10000 (94%)



93.52

In [75]:
eval_fn.eval(reconstructed_model,dataloaders[1], None, None)


 Starting eval on test set.

Test set: Average loss: -10.0224, Accuracy: 8751/10000 (88%)



87.51

In [121]:
from HAND.tasks.resnet56 import ResNet56X2, ReconstructedResNet56X2
resnetx2_model = ResNet56X2().to(device)
resnetx2_recon = ReconstructedResNet56X2(resnetx2_model, cfg, device, sampling_mode=cfg.hand.sampling_mode)

Trying to load precomputed embeddings
Loaded positional embeddings for layer 1/55
Loaded positional embeddings for layer 2/55
Loaded positional embeddings for layer 3/55
Loaded positional embeddings for layer 4/55
Loaded positional embeddings for layer 5/55
Loaded positional embeddings for layer 6/55
Loaded positional embeddings for layer 7/55
Loaded positional embeddings for layer 8/55
Loaded positional embeddings for layer 9/55
Loaded positional embeddings for layer 10/55
Loaded positional embeddings for layer 11/55
Loaded positional embeddings for layer 12/55
Loaded positional embeddings for layer 13/55
Loaded positional embeddings for layer 14/55
Loaded positional embeddings for layer 15/55
Loaded positional embeddings for layer 16/55
Loaded positional embeddings for layer 17/55
Loaded positional embeddings for layer 18/55
Loaded positional embeddings for layer 19/55
Loaded positional embeddings for layer 20/55
Loaded positional embeddings for layer 21/55
Loaded positional embeddin

In [122]:
resnetx2_model.get_learnable_weights()[4].shape

torch.Size([32, 32, 3, 3])

In [123]:
original_model.get_learnable_weights()[4].shape

torch.Size([16, 16, 3, 3])

In [124]:
indices_x2, positional_embeddings_x2 = resnetx2_recon.get_indices_and_positional_embeddings()
original_weights_x2 = resnetx2_model.get_learnable_weights()
learnable_weights_shapes_x2 = resnetx2_recon.get_learnable_weights_shapes()

In [125]:
positional_embeddings_x2[0].shape

torch.Size([96, 240])

In [126]:
positional_embeddings[0].shape

torch.Size([48, 240])

In [127]:
positional_embeddings[0][0]-positional_embeddings_x2[0][0]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [133]:
reconstructed_weights_x2 = HANDPredictorBase.predict_all(predictor, positional_embeddings_x2,
                                                          original_weights_x2,
                                                          learnable_weights_shapes_x2)


In [134]:
resnetx2_recon.update_weights(reconstructed_weights_x2)

In [135]:
eval_fn.eval(resnetx2_recon,dataloaders[1], None, None)


 Starting eval on test set.

Test set: Average loss: nan, Accuracy: 1076/10000 (11%)



10.76

In [137]:
torch.save(resnetx2_recon.reconstructed_model.state_dict(), "./reconstructed_x2_weights.pth")

AttributeError: 'ResNet56X2' object has no attribute 'module'


 Starting eval on test set.

Test set: Average loss: 1.9539, Accuracy: 1000/10000 (10%)



10.0