forked from tbepler/harness
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
45 lines (39 loc) · 1.5 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os
import math
import cPickle as pickle
import numpy as np
class Loader(object):
def __init__(self):
pass
def argparse(self, parser):
parser.add_argument('--model-inputs', dest='model_inputs', type=int, default=0, help='number of inputs to the model, 0 or less mean infer from data (default: 0)')
parser.add_argument('--model-outputs', dest='model_outputs', type=int, default=0, help='number of outputs from the model, 0 or less mean infer from data (default: 0)')
def __call__(self, path, args, n_in, n_out):
ext = os.path.splitext(path)[-1].lower()
if args.model_inputs > 0:
n_in = args.model_inputs
if args.model_outputs > 0:
n_out = args.model_outputs
if ext == '.py': #need to load model from python file
m = load_from_source(path, n_in, n_out)
else: #unpickle model from binary file
with open(path) as f:
m = pickle.load(f)
name,epoch = model_name_epoch(path)
return m, name, epoch
def load_from_source(path, n_in, n_out):
#import importlib
#i = importlib.import_module(path)
import imp
model = imp.load_source("model", path)
return model.model(n_in, n_out)
def model_name_epoch(path):
base = os.path.splitext(os.path.basename(path))[0]
splt = base.split('epoch')
if len(splt) > 1:
name = splt[0][:-1]
epoch = int(splt[1].split('_')[0])
else:
name = splt[0]
epoch = 0
return name,epoch