Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solver for continued learning between runs? #1143

Open
studywolf opened this issue Aug 11, 2016 · 0 comments
Open

Solver for continued learning between runs? #1143

studywolf opened this issue Aug 11, 2016 · 0 comments

Comments

@studywolf
Copy link
Collaborator

so, this is super hacked, but something along these lines would be nice, at least in some example somewhere (post-cleanup) so you can keep learning between runs without much overhead

import numpy as np
import time

import nengo
from nengo.solvers import Lstsq
class KeepLearningSolver(Lstsq):
    """ Loads in weights from a file if they exist, 
    otherwise returns weights from Lstsq solver """ 

    def __init__(self, filename, weights=False):
        super(KeepLearningSolver, self).__init__(weights=weights)
        self.filename = filename

    def __call__(self, A, Y, rng=None, E=None):
        import os 
        if os.path.isfile('./%s'%self.filename):
            print('Loading weights from %s'%self.filename)
            tstart = time.time()
            weights = np.load(self.filename)['weights'][-1].T
            info = {'rmses':'what no stop', 
                       'time':time.time() - tstart}
            if weights.shape[0] != A.shape[1] or weights.shape[1] != Y.shape[1]:
                raise Exception('Stored weights are not correct shape for this connection.')
        else:
            print('No weights file found, generating with Lstsq solver')
            weights, info = super(KeepLearningSolver, self).__call__(A, Y)

        return weights, info

model = nengo.Network(seed=1)
with model:
    stim = nengo.Node(lambda x: np.cos(x*2))

    a = nengo.Ensemble(n_neurons=500, dimensions=1)
    err = nengo.Ensemble(n_neurons=1, dimensions=1, 
            neuron_type=nengo.Direct())

    output = nengo.Ensemble(n_neurons=10, dimensions=1,
            neuron_type=nengo.Direct())

    nengo.Connection(stim, a)
    nengo.Connection(stim, err, 
            function=lambda x: x**2, transform=-1)
    nengo.Connection(output, err)

    learn_conn = nengo.Connection(a, output, 
            learning_rule_type=nengo.PES(learning_rate=1e-5), 
            solver=KeepLearningSolver('weights.npz'))
    nengo.Connection(err, learn_conn.learning_rule)

    probe_input = nengo.Probe(stim)
    probe_output = nengo.Probe(output, synapse=.01)
    probe_weights = nengo.Probe(learn_conn, 'weights', 
            sample_every=5) # in seconds

sim = nengo.Simulator(model)
sim.run(11)

np.savez_compressed('weights', weights=sim.data[probe_weights])

import matplotlib.pyplot as plt
plt.plot(sim.trange(), sim.data[probe_input])
plt.plot(sim.trange(), sim.data[probe_output])
plt.legend(['Input', 'Output'])
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Development

No branches or pull requests

1 participant