Skip to content

Commit

Permalink
Change Trainer.train to be a generator; this lets clients of the trai…
Browse files Browse the repository at this point in the history
…ners modify the model if needed during training.
  • Loading branch information
Leif Johnson committed Feb 7, 2014
1 parent 3312424 commit c025646
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 34 deletions.
41 changes: 26 additions & 15 deletions theanets/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _build_trainers(self, **kwargs):
for factory in self.args.optimize:
self.add_trainer(factory, **kwargs)

def add_trainer(self, factory, **kwargs):
def add_trainer(self, factory, *args, **kwargs):
'''Add a new trainer to this experiment.
Arguments
Expand All @@ -145,21 +145,25 @@ def add_trainer(self, factory, **kwargs):
A callable that creates a Trainer instance, or a string that maps to
a Trainer constructor.
Keyword arguments are passed directly to the trainer factory.
Remaining positional and keyword arguments are passed directly to the
trainer factory.
'''
args = (self.network, )
args = (self.network, ) + args
if isinstance(factory, str):
if factory.lower() in ('cg', 'bfgs', 'newton-cg'):
if factory.lower() in trainer.Scipy.METHODS:
args = (self.network, factory)
factory = trainer.Scipy
elif factory.lower().startswith('l'):
if len(args) == 1:
# use SGD trainer by default for individual layers
args += (trainer.SGD, )
factory = trainer.Layerwise
else:
factory = {
'hf': trainer.HF,
'layer': trainer.Layerwise,
'layerwise': trainer.Layerwise,
'sample': trainer.Sample,
'sgd': trainer.SGD,
}[factory.lower()]
factory = dict(
hf=trainer.HF,
sample=trainer.Sample,
sgd=trainer.SGD
)[factory.lower()]
kw = {}
kw.update(self.kwargs)
kw.update(kwargs)
Expand Down Expand Up @@ -195,7 +199,13 @@ def add_dataset(self, label, dataset, **kwargs):
self.datasets[label] = Dataset(*dataset, **kwargs)

def run(self, train=None, valid=None):
'''Run this experiment by training (and validating) a network.
'''Run this experiment by training and validating our network.
'''
for _ in self.train(train=train, valid=valid):
pass

def train(self, train=None, valid=None):
'''Train (and validate) our network.
Before calling this method, datasets will typically need to have been
added to the experiment by calling add_dataset(...). However, as a
Expand All @@ -214,9 +224,10 @@ class will contain the trained network parameters.
if valid is not None and 'valid' not in self.datasets:
self.add_dataset('valid', valid)
for trainer in self.trainers:
trainer.train(train_set=self.datasets['train'],
valid_set=self.datasets['valid'],
cg_set=self.datasets['cg'])
for _ in trainer.train(train_set=self.datasets['train'],
valid_set=self.datasets['valid'],
cg_set=self.datasets['cg']):
yield

