Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ workspace(name = "io_bazel_rules_scala")
load("//scala:scala.bzl", "scala_repositories", "scala_mvn_artifact")
scala_repositories()

load("//twitter_scrooge:twitter_scrooge.bzl", "twitter_scrooge", "scrooge_scala_library")
twitter_scrooge()

# test adding a scala jar:
maven_jar(
name = "com_twitter__scalding_date",
artifact = scala_mvn_artifact("com.twitter:scalding-date:0.16.0-RC4")
)
)
105 changes: 57 additions & 48 deletions scala/scala.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ touch -t 198001010000 {manifest}
progress_message="scala %s" % ctx.label,
arguments=[])

def _compile(ctx, jars, dep_srcjars, buildijar):
def _compile(ctx, _jars, dep_srcjars, buildijar):
jars = _jars
res_cmd = _add_resources_cmd(ctx)
ijar_cmd = ""
if buildijar:
Expand All @@ -75,14 +76,9 @@ def _compile(ctx, jars, dep_srcjars, buildijar):
out=ctx.outputs.jar.path,
ijar_out=ctx.outputs.ijar.path)

sources = []
srcjars = []
for f in ctx.files.srcs:
#TODO this is gross but we aren't given a good "filterNot"
if len(_srcjar_filetype.filter([f])) == 0:
sources.append(f)
else:
srcjars.append(f)
sources = _scala_filetype.filter(ctx.files.srcs)
srcjars = _srcjar_filetype.filter(ctx.files.srcs)
all_srcjars = set(srcjars + list(dep_srcjars))

# Set up the args to pass to scalac because they can be too long for bash
scalac_args_file = ctx.new_file(ctx.outputs.jar, ctx.outputs.jar.short_path + "scalac_args")
Expand All @@ -95,29 +91,22 @@ def _compile(ctx, jars, dep_srcjars, buildijar):
)
ctx.file_action(output = scalac_args_file, content = scalac_args)

all_srcjars = srcjars + list(dep_srcjars)
srcjar_cmd = ""
if len(all_srcjars) > 0:
srcjar_cmd = "\nmkdir -p {out}_tmp_expand_srcjars\n"
for srcjar in all_srcjars:
# Note: this is double escaped because we need to do one format call
# per each srcjar, but then we are going to include this in the bigger format
# call that is done to generate the full command

# Note: unzip has -o set (overriding files), and all of the files are unzipped into the same directory.
# I feel this is ok because everything should be from the same source tree, so it should have to be consistent
# (the whole point of bazel is resolving diamonds). That said, a good TODO might be to ensure that
# if there are duplicate files, they are identical (and error otherwise)

#TODO would like to be able to switch >/dev/null, -v, etc based on the user's settings
srcjar_cmd += """
rm -rf {{out}}_tmp_expand_srcjars
mkdir -p {{out}}_tmp_expand_srcjars
unzip -o {srcjar} -d {{out}}_tmp_expand_srcjars >/dev/null
echo " " >> {{out}}_args/files_from_jar
find {{out}}_tmp_expand_srcjars -type f -name "*.scala" >> {{out}}_args/files_from_jar
""".format(srcjar = srcjar.path)
srcjar_cmd += """find {out}_tmp_expand_srcjars -type f -name "*.scala" > {out}_args/files_from_jar\n"""

cmd = """
rm -rf {out}_tmp_expand_srcjars
rm -rf {out}_tmp
set -e
rm -rf {out}_args
Expand Down Expand Up @@ -148,6 +137,8 @@ rm -rf {out}_tmp
ctx.action(
inputs=list(jars) +
list(dep_srcjars) +
list(srcjars) +
list(sources) +
ctx.files.srcs +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be sources now? You partitioned that into sources and srcjars.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be all 3. ctx.files.srcs goes to sources and srcjars, and dep_srcjars came from ctx.attr.deps

ctx.files.resources +
ctx.files._jdk +
Expand Down Expand Up @@ -185,15 +176,16 @@ def write_manifest(ctx):
content = manifest)

def _write_launcher(ctx, jars):
classpath = ':'.join(["$0.runfiles/" + f.short_path for f in jars])
content = """#!/bin/bash
cd $0.runfiles
{java} -cp {cp} {name} "$@"
"""
content = content.format(
java=ctx.file._java.path,
name=ctx.attr.main_class,
deploy_jar=ctx.outputs.jar.path,
cp=":".join([j.short_path for j in jars]))
export CLASSPATH={classpath}
$0.runfiles/{java} {name} "$@"
""".format(
java=ctx.file._java.path,
name=ctx.attr.main_class,
deploy_jar=ctx.outputs.jar.path,
classpath=classpath,
)
ctx.file_action(
output=ctx.outputs.executable,
content=content)
Expand All @@ -214,11 +206,11 @@ def _write_test_launcher(ctx, jars):
output=ctx.outputs.executable,
content=content)

def _collect_srcjars(targets):
def collect_srcjars(targets):
srcjars = set()
for target in targets:
if hasattr(target, "srcjar"):
srcjars += [target.srcjar]
if hasattr(target, "srcjars"):
srcjars += [target.srcjars.srcjar]
return srcjars

