Skip to content

Commit

Permalink
refactor: only return one target_df. it is either sdf or gdf, as appr…
Browse files Browse the repository at this point in the history
…opriate
  • Loading branch information
korenmiklos committed Jun 30, 2024
1 parent 89e12af commit 877e5cd
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 32 deletions.
2 changes: 1 addition & 1 deletion src/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ function generate_command(command::Command; options=[], allowed=[])
x
end
end)
GeneratedCommand(dfname, df2, sdf, gdf, Expr(:block, setup...), tdfunction, collect(process.(command.arguments)), collect(command.options))
GeneratedCommand(dfname, df2, target_df, Expr(:block, setup...), tdfunction, collect(process.(command.arguments)), collect(command.options))
end

get_by(command::Command) = get_option(command, :by)
Expand Down
48 changes: 19 additions & 29 deletions src/rewrites.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,27 @@ rewrite(command::Command) = rewrite(Val(command.command), command)

function rewrite(::Val{:tabulate}, command::Command)
gc = generate_command(command; options=[:variables, :ifable, :nofunction])
(; df, local_copy, sdf, gdf, setup, teardown, arguments, options) = gc
(; df, local_copy, target_df, setup, teardown, arguments, options) = gc
columns = [x[1] for x in extract_variable_references.(command.arguments)]
quote
$setup
Kezdi.tabulate($sdf, $columns) |> $teardown
Kezdi.tabulate($target_df, $columns) |> $teardown
end |> esc
end

function rewrite(::Val{:summarize}, command::Command)
gc = generate_command(command; options=[:variables, :ifable, :replace_variables, :single_argument, :nofunction])
(; df, local_copy, sdf, gdf, setup, teardown, arguments, options) = gc
(; df, local_copy, target_df, setup, teardown, arguments, options) = gc
column = extract_variable_references(command.arguments[1])
quote
$setup
Kezdi.summarize($sdf, $column[1]) |> $teardown
Kezdi.summarize($target_df, $column[1]) |> $teardown
end |> esc
end

function rewrite(::Val{:regress}, command::Command)
gc = generate_command(command; options=[:variables, :ifable], allowed=[:robust, :cluster])
(; df, local_copy, sdf, gdf, setup, teardown, arguments, options) = gc
(; df, local_copy, target_df, setup, teardown, arguments, options) = gc
if :robust in get_top_symbol.(options)
vcov = :(Vcov.robust())
elseif :cluster in get_top_symbol.(options)
Expand All @@ -36,16 +36,16 @@ function rewrite(::Val{:regress}, command::Command)
quote
$setup
if length($(arguments[2:end])) == 1
reg($sdf, @formula($(arguments[1]) ~ $(arguments[2])), $vcov) |> $teardown
reg($target_df, @formula($(arguments[1]) ~ $(arguments[2])), $vcov) |> $teardown
else
reg($sdf, @formula($(arguments[1]) ~ $(Expr(:call, :+, arguments[2:end]...))), $vcov) |> $teardown
reg($target_df, @formula($(arguments[1]) ~ $(Expr(:call, :+, arguments[2:end]...))), $vcov) |> $teardown
end
end |> esc
end

function rewrite(::Val{:generate}, command::Command)
gc = generate_command(command; options=[:single_argument, :variables, :ifable, :replace_variables, :vectorize, :assignment], allowed=[:by])
(; df, local_copy, sdf, gdf, setup, teardown, arguments, options) = gc
(; df, local_copy, target_df, setup, teardown, arguments, options) = gc
target_column = get_LHS(command.arguments[1])
LHS, RHS = split_assignment(arguments[1])
quote
Expand All @@ -54,15 +54,15 @@ function rewrite(::Val{:generate}, command::Command)
else
$setup
$local_copy[!, $target_column] .= missing
$sdf[!, $target_column] .= $RHS
$target_df[!, $target_column] .= $RHS
$local_copy |> $teardown
end
end |> esc
end

