In [None]:
import os
import sys

root_dir = os.path.dirname(os.path.dirname(os.getcwd()))
if root_dir not in sys.path: sys.path.append(root_dir)

import numpy as np
import torch

import DeepSparseCoding.utils.loaders as loaders
import DeepSparseCoding.utils.run_utils as run_utils
import DeepSparseCoding.utils.dataset_utils as dataset_utils

In [None]:
lca_param_file = os.path.expanduser("~")+"/Work/DeepSparseCoding/params/lca_mnist_params.py"
#lca_log_file = os.path.expanduser("~")+"/Work/Torch_projects/lca_768_mnist/logfiles/lca_768_mnist_v0.log"

lca_params = loaders.load_params(lca_param_file)
lca_params.train_logs_per_epoch = None
lca_params.shuffle_data = False

train_loader, val_loader, test_loader, lca_params = dataset_utils.load_dataset(lca_params)

lca_model = loaders.load_model(lca_params.model_type, root_dir=lca_params.lib_root_dir)
lca_model.setup(lca_params)
lca_model.to(lca_params.device)

In [None]:
ensemble_param_file = os.path.expanduser("~")+"/Work/DeepSparseCoding/params/lca_mlp_mnist_params.py"
#ensemble_log_file = os.path.expanduser("~")+"/Work/Torch_projects/lca_768_mlp_mnist/logfiles/lca_768_mlp_mnist_v0.log"

ensemble_params = loaders.load_params(ensemble_param_file)

ensemble_params.epoch_size = lca_params.epoch_size
ensemble_params.num_val_images = lca_params.num_val_images
ensemble_params.num_test_images = lca_params.num_test_images
ensemble_params.data_shape = lca_params.data_shape
ensemble_params.train_logs_per_epoch = lca_params.train_logs_per_epoch
ensemble_params.shuffle_data = lca_params.shuffle_data

ensemble_model = loaders.load_model(ensemble_params.model_type, root_dir=ensemble_params.lib_root_dir)
ensemble_model.setup(ensemble_params)
ensemble_model.to(ensemble_params.device)

In [None]:
ensemble_state_dict = ensemble_model.state_dict()
ensemble_state_dict['lca.w'] = lca_model.w.clone()
ensemble_model.load_state_dict(ensemble_state_dict)

In [None]:
data, target = next(iter(train_loader))
train_data_batch = lca_model.preprocess_data(data.to(lca_model.params.device))
train_target_batch = target.to(lca_model.params.device)

In [None]:
lca_model.optimizer.zero_grad()
for submodel in ensemble_model:
    submodel.optimizer.zero_grad()
    
inputs = [train_data_batch] # only the first model acts on input
for submodel in ensemble_model:
    inputs.append(submodel.get_encodings(inputs[-1]).detach())

lca_loss = lca_model.get_total_loss((train_data_batch, train_target_batch))
ensemble_losses = [ensemble_model.get_total_loss((inputs[0], train_target_batch), 0)]
ensemble_losses.append(ensemble_model.get_total_loss((inputs[1], train_target_batch), 1))

lca_loss.backward()
ensemble_losses[0].backward()
ensemble_losses[1].backward()

print("lca losses are equal: ", lca_loss.cpu().detach().numpy() == ensemble_losses[0].cpu().detach().numpy())
print("lca grads are equal: ", np.all(lca_model.w.grad.cpu().numpy() == ensemble_model[0].w.grad.cpu().numpy()))

In [None]:
lca_pre_train_w = lca_model.w.cpu().detach().numpy().copy()
ensemble_pre_train_w = ensemble_model[0].w.cpu().detach().numpy().copy()

run_utils.train_epoch(1, lca_model, train_loader)
run_utils.train_epoch(1, ensemble_model, train_loader)

lca_w = lca_model.w.cpu().detach().numpy().copy()
ensemble_w = ensemble_model[0].w.cpu().detach().numpy().copy()

print("lca & ensemble weights equal before one epoch of training", np.all(lca_pre_train_w == ensemble_pre_train_w))
print("lca weights different from init after one epoch of training", not np.all(lca_pre_train_w == lca_w))
print("ensemble weights different from init after one epoch of training", not np.all(ensemble_pre_train_w == ensemble_w))
print("lca & ensemble weights equal after one epoch of training", np.all(lca_w == ensemble_w))