In [1]:
import os
print(os.getcwd())
os.chdir(os.path.join('/', 'content', 'drive', 'MyDrive', 'eqp_revised'))
print(os.getcwd())
from framework import eqp
from framework import datasets
from matplotlib import pyplot as plt
import torch
import numpy as np
import time
import pickle
filepath = os.path.join(os.getcwd(), 'results', 'mnist_verification_3layerSW')
import sys
old_stdout = sys.stdout
log_file = open(os.path.join(filepath, 'log.txt'), 'w')
sys.stdout = log_file

/content
/content/drive/MyDrive/eqp_revised


In [None]:
topology = \
{
    'layer sizes': [28**2, 500, 500, 500, 10],
    'network type': 'SW_intra',
    'bypass p': .0756,
    'bypass mag': .05
}
hyperparameters = \
{
    'learning rate': .05,
    'epsilon': .5,
    'beta': 1.0,
    'free iterations': 500,
    'weakly clamped iterations': 8
}
configuration = \
{
    'batch size': 20,
    'device': 'cuda',
    'seed': 0
}

per_layer_rates = []
correction_matrices = []
training_errors = []
test_errors = []

n_epochs = 250
rate_period = 1
correction_period = 10
Network = eqp.Network(topology, hyperparameters, configuration, datasets.MNIST)

initial_W = Network.W.clone().cpu().squeeze().numpy()
initial_W_mask = Network.W.clone().cpu().squeeze().numpy()

with open(os.path.join(filepath, 'init.pickle'), 'wb') as F:
    pickle.dump({
            'topology': topology,
            'hyperparameters': hyperparameters,
            'configuration': configuration,
            'dataset': Network.dataset.name,
            'training parameters': {'number of epochs': n_epochs, 'rate period': rate_period, 'correction_period': correction_period},
            'initial weight': initial_W,
            'initial mask': initial_W_mask}, F)

for epoch_idx in np.arange(n_epochs):
    print('Starting epoch %d.'%(epoch_idx+1))
    t0 = time.time()
    Network.train_epoch()
    Network.calculate_test_error()
    training_errors.append(Network.training_error)
    test_errors.append(Network.test_error)
    if epoch_idx % rate_period == 0:
        per_layer_rates.append([])
        for conn in Network.interlayer_connections:
            correction = torch.norm((Network.dW*conn)/torch.sqrt(torch.norm(conn, p=1)))
            per_layer_rates[-1].append(float(correction.cpu()))
    if epoch_idx % correction_period == 0:
        correction_matrices.append(Network.dW.clone().cpu().squeeze().numpy())
    print('\tDone.')
    print('\tTime taken:', (time.time()-t0))
    print('\tTraining error:', Network.training_error)
    print('\tTest error:', Network.test_error)
    with open(os.path.join(filepath, 'e%d.pickle'%(epoch_idx)), 'wb') as F:
        pickle.dump({
                'training error': Network.training_error,
                'test error': Network.test_error,
                'per-layer rates': per_layer_rates[-1] if (epoch_idx % rate_period == 0) else None,
                'correction matrix': correction_matrices[-1] if (epoch_idx % correction_period == 0) else None,
                'weight matrix': Network.W,
                'biases': Network.B,
                'states': Network.s,
                'persistent particles': Network.persistent_particles}, F)

final_W = Network.W.clone().cpu().squeeze().numpy()
final_W_mask = Network.W.clone().cpu().squeeze().numpy()
mean_dW = Network.mean_dW.clone().cpu().squeeze().numpy()

with open(os.path.join(filepath, 'final.pickle'), 'wb') as F:
    pickle.dump({
            'final weight': final_W,
            'final mask': final_W_mask,
            'mean dW': mean_dW}, F)
sys.stdout = old_stdout
log_file.close()
