In [7]:
%%writefile mouseworld/weight_saver.py

from nengo.solvers import Solver
import nengo
import numpy as np

# loads a decoder from a file, defaulting to zero if it doesn't exist
class LoadFrom(Solver):
    def __init__(self, filename, weights=False):
        super(LoadFrom, self).__init__(weights=weights)
        self.filename = filename
            
    def __call__(self, A, Y, rng=None, E=None):
        if self.weights:
            shape = (A.shape[1], E.shape[1])
        else:
            shape = (A.shape[1], Y.shape[1])
            
        try:
            value = np.load(self.filename)
            assert value.shape == shape
        except IOError:
            value = np.zeros(shape)
        return value, {}

# helper to create the LoadFrom solver and the needed probe and do the saving
class WeightSaver(object):
    def __init__(self, connection, filename, sample_every=None, weights=False):
        assert isinstance(connection.pre, nengo.Ensemble)
        if not filename.endswith('.npy'):
            filename = filename + '.npy'
        self.filename = filename
        #connection.solver = LoadFrom(self.filename, weights=weights)
        self.probe = nengo.Probe(connection, 'weights', sample_every=sample_every)
        self.connection = connection
    def save(self, sim):
        np.save(self.filename, sim.data[self.probe][-1].T)

Overwriting mouseworld/weight_saver.py


In [1]:
%%writefile mouseworld/input_manager.py

class Input_manager(object):
    """because we need to contain state, the easier way to do that in
    Python is to make a class"""

    def __init__(self):
        self.state = 'wait'
        self.value = 0
        
#     def set_state(x):
#             if x == 'wait' :
#                 current_state = [-1,-1,-1]
#             if x == 'search' :
#                 current_state = [0,-1,-1]
#             if x == 'approach' :
#                 current_state = [-1,0,-1]
#             if x == 'avoid' :
#                 current_state = [-1,-1,0]
#             return current_state
        
    def modify_value(self, modifyer):
        """you can modify the state value or over-write it here
        or you can just modify the state parameter directly"""
        #print("Manage the input here.")
        self.value += modifyer

    def return_state(self, t):
        return self.state
    
    def return_value(self, t):
        return self.value
    
        

Overwriting mouseworld/input_manager.py


In [2]:
%%writefile mouseworld/mydatacollector.py

from mesa.datacollection import DataCollector

class MyDataCollector(DataCollector):
    ## subclass DataCollector to only collect data on certain agents
    ## in this case, I only report them if they are NOT alive
    ## self.alive is an attribute that I track for my agents
    def __init__(self, model_reporters={}, agent_reporters={}, tables={}):
        """ Instantiate a DataCollector with lists of model and agent reporters.
        Both model_reporters and agent_reporters accept a dictionary mapping a
        variable name to a method used to collect it.
        For example, if there was only one model-level reporter for number of
        agents, it might look like:
            {"agent_count": lambda m: m.schedule.get_agent_count() }
        If there was only one agent-level reporter (e.g. the agent's energy),
        it might look like this:
            {"energy": lambda a: a.energy}
        The tables arg accepts a dictionary mapping names of tables to lists of
        columns. For example, if we want to allow agents to write their age
        when they are destroyed (to keep track of lifespans), it might look
        like:
            {"Lifespan": ["unique_id", "age"]}
        Args:
            model_reporters: Dictionary of reporter names and functions.
            agent_reporters: Dictionary of reporter names and functions.
        """
        self.model_reporters = {}
        self.agent_reporters = {}

        self.model_vars = {}
        self.agent_vars = {}
        self.tables = {}

        for name, func in model_reporters.items():
            self._new_model_reporter(name, func)

        for name, func in agent_reporters.items():
            self._new_agent_reporter(name, func)

        for name, columns in tables.items():
            self._new_table(name, columns)

    
    def collect(self, model, schedule):
        """ Collect all the data for the given model object. """
        if self.model_reporters:
            for var, reporter in self.model_reporters.items():
                self.model_vars[var].append(reporter(model))

        if self.agent_reporters:
            for var, reporter in self.agent_reporters.items():
                agent_records = []
                #add an if clause to only append to agent records if our agent meets a certain condition
                for agent in schedule.agents:
                    agent_records.append((agent.unique_id, reporter(agent)))
                self.agent_vars[var].append(agent_records)
                
## When I define the datacollector for my model, I use MyDataCollector rather than the default DataCollector

Overwriting mouseworld/mydatacollector.py