function rewrite(::Val{:replace}, command::Command)
gc = generate_command(command; options=[:single_argument, :variables, :ifable, :replace_variables, :vectorize, :assignment])
(; df, local_copy, sdf, gdf, setup, teardown, arguments, options) = gc
(; df, local_copy, target_df, setup, teardown, arguments, options) = gc
target_column = get_LHS(command.arguments[1])
LHS, RHS = split_assignment(arguments[1])
third_vector = gensym()
Expand All @@ -72,13 +72,13 @@ function rewrite(::Val{:replace}, command::Command)
ArgumentError("Column \"$($target_column)\" does not exist in $(names($df))") |> throw
else
$setup
if eltype($RHS) != eltype($sdf[!, $target_column])
if eltype($RHS) != eltype($target_df[!, $target_column])
local $third_vector = Vector{eltype($RHS)}(undef, nrow($local_copy))
$third_vector[$bitmask] .= $RHS
$third_vector[.!$bitmask] .= $local_copy[!, $target_column][.!$bitmask]
$local_copy[!, $target_column] = $third_vector
else
$sdf[!, $target_column] .= $RHS
$target_df[!, $target_column] .= $RHS
end
$local_copy |> $teardown
end
Expand All @@ -87,16 +87,16 @@ end

function rewrite(::Val{:keep}, command::Command)
gc = generate_command(command; options=[:variables, :ifable, :nofunction])
(; df, local_copy, sdf, gdf, setup, teardown, arguments, options) = gc
(; df, local_copy, target_df, setup, teardown, arguments, options) = gc
quote
$setup
$sdf[!, isempty($(command.arguments)) ? eval(:(:)) : collect($command.arguments)] |> $teardown
$target_df[!, isempty($(command.arguments)) ? eval(:(:)) : collect($command.arguments)] |> $teardown
end |> esc
end

function rewrite(::Val{:drop}, command::Command)
gc = generate_command(command; options=[:variables, :ifable, :nofunction])
(; df, local_copy, sdf, gdf, setup, teardown, arguments, options) = gc
(; df, local_copy, target_df, setup, teardown, arguments, options) = gc
if isnothing(command.condition)
return quote
$setup
Expand All @@ -112,13 +112,8 @@ end

function rewrite(::Val{:collapse}, command::Command)
gc = generate_command(command; options=[:variables, :ifable, :replace_variables, :vectorize, :assignment], allowed=[:by])
(; df, local_copy, sdf, gdf, setup, teardown, arguments, options) = gc
by_cols = get_by(command)
if isnothing(by_cols)
combine_epxression = Expr(:call, :combine, sdf, build_assignment_formula.(command.arguments)...)
else
combine_epxression = Expr(:call, :combine, gdf, build_assignment_formula.(command.arguments)...)
end
(; df, local_copy, target_df, setup, teardown, arguments, options) = gc
combine_epxression = Expr(:call, :combine, target_df, build_assignment_formula.(command.arguments)...)
quote
$setup
$combine_epxression |> $teardown
Expand All @@ -127,14 +122,9 @@ end

function rewrite(::Val{:egen}, command::Command)
gc = generate_command(command; options=[:variables, :ifable, :replace_variables, :vectorize, :assignment], allowed=[:by])
(; df, local_copy, sdf, gdf, setup, teardown, arguments, options) = gc
(; df, local_copy, target_df, setup, teardown, arguments, options) = gc
target_column = get_LHS(command.arguments[1])
by_cols = get_by(command)
if isnothing(by_cols)
transform_expression = Expr(:call, :transform!, sdf, build_assignment_formula.(command.arguments)...)
else
transform_expression = Expr(:call, :transform!, gdf, build_assignment_formula.(command.arguments)...)
end
transform_expression = Expr(:call, :transform!, target_df, build_assignment_formula.(command.arguments)...)
quote
if ($target_column in names($df))
ArgumentError("Column \"$($target_column)\" already exists in $(names($df))") |> throw
Expand Down
3 changes: 1 addition & 2 deletions src/structs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ end
struct GeneratedCommand
df::Any
local_copy::Symbol
sdf::Union{Symbol, Nothing}
gdf::Union{Symbol, Nothing}
target_df::Union{Symbol, Nothing}
setup::Expr
teardown::Symbol
arguments::Vector{Any}
Expand Down

0 comments on commit 877e5cd

Please sign in to comment.