Skip to content

Commit

Permalink
refactor agen to use plugins
Browse files Browse the repository at this point in the history
  • Loading branch information
raggledodo committed Dec 14, 2018
1 parent 19da982 commit 16b3c90
Show file tree
Hide file tree
Showing 14 changed files with 464 additions and 521 deletions.
2 changes: 0 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@ workspace(name = "com_github_mingkaic_tenncor")
# local dependencies

load("//:tenncor.bzl", "dependencies")

dependencies()

# test dependencies

load("@cppkg//:gtest.bzl", "gtest_repository")

gtest_repository(name = "gtest")
20 changes: 16 additions & 4 deletions age/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,18 @@ py_library(
srcs = glob(["templates/*.py"]),
)

py_library(
name = "age_generator",
srcs = glob(["generator/*.py"]),
deps = ["//age:age_tmpl"],
)

######### GENERATOR #########

py_binary(
name = "agen",
srcs = ["agen.py"],
deps = ["//age:age_tmpl"],
deps = ["//age:age_generator"],
visibility = ["//visibility:public"],
)

Expand All @@ -46,6 +52,12 @@ py_test(
deps = ["//age:age_tmpl"],
)

py_binary(
name = "cagen",
srcs = ["test/cagen.py", "test/capi_plugin.py"],
deps = ["//age:age_generator"],
)

genrule(
name = "generated_mock",
srcs = ["test/mock.json"],
Expand All @@ -61,10 +73,10 @@ genrule(
"generated/opmap.hpp",
"generated/opmap.cpp",
],
tools = ["//age:agen"],
cmd = "$(location //age:agen) " +
tools = ["//age:cagen"],
cmd = "$(location //age:cagen) " +
"--cfg $(location :test/mock.json) --out $(@D)/generated " +
"--strip_prefix=$$(dirname $(@D)) --gen_capi",
"--strip_prefix=$$(dirname $(@D))",
)

