Skip to content

Commit

Permalink
Fix new sampling system for HDF5
Browse files Browse the repository at this point in the history
  • Loading branch information
calebwin committed Aug 18, 2022
1 parent 734f8b7 commit ebe4999
Show file tree
Hide file tree
Showing 22 changed files with 265 additions and 156 deletions.
2 changes: 1 addition & 1 deletion Banyan/src/Banyan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ using AWSS3
global BANYAN_API_ENDPOINT

# Account management
export configure
export configure, get_organization_id

# Cluster management
export Cluster,
Expand Down
85 changes: 54 additions & 31 deletions Banyan/src/location.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ mutable struct Location
sample_invalid::Bool
end

LOCATION_PATH_KWARG_NAMES = ["add_channelview"]

struct LocationPath
original_path::String
path::String
Expand All @@ -22,47 +24,60 @@ struct LocationPath
format_name::String
format_version::String

function LocationPath(path, format_name, format_version)
function LocationPath(path::Any, format_name::String, format_version::String; kwargs...)
LocationPath("lang_jl_$(hash(path))", format_name, format_version; kwargs...)
end
function LocationPath(path::String, format_name::String, format_version::String; kwargs...)
# This function is responsible for "normalizing" the path.
# If there are multiple path strings that are technically equivalent,
# this function should map them to the same string.
path_hash = hash(path)

# Add the kwargs to the path
path_res = deepcopy(path)
for (kwarg_name, kwarg_value) in kwargs
if kwarg_name in LOCATION_PATH_KWARG_NAMES
path_res *= "_$kwarg_name=$kwarg_value"
end
end

# Return the LocationPath
path_hash = hash(path_res)
new(
path,
path,
path_res,
path_res,
path_hash,
string(path_hash),
format_name,
format_version
)
end

function LocationPath(p::String; kwargs...)::LocationPath
if isempty(p)
return NO_LOCATION_PATH
end

format_name = get(kwargs, :format, "jl")
is_sample_format_arrow = format_name == "arrow"
if is_sample_format_arrow
return LocationPath(p, "arrow", get(kwargs, :format_version, "2"); kwargs...)
else
for table_format in TABLE_FORMATS
if occursin(table_format, p) || format_name == p
return LocationPath(p, "arrow", "2"; kwargs...)
end
end
end
LocationPath(p, "jl", get_julia_version(); kwargs...)
end

LocationPath(path) = LocationPath(path, "jl", get_julia_version())``
# TODO: Maybe make
end

# Functions with `LocationPath`s`

global TABLE_FORMATS = ["csv", "parquet", "arrow"]

function get_location_path_with_format(p::String; kwargs...)::LocationPath
if isempty(p)
return NO_LOCATION_PATH
end

format_name = get(kwargs, :format, "jl")
is_sample_format_arrow = format_name == "arrow"
if is_sample_format_arrow
return LocationPath(p, "arrow", get(kwargs, :format_version, "2"))
else
for table_format in TABLE_FORMATS
if occursin(table_format, p) || format_name == p
return LocationPath(p, "arrow", "2")
end
end
end
LocationPath(p, "jl", get_julia_version())
end

function get_sample_path_prefix(lp::LocationPath)
format_name_sep = !isempty(lp.format_name) ? "_" : ""
lp.path_hash * "_" * lp.format_name * format_name_sep * lp.format_version
Expand All @@ -85,7 +100,7 @@ function set_sampling_configs(d::Dict{LocationPath,SamplingConfig})
session_sampling_configs[_get_session_id_no_error()] = d
end

get_sampling_config(path=""; kwargs...) = get_sampling_config(get_location_path_with_format(path; kwargs...))
get_sampling_config(path=""; kwargs...) = get_sampling_config(LocationPath(path; kwargs...))
function get_sampling_configs()
global session_sampling_configs
session_sampling_configs[_get_session_id_no_error()]
Expand All @@ -98,7 +113,7 @@ get_sampling_config(l_path::LocationPath)::SamplingConfig =
# Getting sample rate

