In [1]:
import os
import shutil
import unittest
from copy import deepcopy

import torch
import torch.nn as nn
import numpy as np

from src.ai.base_net import ArchitectureConfig, BaseNet
from src.ai.utils import mlp_creator, generate_random_dataset_in_raw_data
from src.core.utils import get_to_root_dir, get_filename_without_extension, generate_random_image
from src.ai.architectures import *  # Do not remove

In [27]:
output_path = "dronet_sidetuned"
architecture_base_config = {
    "output_path": output_path,
    "architecture": "dronet",
}
network = eval(architecture_base_config['architecture']).Net(
            config=ArchitectureConfig().create(config_dict=architecture_base_config)
        )
checkpoint_file = os.path.join(os.environ['DATADIR'], 'dronet', 'torch_checkpoints', 'checkpoint_latest.ckpt')
checkpoint = torch.load(checkpoint_file, map_location=torch.device('cpu'))
network.load_checkpoint(checkpoint['net_ckpt'])
first_weight_checksum = network.conv2d_1.weight.data.sum().item()
feature_extraction = [ch for ch in network.children()][:-1]

dronet - INFO - Started.


In [29]:
from src.ai.architectures.dronet_sidetuned import Net

In [49]:
architecture_base_config['architecture'] = 'dronet_sidetuned'
sidetuned_network = Net(ArchitectureConfig().create(config_dict=architecture_base_config))
sidetuned_feature_extraction = [ch for ch in sidetuned_network.children()][:-1]

dronet_sidetuned - INFO - Started.
dronet_sidetuned - INFO - Started.


In [87]:
sidetuned_network.conv2d_1.weight.data.sum().item()

-5.570544719696045

In [88]:
# Initialize two feature extraction part 
for sidetuned_network_part in [sidetuned_feature_extraction[:17], 
                              sidetuned_feature_extraction[17:-1]]:
    for index, layer in enumerate(sidetuned_network_part):
        print(f'\n\n{layer}')
        for element in ['weight', 'bias']:
            if hasattr(layer, element):
                print(layer._parameters[element].shape)
                layer._parameters[element] = feature_extraction[index]._parameters[element].detach().clone()
        if hasattr(layer, 'running_mean'):
            layer.running_mean = feature_extraction[index].running_mean
        if hasattr(layer, 'running_var'):
            layer.running_var = feature_extraction[index].running_var
assert sidetuned_network.conv2d_1.weight.data.sum().item() == first_weight_checksum
assert sidetuned_network.sidetune_conv2d_1.weight.data.sum().item() == first_weight_checksum





Conv2d(1, 32, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1))
torch.Size([32, 1, 5, 5])
torch.Size([32])


MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)


BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
torch.Size([32])
torch.Size([32])


Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
torch.Size([32, 32, 3, 3])
torch.Size([32])


BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
torch.Size([32])
torch.Size([32])


Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
torch.Size([32, 32, 3, 3])
torch.Size([32])


Conv2d(32, 32, kernel_size=(1, 1), stride=(2, 2))
torch.Size([32, 32, 1, 1])
torch.Size([32])


BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
torch.Size([32])
torch.Size([32])


Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
torch.Size([64, 32, 3, 3])
torch.Size([64])


BatchNorm2d(64, eps=1e-05, mome

In [99]:
sidetune_checkpoint = {'net_ckpt': sidetuned_network.get_checkpoint()}

In [97]:
checkpoint_file

'dronet_sidetuned/torch_checkpoints/checkpoint_latest.ckpt'

In [100]:
checkpoint_file = os.path.join(os.environ['DATADIR'], architecture_base_config['output_path'], 'torch_checkpoints', 'checkpoint_latest.ckpt')
os.makedirs(os.path.dirname(checkpoint_file), exist_ok=True)
torch.save(sidetune_checkpoint, checkpoint_file)

In [96]:
# test loading of checkpoint
test_network = Net(ArchitectureConfig().create(config_dict=architecture_base_config))
test_checkpoint = torch.load(checkpoint_file, map_location=torch.device('cpu'))
test_network.load_checkpoint(test_checkpoint['net_ckpt'])
assert test_network.conv2d_1.weight.data.sum().item() == first_weight_checksum

dronet_sidetuned - INFO - Started.