cc_test(
Expand Down
205 changes: 8 additions & 197 deletions age/agen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,189 +5,16 @@
import os.path
import sys

import age.templates.api_tmpl as api
import age.templates.capi_tmpl as capi
import age.templates.codes_tmpl as codes
import age.templates.grader_tmpl as grader
import age.templates.opera_tmpl as opera
import age.templates.template as template
from age.generator.generate import generate

prog_description = 'Generate c++ glue layer mapping ADE and some data-processing library.'
hdr_postfix = ".hpp"
src_postfix = ".cpp"

api_filename = "api"
codes_filename = "codes"
grader_filename = "grader"
opera_filename = "opmap"
runtime_filename = "runtime"

class Fields:
def __init__(self, fields):
self.fields = fields

def unmarshal_json(self, jobj):
outs = {}
for field in self.fields:
outtype, outholder = self.fields[field]
if field not in jobj:
continue
entry = jobj[field]
gottype = type(entry)
if gottype != outtype:
raise Exception("cannot read {} of type {} as type {}".format(\
field, gottype.__name__, outtype.__name__))
if str == type(outholder):
outs[outholder] = entry
else:
outs.update(outholder.unmarshal_json(entry))
return outs

root = Fields({
"opcodes": (dict, "opcodes"),
"dtypes": (dict, "dtypes"),
"data": (dict, Fields({
"sum": (unicode, "sum"),
"prod": (unicode, "prod"),
"data_in": (unicode, "data_in"),
"data_out": (unicode, "data_out"),
"scalarize": (unicode, "scalarize"),
})),
"apis": (list, "apis")
})

def parse(cfg_str):
args = json.loads(cfg_str)
if type(args) != dict:
raise Exception("cannot parse non-root object {}".format(cfg_str))

if 'includes' in args:
includes = args['includes']
if dict != type(includes):
raise Exception(\
"cannot read include of type {} as type dict".format(\
type(includes).__name__))
else:
includes = {}
return root.unmarshal_json(args), includes

def format_include(includes):
return '\n'.join(["#include " + include for include in includes]) + '\n\n'

def make_dir(fields, includes, includepath, gen_capi):
opcodes = fields["opcodes"]

code_fields = {
"opcodes": opcodes.keys(),
"dtypes": fields["dtypes"]
}
codes_header = codes_filename + hdr_postfix
codes_source = codes_filename + src_postfix
codes_hdr_path = os.path.join(includepath, codes_header)

codes_header_include = ["<string>"]
codes_source_include = [
"<unordered_map>",
'"logs/logs.hpp"',
'"' + codes_hdr_path + '"',
]
if codes_header in includes:
codes_header_include += includes[codes_header]
if codes_source in includes:
codes_source_include += includes[codes_source]

api_header = api_filename + hdr_postfix
api_source = api_filename + src_postfix
api_hdr_path = os.path.join(includepath, api_header)

api_header_include = [
'"bwd/grader.hpp"',
]
api_source_include = [
'"' + codes_hdr_path + '"',
'"' + api_hdr_path + '"',
]
if api_header in includes:
api_header_include += includes[api_header]
if api_source in includes:
api_source_include += includes[api_source]

grader_fields = {
"sum": fields["sum"],
"prod": fields["prod"],
"scalarize": fields["scalarize"],
"grads": {code: opcodes[code]["derivative"] for code in opcodes}
}
grader_header = grader_filename + hdr_postfix
grader_source = grader_filename + src_postfix
grader_hdr_path = os.path.join(includepath, grader_header)

grader_header_include = [
'"bwd/grader.hpp"',
'"' + codes_hdr_path + '"',
]
grader_source_include = [
'"' + codes_hdr_path + '"',
'"' + api_hdr_path + '"',
'"' + grader_hdr_path + '"',
]
if grader_header in includes:
grader_header_include += includes[grader_header]
if grader_source in includes:
grader_source_include += includes[grader_source]

opera_fields = {
"data_out": fields["data_out"],
"data_in": fields["data_in"],
"types": fields["dtypes"],
"ops": {code: opcodes[code]["operation"] for code in opcodes}
}
opera_header = opera_filename + hdr_postfix
opera_source = opera_filename + src_postfix
opera_hdr_path = os.path.join(includepath, opera_header)

opera_header_include = [
'"ade/functor.hpp"',
'"' + codes_hdr_path + '"',
]
opera_source_include = ['"' + opera_hdr_path + '"']
if opera_header in includes:
opera_header_include += includes[opera_header]
if opera_source in includes:
opera_source_include += includes[opera_source]

out = [
(api_header,
format_include(api_header_include) + api.header.repr(fields)),
(api_source,
format_include(api_source_include) + api.source.repr(fields)),
(codes_header,
format_include(codes_header_include) + codes.header.repr(code_fields)),
(codes_source,
format_include(codes_source_include) + codes.source.repr(code_fields)),
(grader_header,
format_include(grader_header_include) + grader.header.repr(grader_fields)),
(grader_source,
format_include(grader_source_include) + grader.source.repr(grader_fields)),
(opera_header,
format_include(opera_header_include) + opera.header.repr(opera_fields)),
(opera_source,
format_include(opera_source_include) + opera.source.repr(opera_fields)),
]
if gen_capi:
capi_header = 'c' + api_header
capi_source = 'c' + api_source
capi_hdr_path = os.path.join(includepath, capi_header)
capi_source_include = [
'<algorithm>',
'<unordered_map>',
'"' + api_hdr_path + '"',
'"' + capi_hdr_path + '"',
]
out.append((capi_header, capi.header.repr(fields)))
out.append((capi_source,
format_include(capi_source_include) + capi.source.repr(fields)))

return out
raise Exception('cannot parse non-root object {}'.format(cfg_str))
return args

def str2bool(opt):
optstr = opt.lower()
Expand All @@ -207,18 +34,8 @@ def main(args):
help='Directory path to dump output files (default: write to stdin)')
parser.add_argument('--strip_prefix', dest='strip_prefix', nargs='?', default='',
help='Directory path to dump output files (default: write to stdin)')
parser.add_argument('--gen_capi', dest='gen_capi',
type=str2bool, nargs='?', const=True, default=False,
help='Whether to generate C api or not (default: False)')
args = parser.parse_args(args)