get_sample_rate(p::String=""; kwargs...) =
get_sample_rate(get_location_path_with_format(p; kwargs...))
get_sample_rate(LocationPath(p; kwargs...))
function parse_sample_rate(object_key)
parse(Int64, last(splitpath(object_key)))
end
Expand Down Expand Up @@ -139,14 +154,14 @@ end
# Checking for having metadata, samples

has_metadata(p::String=""; kwargs...) =
has_metadata(get_location_path_with_format(p; kwargs...))
has_metadata(LocationPath(p; kwargs...))
function has_metadata(l_path:: LocationPath)::Bool
println("In has_metadata, checking get_metadata_path(l_path)=$(get_metadata_path(l_path)) and banyan_metadata_bucket_name()=$(banyan_metadata_bucket_name())")
isfile(S3Path("s3://$(banyan_metadata_bucket_name())/$(get_metadata_path(l_path))"))
end

has_sample(p::String=""; kwargs...) =
has_sample(get_location_path_with_format(p; kwargs...))
has_sample(LocationPath(p; kwargs...))
function has_sample(l_path:: LocationPath)::Bool
sc = get_sampling_config(l_path)
banyan_sample_dir = S3Path("s3://$(banyan_samples_bucket_name())/$(get_sample_path_prefix(l_path))")
Expand Down Expand Up @@ -200,7 +215,7 @@ function get_metadata_local_path()
end

function get_samples_local_path()
p = joinpath(homedir(), ".banyan", "metadata")
p = joinpath(homedir(), ".banyan", "samples")
if !isdir(p)
mkpath(p)
end
Expand Down Expand Up @@ -306,6 +321,8 @@ function get_location_source(lp::LocationPath)::Tuple{Location,String,String}
"$(dayabbr(lm)), $(twodigit(day(lm))) $(monthabbr(lm)) $(year(lm)) $(twodigit(hour(lm))):$(twodigit(minute(lm))):$(twodigit(second(lm))) GMT"
sample_s3_path = "/$(banyan_samples_bucket_name())/$sample_path_prefix/$sample_rate"
try
@show sample_local_path
@show sample_s3_path
blob = s3("GET", sample_s3_path, Dict("headers" => Dict("If-Modified-Since" => if_modified_since_string)))
write(sample_local_path, seekstart(blob.io)) # This overwrites the existing file
final_local_sample_path = sample_local_path
Expand All @@ -330,10 +347,11 @@ function get_location_source(lp::LocationPath)::Tuple{Location,String,String}
end

# If no such sample is found, search the S3 bucket
banyan_samples_bucket = S3Path("s3://$(banyan_samples_bucket_name())")
banyan_samples_object_dir = joinpath(banyan_samples_bucket, sample_path_prefix)
if isempty(final_local_sample_path)
banyan_samples_bucket = S3Path("s3://$(banyan_samples_bucket_name())")
final_sample_rate = -1
banyan_samples_object_dir = joinpath(banyan_samples_bucket, sample_path_prefix)
@show readdir_no_error(banyan_samples_object_dir)
for object_key in readdir_no_error(banyan_samples_object_dir)
object_sample_rate = parse(Int64, object_key)
object_sample_rate_diff = abs(object_sample_rate - desired_sample_rate)
Expand All @@ -353,6 +371,7 @@ function get_location_source(lp::LocationPath)::Tuple{Location,String,String}
Path(final_local_sample_path)
)
end
@show readdir_no_error(banyan_samples_object_dir)
end

# Construct and return LocationSource
Expand All @@ -364,11 +383,15 @@ function get_location_source(lp::LocationPath)::Tuple{Location,String,String}
)
res_location.metadata_invalid = isempty(src_params)
res_location.sample_invalid = isempty(final_local_sample_path)
@show res_location
@show final_sample_rate
@show final_local_sample_path
final_sample_rate = isempty(final_local_sample_path) ? desired_sample_rate : final_sample_rate
@show desired_sample_rate
@show sample_local_dir
@show readdir(sample_local_dir)
println("At end of get_location_source with readdir_no_error(banyan_samples_object_dir)=$(readdir_no_error(banyan_samples_object_dir))")