def save(self, path):
'''Save the parameters in the network to a pickle file on disk.
Expand Down
48 changes: 29 additions & 19 deletions theanets/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(self, network, **kwargs):
self.best_params = [p.get_value().copy() for p in self.params]

def flat_to_arrays(self, x):
x = x.astype(self.dtype)
start = 0
arrays = []
for shape, count in zip(self.shapes, self.counts):
Expand All @@ -94,9 +95,6 @@ def update_params(self, targets):
for param, target in zip(self.params, targets):
param.set_value(target)

def train(self, train_set, valid_set=None, **kwargs):
raise NotImplementedError

def evaluate(self, iteration, valid_set):
costs = np.mean([self.f_eval(*x) for x in valid_set], axis=0)
improvement = self.best_cost - costs[0] > self.best_cost * self.min_improvement
Expand All @@ -108,12 +106,15 @@ def evaluate(self, iteration, valid_set):
marker = ' *'
cost_desc = ' '.join(
'%s=%.4f' % el for el in zip(self.cost_names, costs))
logging.info('SGD %i -- valid %s%s', iteration + 1, cost_desc, marker)
logging.info('validation %i %s%s', iteration + 1, cost_desc, marker)
if iteration - self.best_iter > self.patience:
raise PatienceElapsedError
if not improvement:
raise NoImprovementError

def train(self, train_set, valid_set=None, **kwargs):
raise NotImplementedError


class SGD(Trainer):
'''Stochastic gradient descent network trainer.'''
Expand Down Expand Up @@ -149,9 +150,9 @@ def train(self, train_set, valid_set=None, **kwargs):
costs = []
grads = []
try:
for costs_, grads_ in learn(train_set, velocities):
costs.append(costs_)
grads.append(grads_)
for c, g in learn(train_set, velocities):
costs.append(c)
grads.append(g)
except KeyboardInterrupt:
logging.info('interrupted!')
break
Expand All @@ -160,12 +161,14 @@ def train(self, train_set, valid_set=None, **kwargs):
'%s=%.4f' % el for el in
zip(self.cost_names, np.mean(costs, axis=0)))
grad_desc = ' '.join(
'%s=%.4f' % (p.name, x) for p, x in
'%s=%.2f' % (p.name, x) for p, x in
zip(self.params, np.mean(grads, axis=0)))
logging.info('SGD %i/%i @%.2e -- train %s -- grad %s',
logging.info('SGD %i/%i @%.2e %s (grad %s)',
i + 1, self.iterations, self.learning_rate,
cost_desc, grad_desc)

yield

self.update_params(self.best_params)

def _nag(self, train_set, velocities):
Expand Down Expand Up @@ -271,6 +274,8 @@ def _apply_delta(self, param, delta):
class Scipy(Trainer):
'''General trainer for neural nets using `scipy.optimize.minimize`.'''

METHODS = ('bfgs', 'cg', 'dogleg', 'newton-cg', 'trust-ncg')

def __init__(self, network, method, **kwargs):
super(Scipy, self).__init__(network, **kwargs)

Expand Down Expand Up @@ -303,21 +308,26 @@ def train(self, train_set, valid_set=None, **kwargs):
break
except NoImprovementError:
pass

try:
res = scipy.optimize.minimize(
fun=self.function_at,
jac=self.gradient_at,
x0=self.arrays_to_flat(self.best_params),
args=(train_set, ),
method=self.method,
options=dict(maxiter=self.validation_frequency),
options=dict(maxiter=self.validation_frequency, disp=1),
)
except KeyboardInterrupt:
logging.info('interrupted!')
break

logging.info('scipy %s %i/%i J=%.4f', self.method, i + 1, self.iterations, res.fun)
for p, a in zip(self.params, self.flat_to_arrays(res.x)):
p.set_value(a)

yield

self.update_params(self.best_params)


Expand Down Expand Up @@ -373,6 +383,7 @@ def __init__(self, network, **kwargs):
def train(self, train_set, valid_set=None, **kwargs):
self.update_params(self.opt.train(
train_set, kwargs['cg_set'], validation=valid_set, **self.kwargs))
yield


class Sample(Trainer):
Expand Down Expand Up @@ -424,6 +435,8 @@ def train(self, train_set, valid_set=None, **kwargs):
w.set_value(arr)
samples = ifci(self.network.feed_forward(first(t))[i-1] for t in train_set)

yield


class Layerwise(Trainer):
'''This trainer adapts parameters using a variant of layerwise pretraining.
Expand All @@ -442,8 +455,10 @@ class Layerwise(Trainer):
Network instances.
'''

def __init__(self, network, **kwargs):
def __init__(self, network, factory, *args, **kwargs):
self.network = network
self.factory = factory
self.args = args
self.kwargs = kwargs

def train(self, train_set, valid_set=None, **kwargs):
Expand All @@ -460,14 +475,9 @@ def train(self, train_set, valid_set=None, **kwargs):
self.network.hiddens = hiddens[:i]
self.network.weights = weights[:i] + [W]
self.network.biases = biases[:i] + [b]
SGD(self.network, **self.kwargs).train(train_set, valid_set)
self.network.save('/tmp/layerwise-%s-h%f-n%f-d%f-w%f-%d.pkl.gz' % (
','.join(map(str, self.kwargs['layers'])),
self.kwargs['hidden_l1'],
self.kwargs['input_noise'],
self.kwargs['hidden_dropouts'],
self.kwargs['weight_l1'],
i))
trainer = self.factory(self.network, *self.args, **self.kwargs)
for _ in trainer.train(train_set, valid_set):
yield

self.network.y = y
self.network.hiddens = hiddens
Expand Down

0 comments on commit c025646

Please sign in to comment.