Skip to content

Commit

Permalink
Merge pull request #124 from bartvm/more_fixes
Browse files Browse the repository at this point in the history
Make Cost and CostMatrix abstract classses
  • Loading branch information
bartvm committed Jan 21, 2015
2 parents c152a8f + 181774e commit d7aa7cd
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 9 deletions.
16 changes: 14 additions & 2 deletions blocks/bricks/cost.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
from abc import ABCMeta, abstractmethod

import theano
from theano import tensor
from six import add_metaclass

from blocks.bricks import application, Brick


floatX = theano.config.floatX


@add_metaclass(ABCMeta)
class Cost(Brick):
pass
@abstractmethod
@application
def apply(self, y, y_hat):
pass


@add_metaclass(ABCMeta)
class CostMatrix(Cost):
"""Base class for costs which can be calculated element-wise.
Expand All @@ -22,6 +29,11 @@ def apply(self, y, y_hat):
return self.cost_matrix.application_method(
self, y, y_hat).sum(axis=1).mean()

@abstractmethod
@application
def cost_matrix(self, y, y_hat):
pass


class BinaryCrossEntropy(CostMatrix):
@application
Expand Down
2 changes: 1 addition & 1 deletion blocks/bricks/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def __init__(self, prototype, **kwargs):
super(Bidirectional, self).__init__(**kwargs)
self.prototype = prototype

self.children = [copy.deepcopy(prototype) for i in range(2)]
self.children = [copy.deepcopy(prototype) for _ in range(2)]
self.children[0].name = 'forward'
self.children[1].name = 'backward'

Expand Down
2 changes: 1 addition & 1 deletion blocks/bricks/sequence_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def _push_allocation_config(self):
self.lookup.dim = self.feedback_dim

@application
def feedback(self, outputs, **kwargs):
def feedback(self, outputs):
assert self.output_dim == 0
return self.lookup.lookup(outputs)

Expand Down
2 changes: 2 additions & 0 deletions blocks/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class Dataset(object):
simultaneously.
"""
sources = None

def __init__(self, sources=None):
if sources is not None:
if not all(source in self.sources for source in sources):
Expand Down
2 changes: 1 addition & 1 deletion blocks/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def read_mnist_labels(filename):
"""
with open(filename, 'rb') as f:
magic, number = struct.unpack('>ii', f.read(8))
magic, _ = struct.unpack('>ii', f.read(8))
if magic != MNIST_LABEL_MAGIC:
raise ValueError("Wrong magic number reading MNIST label file")
array = numpy.fromfile(f, dtype='uint8')
Expand Down
6 changes: 3 additions & 3 deletions blocks/groundhog.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ def inputs(self):

@property
def valid_costs(self):
return ['cost'] + [name for name, var in self.properties]
return ['cost'] + [name for name, _ in self.properties]

def validate(self, data):
valid_names = self.valid_costs
valid_vars = [self.train_cost] + [var for name, var in self.properties]
valid_vars = [self.train_cost] + [var for _, var in self.properties]

sums = numpy.zeros((len(valid_vars),))
num_batches = 0
Expand Down Expand Up @@ -115,7 +115,7 @@ def load(self, path):

class GroundhogState(object):
"""Good default values for groundhog state."""
def __init__(self, prefix, batch_size, learning_rate, **kwargs):
def __init__(self, prefix, batch_size, learning_rate):
self.prefix = prefix
self.bs = batch_size
self.lr = learning_rate
Expand Down
2 changes: 1 addition & 1 deletion blocks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def repr_attrs(instance, *attrs):
orig_repr_template += '>'
try:
return repr_template.format(instance, id(instance))
except:
except Exception:
return orig_repr_template.format(instance, id(instance))


Expand Down

0 comments on commit d7aa7cd

Please sign in to comment.