(
res_location,
metadata_local_path,
Expand Down
30 changes: 23 additions & 7 deletions Banyan/src/locations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ getsamplenrows(totalnrows::Int64)::Int64 = begin
# eventually stored and updated in S3 on each write.

function invalidate_metadata(p; kwargs...)
lp = get_location_path_with_format(p; kwargs...)
lp = LocationPath(p; kwargs...)

# Delete locally
p = joinpath(homedir(), ".banyan", "metadata", get_metadata_path(lp))
Expand All @@ -310,7 +310,7 @@ function invalidate_metadata(p; kwargs...)
end
end
function invalidate_samples(p; kwargs...)
lp = get_location_path_with_format(p; kwargs...)
lp = LocationPath(p; kwargs...)

# Delete locally
samples_local_dir = joinpath(homedir(), ".banyan", "samples")
Expand Down Expand Up @@ -344,11 +344,8 @@ function partition(series, partition_size)
(series[i:min(i+(partition_size-1),end)] for i in 1:partition_size:length(series))
end
function invalidate_all_locations()
for subdir in ["samples", "metadata"]
local_dir = joinpath(homedir(), ".banyan", subdir)
if isdir(local_dir)
rm(local_dir; force=true, recursive=true)
end
for local_dir in [get_samples_local_path(), get_metadata_local_path()]
rm(local_dir; force=true, recursive=true)
end

# Delete from S3
Expand Down Expand Up @@ -435,9 +432,13 @@ function RemoteSource(
# Look at local and S3 caches of metadata and samples to attempt to
# construct a Location.
loc, local_metadata_path, local_sample_path = get_location_source(lp)
let banyan_samples_object_dir = S3Path("s3://banyan-samples-75c0f7151604587a83055278b28db83b/15117355623592221474_jl_1.8.0-beta3")
println("Before get_location_source with readdir_no_error(banyan_samples_object_dir)=$(readdir_no_error(banyan_samples_object_dir)) and loc.metadata_invalid=$(loc.metadata_invalid) and loc.sample_invalid=$(loc.sample_invalid)")
end
@show lp
@show get_sampling_configs()
@show local_sample_path
@show loc

res = if !loc.metadata_invalid && !loc.sample_invalid
# Case where both sample and parameters are valid
Expand All @@ -446,7 +447,19 @@ function RemoteSource(
loc
elseif loc.metadata_invalid && !loc.sample_invalid
# Case where parameters are invalid
let banyan_samples_object_dir = S3Path("s3://banyan-samples-75c0f7151604587a83055278b28db83b/15117355623592221474_jl_1.8.0-beta3")
println("Before offloaded with readdir_no_error(banyan_samples_object_dir)=$(readdir_no_error(banyan_samples_object_dir))")
end
let banyan_samples_bucket = S3Path("s3://banyan-samples-75c0f7151604587a83055278b28db83b")
println("Before offloaded with readdir_no_error(banyan_samples_bucket)=$(readdir_no_error(banyan_samples_bucket))")
end
new_loc = offloaded(_remote_source, lp, loc, args...; distributed=true)
let banyan_samples_object_dir = S3Path("s3://banyan-samples-75c0f7151604587a83055278b28db83b/15117355623592221474_jl_1.8.0-beta3")
println("After offloaded with readdir_no_error(banyan_samples_object_dir)=$(readdir_no_error(banyan_samples_object_dir))")
end
let banyan_samples_bucket = S3Path("s3://banyan-samples-75c0f7151604587a83055278b28db83b")
println("After offloaded with readdir_no_error(banyan_samples_bucket)=$(readdir_no_error(banyan_samples_bucket))")
end
Arrow.write(local_metadata_path, Arrow.Table(); metadata=new_loc.src_parameters)
@show new_loc
new_loc.sample.value = load_sample(local_sample_path)
Expand All @@ -471,5 +484,8 @@ function RemoteSource(

new_loc
end
let banyan_samples_object_dir = S3Path("s3://banyan-samples-75c0f7151604587a83055278b28db83b/15117355623592221474_jl_1.8.0-beta3")
println("At end of RemoteSource with readdir_no_error(banyan_samples_object_dir)=$(readdir_no_error(banyan_samples_object_dir))")
end
res
end
33 changes: 22 additions & 11 deletions Banyan/src/queues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ function get_next_message(
end
end
m_dict = m["ReceiveMessageResult"]["Message"]
@show m_dict["MessageId"]
@show m_dict["ReceiptHandle"]
if delete
SQS.delete_message(queue_url, m_dict["ReceiptHandle"]::String)
end
Expand Down Expand Up @@ -148,31 +150,40 @@ function send_to_client(value_id::ValueId, value, worker_memory_used = 0)
end
end

for (i, pm) in enumerate(message_ranges)
if i > 1
println("pm == partial_messages[i-1] = $(message[pm] == message[message_ranges[i-1]])")
end
end

# Launch asynchronous threads to send SQS messages
gather_q_url = gather_queue_url()
num_chunks = length(message_ranges)
@show num_chunks
if num_chunks > 1
@sync for i = 1:message_ranges
@sync for i = 1:num_chunks
@async begin
msg = Dict{String,Any}(
"kind" => "GATHER",
"value_id" => value_id,
"contents" => message[message_ranges[i]],
"worker_memory_used" => worker_memory_used,
"chunk_idx" => i,
"num_chunks" => num_chunks
)
msg_json = JSON.json(msg)
SQS.send_message(
msg_json,
JSON.json(
Dict{String,Any}(
"kind" => "GATHER",
"value_id" => value_id,
"contents" => message[message_ranges[i]],
"contents_length" => length(message[message_ranges[i]]),
"worker_memory_used" => worker_memory_used,
"chunk_idx" => i,
"num_chunks" => num_chunks
)
),
gather_q_url,
Dict(
"MessageGroupId" => string(i),
"MessageDeduplicationId" => generated_message_id * string(i)
)
)
@show i
@show message_ranges[i]
@show length(message[message_ranges[i]])
end
end
else
Expand Down
30 changes: 23 additions & 7 deletions Banyan/src/requests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,10 @@ function _partitioned_computation_concrete(fut::Future, destination::Location, n
partial_message, _ = sqs_receive_next_message(gather_queue, p, nothing, nothing)
chunk_idx = partial_message["chunk_idx"]
@show chunk_idx
partial_messages[chunk_idx] = message["contents"]
partial_messages[chunk_idx] = partial_message["contents"]
end
end
@show length.(partial_messages)
join(partial_messages)
else
message["contents"]
Expand Down Expand Up @@ -719,14 +720,29 @@ function offloaded(given_function::Function, args...; distributed::Bool = false)
@show num_chunks

whole_message_contents = if num_chunks > 1
partial_messages = Vector{String}(undef, num_chunks)
partial_messages = fill("", num_chunks)
partial_messages[message["chunk_idx"]] = message["contents"]
@sync for i = 1:num_remaining_chunks
@show message["chunk_idx"]
@sync for _ = 1:num_remaining_chunks
@async begin
partial_message, _ = sqs_receive_next_message(gather_queue, p, nothing, nothing)
chunk_idx = partial_message["chunk_idx"]
@show chunk_idx
partial_messages[chunk_idx] = message["contents"]
let partial_message = sqs_receive_next_message(gather_queue, p, nothing, nothing)[1]
chunk_idx = partial_message["chunk_idx"]
partial_messages[chunk_idx] = partial_message["contents"]
@show chunk_idx
@show length(partial_message["contents"])
@show partial_message["contents_length"]
@show length(partial_messages[chunk_idx])
@show last(partial_message["contents"], 20)
@show last(partial_messages[chunk_idx], 20)
@show length.(partial_messages)
end
end
end
# TODO: Fix this so that it gets the partial messages which are different lengths
@show length.(partial_messages)
for (i, pm) in enumerate(partial_messages)
if i > 1
println("pm == partial_messages[i-1] = $(pm == partial_messages[i-1])")
end
end
join(partial_messages)
Expand Down
2 changes: 1 addition & 1 deletion Banyan/src/samples.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ function configure_sampling(
)

session_id = _get_session_id_no_error()
lp = get_location_path_with_format(path; kwargs...)
lp = LocationPath(path; kwargs...)
sampling_configs = session_sampling_configs[session_id]
if for_all_locations
empty!(sampling_configs)
Expand Down
Loading

0 comments on commit ebe4999

Please sign in to comment.