outpath = args.outpath
strip_prefix = args.strip_prefix

includepath = outpath
if includepath and includepath.startswith(strip_prefix):
includepath = includepath[len(strip_prefix):].strip("/")

cfgpath = args.cfgpath
if cfgpath:
with open(str(cfgpath), 'r') as cfg:
Expand All @@ -227,18 +44,12 @@ def main(args):
raise Exception("cannot read from cfg file {}".format(cfgpath))
else:
cfg_str = sys.stdin.read()
fields, includes = parse(cfg_str)

directory = make_dir(fields, includes, includepath, args.gen_capi)
fields = parse(cfg_str)
outpath = args.outpath
strip_prefix = args.strip_prefix

if outpath:
for fname, content in directory:
with open(os.path.join(outpath, fname), 'w') as out:
out.write(content)
else:
for fname, content in directory:
print("============== %s ==============" % fname)
print(content)
generate(fields, outpath=outpath, strip_prefix=strip_prefix)

if '__main__' == __name__:
main(sys.argv[1:])
27 changes: 27 additions & 0 deletions age/generator/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
''' Reusable generator using plugin pattern '''

import os

import age.templates.template as template
import age.generator.internal as internal_plugin

def generate(fields, outpath = '', strip_prefix = '', plugins = [internal_plugin]):
includepath = outpath
if includepath and includepath.startswith(strip_prefix):
includepath = includepath[len(strip_prefix):].strip("/")

directory = {}
for plugin in plugins:
assert('process' in dir(plugin))
directory = plugin.process(directory, includepath, fields)

for akey in directory:
afile = directory[akey]
assert(isinstance(afile, template.AGE_FILE))
if outpath:
print(os.path.join(outpath, afile.fpath))
with open(os.path.join(outpath, afile.fpath), 'w') as out:
out.write(str(afile))
else:
print("============== %s ==============" % afile.fpath)
print(str(afile))
70 changes: 70 additions & 0 deletions age/generator/internal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
''' Internal plugin process '''

import os

import age.templates.api_tmpl as api
import age.templates.codes_tmpl as codes
import age.templates.grader_tmpl as grader
import age.templates.opera_tmpl as opera

def process(directory, relpath, fields):

api_hdr_path = os.path.join(relpath, api.header.fpath)
codes_hdr_path = os.path.join(relpath, codes.header.fpath)
grader_hdr_path = os.path.join(relpath, grader.header.fpath)
opera_hdr_path = os.path.join(relpath, opera.header.fpath)

# manitory headers
api.header.includes = [
'"bwd/grader.hpp"'
]
api.source.includes = [
'"' + codes_hdr_path + '"',
'"' + api_hdr_path + '"',
]

codes.header.includes = [
'<string>'
]
codes.source.includes = [
'<unordered_map>',
'"logs/logs.hpp"',
'"' + codes_hdr_path + '"',
]

grader.header.includes = [
'"bwd/grader.hpp"',
'"' + codes_hdr_path + '"',
]
grader.source.includes = [
'"' + codes_hdr_path + '"',
'"' + api_hdr_path + '"',
'"' + grader_hdr_path + '"',
]

opera.header.includes = [
'"ade/functor.hpp"',
'"' + codes_hdr_path + '"',
]
opera.source.includes = [
'"' + opera_hdr_path + '"'
]

directory = {
'api_hpp': api.header,
'api_src': api.source,
'codes_hpp': codes.header,
'codes_src': codes.source,
'grader_hpp': grader.header,
'grader_src': grader.source,
'opera_hpp': opera.header,
'opera_src': opera.source,
}

for akey in directory:
afile = directory[akey]
if 'includes' in fields and afile.fpath in fields['includes']:
afile.includes += fields['includes'][afile.fpath]
afile.process(fields)

return directory
Loading

0 comments on commit 16b3c90

Please sign in to comment.