Skip to content

Commit

Permalink
Refuctering for Circuit
Browse files Browse the repository at this point in the history
  • Loading branch information
iwawomaru committed Oct 5, 2016
1 parent df318ce commit 1851355
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 117 deletions.
6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

70 changes: 6 additions & 64 deletions examples/pong_planner_example.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,17 @@
import sys,os
import sys, os
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../')

from noh import Circuit
from noh.circuit import Planner, PropRule, TrainRule
from noh.components import Random, Const
from noh.environments import Pong

import numpy as np

n_stat = Pong.n_stat
n_act = Pong.n_act

component_set = []
component_set.append(Random(n_input=n_stat, n_output=n_act))
component_set.append(Const(n_input=n_stat, n_output=n_act, const_output=1))
component_set.append(Const(n_input=n_stat, n_output=n_act, const_output=2))
component_set.append(Const(n_input=n_stat, n_output=n_act, const_output=3))


class SimpleProp(PropRule):
component_id_list = range(4)
def __init__(self, components):
super(SimpleProp, self).__init__(components)
self.id = self.component_id_list.pop(0)

def __call__(self, data):
return self.components[self.id](data)

class EmplyProp(PropRule):
def __init__(self, components):
super(EmplyProp, self).__init__(components)
def __call__(self, **kwargs):
pass

class PFCPlanner(Planner):
def __init__(self, components, rule_dict={}, default_prop=None, default_train=None):
super(PFCPlanner, self).__init__(components, rule_dict, default_prop=None, default_train=None)
self.f_go = False
self.n_components = len(components)

def __call__(self, data):
if not self.f_go:
self.prop_rule = np.random.choice(self.rules.values())
self.f_go = True

""" kashikoku shitai here """
if np.random.rand() < 0.1:
self.stop()

return self.prop_rule(data)

def train(self, data=None, label=None, epoch=None):
pass

def stop(self):
self.f_go = False

def supervised_train(self, data=None, label=None, epochs=None, **kwargs): pass
def unsupervised_train(self, data=None, label=None, epochs=None, **kwargs): pass
def reinforcement_train(self, data=None, label=None, epochs=None, **kwargs): pass
from noh.components import SuppressionBoosting


if __name__ == "__main__":
prop_rules = {}
for i in xrange(4):
prop_rules["prop"+str(i)] = SimpleProp

n_stat = Pong.n_stat
n_act = Pong.n_act

model = Circuit(PFCPlanner, components=component_set, rule_dict=prop_rules,
default_prop=None, default_train=None)
model = SuppressionBoosting.create(n_stat, n_act, n_learner=4)

env = Pong(model, render=True)
while True:
env.execute()
env.execute()
2 changes: 0 additions & 2 deletions noh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
from noh import circuit

Circuit = circuit.Circuit
Planner = circuit.Planner
PropRule = circuit.PropRule
TrainRule = circuit.TrainRule

Component = component.Component
Environment = environment.Environment
76 changes: 28 additions & 48 deletions noh/circuit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from noh.component import Component


class Collection(object):
keys = []
values = []
Expand Down Expand Up @@ -49,14 +50,14 @@ def __delitem__(self, key):
def __iter__(self):
return iter(self.keys)

def __getslice__(self, i, j):
raise NotImplementedError("To be implemented")
# def __getslice__(self, i, j):
# raise NotImplementedError("To be implemented")

def __setslice__(self, i, j, values):
raise NotImplementedError("To be implemented")
# def __setslice__(self, i, j, values):
# raise NotImplementedError("To be implemented")

def __delslice__(self, i, j):
raise NotImplementedError("To be implemented")
# def __delslice__(self, i, j):
# raise NotImplementedError("To be implemented")

def __getattr__(self, key):
return self.__getitem__(key)
Expand All @@ -71,55 +72,34 @@ def __call__(self, data):
raise NotImplementedError("`__call__` must be explicitly overridden")


class TrainRule(Collection):
def __init__(self, components):
super(TrainRule, self).__init__(components)
self.components = components

def __call__(self, data, label, epoch):
raise NotImplementedError("`__call__` must be explicitly overridden")


