# Eve: Making Learning Interesting

## Load a checkpoint form PyTorch to Eve

We provide a script to load a checkpoint from PyTorch net to Eve net in few steps.

In [1]:
import torch
import torch.nn as nn
import eve
import eve.cores
import eve.app
import eve.utils

# import a evenet
from eve.app import EveCifar10Vgg


net = EveCifar10Vgg()

eve.utils.load_weight_from_legacy_checkpoint(
    net,
    legacy_checkpoint=
    "/media/densechen/data/code/eve-mli/examples/checkpoint/cifar10-vggsmall-zxd-93.4-8943fa3.pth", # the downloaded pytorch model
    # the checkpoint of alexnet can be downloaded from https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth
    eve_checkpoint=
    "/media/densechen/data/code/eve-mli/examples/checkpoint/eve-cifar10-vggsmall-zxd-93.4-8943fa3.pth", # the path to store new model.
    key_map=None, # NOTE: key_map is a dict which map the different names of weights between PyTorch and Eve.
                  # Left to None at first time call, for that we haven't defined that. 
)


/media/densechen/data/code/eve-mli/examples/checkpoint/cifar10-vggsmall-zxd-93.4-8943fa3.pth does not contains a 'state_dict' keytry to take the whole checkpoint as state_dict.
Please specify a kep map between eve and legacy.
You should pick up the paired key in eve and legacy and build a dict like: {'eve_key': 'legacy_key'}, then directly skip the unpaired one.
In most cases, the key order will not be changed, it is not a heavy work to do this.
key of /media/densechen/data/code/eve-mli/examples/checkpoint/eve-cifar10-vggsmall-zxd-93.4-8943fa3.pth
['classifier.bias',
 'classifier.weight',
 'features.0.weight',
 'features.1.bias',
 'features.1.num_batches_tracked',
 'features.1.running_mean',
 'features.1.running_var',
 'features.1.weight',
 'features.10.weight',
 'features.11.bias',
 'features.11.num_batches_tracked',
 'features.11.running_mean',
 'features.11.running_var',
 'features.11.weight',
 'features.14.weight',
 'features.15.bias',
 'features.15.num_batches_tracked',
 'features

ValueError: Invalid key map of NoneType. Follow the introduction above to generate a valid key map first.

Because we do not have the key_map defined, we call `load_weight_from_legacy_checkpoint` will raise an error.
Following the guidance of the above output, we can easy generate a key_map for specified model.

In [4]:
key_map = {
    'task_module.conv1.0.weight': 'features.0.weight',
    'task_module.conv1.1.bias': 'features.1.bias',
    'task_module.conv1.1.num_batches_tracked':
    'features.1.num_batches_tracked',
    'task_module.conv1.1.running_mean': 'features.1.running_mean',
    'task_module.conv1.1.running_var': 'features.1.running_var',
    'task_module.conv1.1.weight': 'features.1.weight',
    'task_module.conv2.0.weight': 'features.3.weight',
    'task_module.conv2.1.bias': 'features.4.bias',
    'task_module.conv2.1.num_batches_tracked':
    'features.4.num_batches_tracked',
    'task_module.conv2.1.running_mean': 'features.4.running_mean',
    'task_module.conv2.1.running_var': 'features.4.running_var',
    'task_module.conv2.1.weight': 'features.4.weight',
    'task_module.conv3.0.weight': 'features.7.weight',
    'task_module.conv3.1.bias': 'features.8.bias',
    'task_module.conv3.1.num_batches_tracked':
    'features.8.num_batches_tracked',
    'task_module.conv3.1.running_mean': 'features.8.running_mean',
    'task_module.conv3.1.running_var': 'features.8.running_var',
    'task_module.conv3.1.weight': 'features.8.weight',
    'task_module.conv4.0.weight': 'features.10.weight',
    'task_module.conv4.1.bias': 'features.11.bias',
    'task_module.conv4.1.num_batches_tracked':
    'features.11.num_batches_tracked',
    'task_module.conv4.1.running_mean': 'features.11.running_mean',
    'task_module.conv4.1.running_var': 'features.11.running_var',
    'task_module.conv4.1.weight': 'features.11.weight',
    'task_module.conv5.0.weight': 'features.14.weight',
    'task_module.conv5.1.bias': 'features.15.bias',
    'task_module.conv5.1.num_batches_tracked':
    'features.15.num_batches_tracked',
    'task_module.conv5.1.running_mean': 'features.15.running_mean',
    'task_module.conv5.1.running_var': 'features.15.running_var',
    'task_module.conv5.1.weight': 'features.15.weight',
    'task_module.conv6.0.weight': 'features.17.weight',
    'task_module.conv6.1.bias': 'features.18.bias',
    'task_module.conv6.1.num_batches_tracked':
    'features.18.num_batches_tracked',
    'task_module.conv6.1.running_mean': 'features.18.running_mean',
    'task_module.conv6.1.running_var': 'features.18.running_var',
    'task_module.conv6.1.weight': 'features.18.weight',
    'task_module.clssifier.bias': 'classifier.bias',
    'task_module.clssifier.weight': 'classifier.weight',
}
# we also give the kep_map for each implemented methods.
# key_map = eve.app.imagenet.alexnet.key_map

Now, recall `load_weight_from_legacy_checkpoint` and deliver the key_map to it.

In [5]:
eve.utils.load_weight_from_legacy_checkpoint(
    net,
    legacy_checkpoint=
    "/media/densechen/data/code/eve-mli/examples/checkpoint/cifar10-vggsmall-zxd-93.4-8943fa3.pth", # the downloaded pytorch model
    # the checkpoint of alexnet can be downloaded from https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth
    eve_checkpoint=
    "/media/densechen/data/code/eve-mli/examples/checkpoint/eve-cifar10-vggsmall-zxd-93.4-8943fa3.pth", # the path to store new model.
    key_map=key_map, # NOTE: key_map is a dict which map the different names of weights between PyTorch and Eve.
                  # Left to None at first time call, for that we haven't defined that. 
)

/media/densechen/data/code/eve-mli/examples/checkpoint/cifar10-vggsmall-zxd-93.4-8943fa3.pth does not contains a 'state_dict' keytry to take the whole checkpoint as state_dict.
new checkpoint has been saved in /media/densechen/data/code/eve-mli/examples/checkpoint/eve-cifar10-vggsmall-zxd-93.4-8943fa3.pth.


Load directly from converted eve-checkpoint.

In [6]:
net.load_state_dict(torch.load("/media/densechen/data/code/eve-mli/examples/checkpoint/eve-cifar10-vggsmall-zxd-93.4-8943fa3.pth")["state_dict"])

<All keys matched successfully>

## Load a trainer via make and test the accuracy

In [1]:
import torch
import torch.nn as nn
import eve
import eve.cores
import eve.app
import eve.utils

# import a evenet
from eve.app import make

trainer = make(
    id="trainer_cifar10_vgg",
    checkpoint_path="/media/densechen/data/code/eve-mli/examples/checkpoint/eve-cifar10-vggsmall-zxd-93.4-8943fa3.pth",
    max_timesteps=1,
    net_arch_kwargs={
        "node": "IfNode",
        "node_kwargs": {
            "neuron_wise": False,
        },
        "quan": "SteQuan",
        "quan_kwargs": {
            "neuron_wise": False,
            "upgradable": False,
            "max_bit_width": 8, # NOTE: This max bit width argument will be convered by loading the checkpoint.
        },
        "encoder": "RateEncoder",
        "encoder_kwargs": {}
    },
    optimizer_kwargs={
        "optimizer":
        "Adam",  # which kind of optimizer, SGD or Adam is supported current.
        "lr": 0.001,  # learning rate
        "betas": [0.99, 0.999],  # betas
        "eps": 1e-8,
        "weight_decay": 1e-5,
        "amsgrad": False,
        "momentum": 0.9,
        "nesterov": False,
    },
    data_kwargs={
        "root": "/media/densechen/data/dataset",
        "batch_size": 128,
        "num_workers": 4,
    },
    upgrader_kwargs={},
    kwargs={
        "device": "cuda:0",
    }
)

("making new trainer: trainer_cifar10_vgg ({'checkpoint_path': "
 "'/media/densechen/data/code/eve-mli/examples/checkpoint/eve-cifar10-vggsmall-zxd-93.4-8943fa3.pth', "
 "'max_timesteps': 1, 'net_arch_kwargs': {'node': 'IfNode', 'node_kwargs': "
 "{'neuron_wise': False}, 'quan': 'SteQuan', 'quan_kwargs': {'neuron_wise': "
 "False, 'upgradable': False, 'max_bit_width': 8}, 'encoder': 'RateEncoder', "
 "'encoder_kwargs': {}}, 'optimizer_kwargs': {'optimizer': 'Adam', 'lr': "
 "0.001, 'betas': [0.99, 0.999], 'eps': 1e-08, 'weight_decay': 1e-05, "
 "'amsgrad': False, 'momentum': 0.9, 'nesterov': False}, 'data_kwargs': "
 "{'root': '/media/densechen/data/dataset', 'batch_size': 128, 'num_workers': "
 "4}, 'upgrader_kwargs': {}, 'kwargs': {'device': 'cuda:0'}})")
Files already downloaded and verified
Files already downloaded and verified
original accuracy: 0.9302808544303798
no upgrader needed
