In [None]:
import json
import os

import pyrallis
import torch
from torch.utils.tensorboard import SummaryWriter
from clearml import Task

from HAND.HAND_train import load_original_model
from HAND.loss.attention_loss import AttentionLossFactory
from HAND.loss.distillation_loss import DistillationLossFactory
from HAND.loss.reconstruction_loss import ReconstructionLossFactory
from HAND.loss.task_loss import TaskLossFactory
from HAND.options import TrainConfig
from HAND.predictors.factory import HANDPredictorFactory
from HAND.tasks.model_factory import ModelFactory
import HAND.log_utils as log_utils
from HAND.trainer import Trainer
from HAND.eval_func import EvalFunction
from HAND.tasks.dataloader_factory import DataloaderFactory                                         'num_workers': cfg.num_workers})
from HAND.tasks.model_factory import ModelFactory

In [None]:
def init_predictor(cfg, predictor):
    if cfg.hand.init == "fmod":
        print("Initializing using fmod")
        for p in predictor.parameters():
            if len(p.shape) >= 2:
                p.data = torch.fmod(p.data, 2)
    elif cfg.hand.init == "checkpoint":
        print(f"Loading pretrained weights from: {cfg.hand.checkpoint_path}")
        predictor.load(cfg.hand.checkpoint_path)
    elif cfg.hand.init == "default":
        print("Using default torch initialization")
    else:
        raise ValueError(f"Unsupported initialization method: {cfg.hand.init}")

cfg = pyrallis.parse(config_class=TrainConfig,
                              config_path='experiments/resnet56/cifar10/resnet56_0_B_cifar10_ranger_nogc_350_epochs_1.5MB_base0.76.yaml')
use_cuda = not cfg.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

original_model, reconstructed_model = load_original_model(cfg, device)

pos_embedding = reconstructed_model.output_size
predictor = HANDPredictorFactory(cfg, input_size=pos_embedding).get_predictor().to(device)

init_predictor(cfg, predictor)

if not cfg.logging.disable_logging:
    if cfg.logging.use_tensorboard:
        logger = SummaryWriter(log_dir=os.path.join(cfg.logging.log_dir, "tb_logs", cfg.logging.exp_name))
        logger.add_text("TrainConfig", json.dumps(pyrallis.encode(cfg), indent=4))
    else:
        clearml_task = Task.init(project_name='HAND_compression', task_name=cfg.logging.exp_name, deferred_init=True)
        clearml_task.connect(log_utils.flatten(pyrallis.encode(cfg)))  # Flatten because of clearml bug
        logger = clearml_task.get_logger()
else:
    logger = None

num_predictor_params = sum([p.numel() for p in predictor.parameters()])
print(f"Predictor:"
      f"\t-> Number of parameters: {num_predictor_params / 1000}K"
      f"\t-> Size: {num_predictor_params * 4 / 1024 / 1024:.2f}Mb")

num_predicted_params = sum([p.numel() for p in original_model.get_learnable_weights()])
num_total_params = sum([p.numel() for p in original_model.parameters()])
print(f"\nOriginal Model:"
      f"\t-> Number of learnable parameters: {num_predicted_params / 1000}K"
      f"\t-> Size of learnable parameters: {num_predicted_params * 4 / 1024 / 1024:.2f}Mb",
      f"\n\t-> Total model size: {num_total_params * 4 / 1024 / 1024:.2f}Mb")

dataloaders = DataloaderFactory.get(cfg.task, **{'batch_size': cfg.batch_size,
                                                 'num_workers': cfg.num_workers})