class Planner(object):
def __init__(self, components, rule_dict={}, default_prop=None, default_train=None):
class Circuit(Collection, Component):
def __init__(self, components, RuleClassDict, default_prop_name=None, default_train_name=None):
super(Circuit, self).__init__(components)
self.components = components
self.rules = {}
if default_prop is not None:
self.rules["prop"] = default_prop(components)
self.prop_rule = self.rules['prop']
if default_train is not None:
self.rules["train"] = default_train(components)
self.train_rule = self.rules['train']

for name in rule_dict:
Rule = rule_dict[name]
self.rules[name] = Rule(components)
self.default_prop = None
self.default_train = None

def set_prop(self, name):
self.prop_rule = self.rules[name]
for name in RuleClassDict:
RuleClass = RuleClassDict[name]
self.rules[name] = RuleClass(components)

def set_train(self, name):
self.train_rule = self.rules[name]

def __call__(self, data):
return self.prop_rule(data)

def train(self, data, label, epoch):
return self.train_rule(data, label, epoch)


class Circuit(Collection, Component):
def __init__(self, PlannerClass, components, rule_dict,
default_prop=None, default_train=None):
super(Circuit, self).__init__(components)
self.planner = PlannerClass(components, rule_dict,
default_prop, default_train)
if default_prop_name is not None:
self.set_default_prop(default_prop_name)
if default_train_name is not None:
self.set_default_train(default_train_name)

def __call__(self, data, **kwargs):
return self.planner(data)
return self.default_prop(data)

def train(self, data, label, epochs):
return self.planner.train(data, label, epochs)
return self.default_train(data, label, epochs)

def set_default_prop(self, name):
self.default_prop = self.rules[name]

def set_default_train(self, name):
self.default_train = self.rules[name]

def __getattr__(self, key):
return self.planner.rules[key]
return self.rules[key]
3 changes: 2 additions & 1 deletion noh/components/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from noh.components.random_component import Random
from noh.components.const_component import Const
from noh.components.const_component import Const
from noh.components.suppression_boosting import SuppressionBoosting
65 changes: 63 additions & 2 deletions noh/components/suppression_boosting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,67 @@
from noh import Circuit
from noh.circuit import PropRule
from noh.components import Random, Const
import numpy as np


class SimpleProp(PropRule):
component_id_list = range(100)

def __init__(self, components):
super(SimpleProp, self).__init__(components)
self.id = self.component_id_list.pop(0)

def __call__(self, data):
return self.components[self.id](data)


class LearnerSet(Circuit):

def __init__(self, components, RuleClassDict):
super(LearnerSet, self).__init__(components, RuleClassDict)
self.n_components = len(components)
self.f_go = False

@classmethod
def create(cls, n_stat, n_act, n_learner):
component_list = [Random(n_input=n_stat, n_output=n_act)] + \
[Const(n_input=n_stat, n_output=n_act, const_output=n) for n in xrange(1, n_learner)]

PropRulesDict = {"prop"+str(i): SimpleProp for i in xrange(n_learner)}

return LearnerSet(component_list, PropRulesDict)


class PropLearner(PropRule):

name_list = []
def __init__(self, components):
super(PropLearner, self).__init__(components)

def __call__(self, data):
if not self.components["learner_set"].f_go:
self.components["learner_set"].f_go = True
self.components["learner_set"].set_default_prop(name=np.random.choice(self.name_list))

res = self.components["learner_set"](data)

""" kashikoku shitai here """
if np.random.rand() < 0.1:
self.components["learner_set"].f_go = False
return res


class SuppressionBoosting(Circuit):
def __init__(self, PlannerClass, components):
super(SuppressionBoosting, self).__init__()
def __init__(self, components, RuleClassDict):
super(SuppressionBoosting, self).__init__(components, RuleClassDict, default_prop_name="prop_learner")
self.f_go = False

@classmethod
def create(cls, n_stat, n_act, n_learner):
components = {"learner_set": LearnerSet.create(n_stat, n_act, n_learner),
"suppressor": None}
PropLearner.name_list = components["learner_set"].rules.keys()
return SuppressionBoosting(components, {"prop_learner": PropLearner})

def stop(self):
self.f_go = False

0 comments on commit 1851355

Please sign in to comment.