Skip to content

Commit

Permalink
Merge pull request #107 from tomalrussell/feature/sector_model_access…
Browse files Browse the repository at this point in the history
…_timesteps

Give SectorModels access to timesteps
  • Loading branch information
tomalrussell authored Nov 7, 2017
2 parents 81cd8b6 + 08fba3a commit dca12b0
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 4 deletions.
3 changes: 2 additions & 1 deletion smif/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,8 @@ def get_model_run_definition(args):

sector_model_builder = SectorModelBuilder(sector_model_config['name'])
LOGGER.debug("Sector model config: %s", sector_model_config)
sector_model_builder.construct(sector_model_config)
sector_model_builder.construct(sector_model_config,
model_run_config['timesteps'])
sector_model_object = sector_model_builder.finish()

sector_model_objects.append(sector_model_object)
Expand Down
1 change: 1 addition & 0 deletions smif/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(self, name):

self.regions = get_region_register()
self.intervals = get_interval_register()
self.timesteps = []

self.logger = getLogger(__name__)

Expand Down
5 changes: 3 additions & 2 deletions smif/model/sector_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@
approaches
"""
import importlib
import logging
import os
from abc import ABCMeta, abstractmethod
from collections import defaultdict

import importlib
from smif import StateData
from smif.convert.area import get_register as get_region_register
from smif.convert.interval import get_register as get_interval_register
Expand Down Expand Up @@ -304,7 +304,7 @@ def __init__(self, name, sector_model=None):
self.region_register = get_region_register()
self.logger = logging.getLogger(__name__)

def construct(self, sector_model_config):
def construct(self, sector_model_config, timesteps):
"""Constructs the sector model
Arguments
Expand All @@ -315,6 +315,7 @@ def construct(self, sector_model_config):
self.load_model(sector_model_config['path'], sector_model_config['classname'])
self._sector_model.name = sector_model_config['name']
self._sector_model.description = sector_model_config['description']
self._sector_model.timesteps = timesteps

self.add_inputs(sector_model_config['inputs'])
self.add_outputs(sector_model_config['outputs'])
Expand Down
4 changes: 3 additions & 1 deletion tests/model/test_sector_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,11 @@ def test_path_not_found(self):
def test_build_from_config(self, get_sector_model_config):
config = get_sector_model_config
builder = SectorModelBuilder('test_sector_model')
builder.construct(config)
timesteps = [2015, 2020]
builder.construct(config, timesteps)
sector_model = builder.finish()
assert sector_model.name == 'water_supply'
assert sector_model.timesteps == timesteps

actual = sector_model.as_dict()

Expand Down

0 comments on commit dca12b0

Please sign in to comment.