def _collect_jars(targets):
Expand Down Expand Up @@ -251,7 +243,7 @@ def _collect_jars(targets):
def _lib(ctx, non_macro_lib):
# This will be used to pick up srcjars from non-scala library
# targets (like thrift code generation)
srcjars = _collect_srcjars(ctx.attr.deps)
srcjars = collect_srcjars(ctx.attr.deps)
jars = _collect_jars(ctx.attr.deps)
(cjars, rjars) = (jars.compiletime, jars.runtime)
write_manifest(ctx)
Expand All @@ -275,7 +267,25 @@ def _lib(ctx, non_macro_lib):
collect_data = True)
return struct(
scala = scalaattr,
runfiles=runfiles)
runfiles=runfiles,
# This is a free monoid given to the graph for the purpose of
# extensibility. This is necessary when one wants to create
# new targets which want to leverage a scala_library. For example,
# new_target1 -> scala_library -> new_target2. There might be
# information that new_target2 needs to get from new_target1,
# but we do not want to ohave to change scala_library to pass
# this information through. extra_information allows passing
# this information through, and it is up to the new_targets
# to filter and make sense of this information.
extra_information=_collect_extra_information(ctx.attr.deps),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to collect from exports as well? If not, can we comment why not?

)

def _collect_extra_information(targets):
r = []
for target in targets:
if hasattr(target, 'extra_information'):
r.extend(target.extra_information)
return r

def _scala_library_impl(ctx):
return _lib(ctx, True)
Expand Down Expand Up @@ -315,18 +325,17 @@ def _scala_test_impl(ctx):
_write_test_launcher(ctx, rjars)
return _scala_binary_common(ctx, cjars, rjars)

def implicit_deps():
return {
"_ijar": attr.label(executable=True, default=Label("@bazel_tools//tools/jdk:ijar"), single_file=True, allow_files=True),
"_scalac": attr.label(executable=True, default=Label("@scala//:bin/scalac"), single_file=True, allow_files=True),
"_scalalib": attr.label(default=Label("@scala//:lib/scala-library.jar"), single_file=True, allow_files=True),
"_scalaxml": attr.label(default=Label("@scala//:lib/scala-xml_2.11-1.0.4.jar"), single_file=True, allow_files=True),
"_scalasdk": attr.label(default=Label("@scala//:sdk"), allow_files=True),
"_scalareflect": attr.label(default=Label("@scala//:lib/scala-reflect.jar"), single_file=True, allow_files=True),
"_java": attr.label(executable=True, default=Label("@bazel_tools//tools/jdk:java"), single_file=True, allow_files=True),
"_jar": attr.label(executable=True, default=Label("@bazel_tools//tools/jdk:jar"), single_file=True, allow_files=True),
"_jdk": attr.label(default=Label("//tools/defaults:jdk"), allow_files=True),
}
_implicit_deps = {
"_ijar": attr.label(executable=True, default=Label("@bazel_tools//tools/jdk:ijar"), single_file=True, allow_files=True),
"_scalac": attr.label(executable=True, default=Label("@scala//:bin/scalac"), single_file=True, allow_files=True),
"_scalalib": attr.label(default=Label("@scala//:lib/scala-library.jar"), single_file=True, allow_files=True),
"_scalaxml": attr.label(default=Label("@scala//:lib/scala-xml_2.11-1.0.4.jar"), single_file=True, allow_files=True),
"_scalasdk": attr.label(default=Label("@scala//:sdk"), allow_files=True),
"_scalareflect": attr.label(default=Label("@scala//:lib/scala-reflect.jar"), single_file=True, allow_files=True),
"_java": attr.label(executable=True, default=Label("@bazel_tools//tools/jdk:java"), single_file=True, allow_files=True),
"_jar": attr.label(executable=True, default=Label("@bazel_tools//tools/jdk:jar"), single_file=True, allow_files=True),
"_jdk": attr.label(default=Label("//tools/defaults:jdk"), allow_files=True),
}

# Common attributes reused across multiple rules.
_common_attrs = {
Expand All @@ -345,7 +354,7 @@ scala_library = rule(
attrs={
"main_class": attr.string(),
"exports": attr.label_list(allow_files=False),
} + implicit_deps() + _common_attrs,
} + _implicit_deps + _common_attrs,
outputs={
"jar": "%{name}_deploy.jar",
"ijar": "%{name}_ijar.jar",
Expand All @@ -358,7 +367,7 @@ scala_macro_library = rule(
attrs={
"main_class": attr.string(),
"exports": attr.label_list(allow_files=False),
} + implicit_deps() + _common_attrs,
} + _implicit_deps + _common_attrs,
outputs={
"jar": "%{name}_deploy.jar",
"manifest": "%{name}_MANIFEST.MF",
Expand All @@ -369,7 +378,7 @@ scala_binary = rule(
implementation=_scala_binary_impl,
attrs={
"main_class": attr.string(mandatory=True),
} + implicit_deps() + _common_attrs,
} + _implicit_deps + _common_attrs,
outputs={
"jar": "%{name}_deploy.jar",
"manifest": "%{name}_MANIFEST.MF",
Expand All @@ -384,7 +393,7 @@ scala_test = rule(
"suites": attr.string_list(),
"_scalatest": attr.label(executable=True, default=Label("@scalatest//file"), single_file=True, allow_files=True),
"_scalatest_reporter": attr.label(default=Label("//scala/support:test_reporter")),
} + implicit_deps() + _common_attrs,
} + _implicit_deps + _common_attrs,
outputs={
"jar": "%{name}_deploy.jar",
"manifest": "%{name}_MANIFEST.MF",
Expand Down
19 changes: 19 additions & 0 deletions src/scala/scripts/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
load("//scala:scala.bzl", "scala_binary")

scala_binary(
name = "generator",
srcs = ["TwitterScroogeGenerator.scala"],
main_class = "scripts.ScroogeGenerator",
deps = [
"@scrooge_generator//jar",
"@util_core//jar",
"@util_logging//jar",
":scala_parsers",
],
visibility = ["//visibility:public"],
)

java_import(
name = "scala_parsers",
jars = ["@scala//:lib/scala-parser-combinators_2.11-1.0.4.jar"],
)
Loading