Skip to content

Commit

Permalink
Optimize never-mutated ref cell
Browse files Browse the repository at this point in the history
  • Loading branch information
minoki committed May 10, 2024
1 parent a5ad04a commit e8d2a09
Showing 1 changed file with 83 additions and 27 deletions.
110 changes: 83 additions & 27 deletions src/cps.sml
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,8 @@ structure CpsUsageAnalysis :> sig
datatype frequency = NEVER | ONCE | MANY
type usage = { call : frequency
, project : frequency
, ref_read : frequency
, ref_write : frequency
, other : frequency
, returnConts : CSyntax.CVarSet.set
, labels : (string option) Syntax.LabelMap.map
Expand All @@ -684,26 +686,35 @@ fun oneMore NEVER = ONCE
| oneMore (many as MANY) = many
type usage = { call : frequency
, project : frequency
, ref_read : frequency
, ref_write : frequency
, other : frequency
, returnConts : CSyntax.CVarSet.set
, labels : (string option) Syntax.LabelMap.map
}
type cont_usage = { direct : frequency, indirect : frequency }
val neverUsed : usage = { call = NEVER, project = NEVER, other = NEVER, returnConts = CSyntax.CVarSet.empty, labels = Syntax.LabelMap.empty }
val neverUsed : usage = { call = NEVER
, project = NEVER
, ref_read = NEVER
, ref_write = NEVER
, other = NEVER
, returnConts = CSyntax.CVarSet.empty
, labels = Syntax.LabelMap.empty
}
val neverUsedCont : cont_usage = { direct = NEVER, indirect = NEVER }
type usage_table = (usage ref) TypedSyntax.VIdTable.hash_table
type cont_usage_table = (cont_usage ref) CSyntax.CVarTable.hash_table
fun getValueUsage (table : usage_table, v)
= case TypedSyntax.VIdTable.find table v of
SOME r => !r
| NONE => { call = MANY, project = MANY, other = MANY, returnConts = CSyntax.CVarSet.empty, labels = Syntax.LabelMap.empty } (* unknown *)
| NONE => { call = MANY, project = MANY, ref_read = MANY, ref_write = MANY, other = MANY, returnConts = CSyntax.CVarSet.empty, labels = Syntax.LabelMap.empty } (* unknown *)
fun getContUsage (table : cont_usage_table, c)
= case CSyntax.CVarTable.find table c of
SOME r => !r
| NONE => { direct = MANY, indirect = MANY } (* unknown *)
fun useValue env (C.Var v) = (case TypedSyntax.VIdTable.find env v of
SOME r => let val { call, project, other, returnConts, labels } = !r
in r := { call = call, project = project, other = oneMore other, returnConts = returnConts, labels = labels }
SOME r => let val { call, project, ref_read, ref_write, other, returnConts, labels } = !r
in r := { call = call, project = project, ref_read = ref_read, ref_write = ref_write, other = oneMore other, returnConts = returnConts, labels = labels }
end
| NONE => ()
)
Expand All @@ -718,8 +729,8 @@ fun useValue env (C.Var v) = (case TypedSyntax.VIdTable.find env v of
| useValue _ (C.String16Const _) = ()
fun useValueAsCallee (env, cont, C.Var v)
= (case TypedSyntax.VIdTable.find env v of
SOME r => let val { call, project, other, returnConts, labels } = !r
in r := { call = oneMore call, project = project, other = other, returnConts = C.CVarSet.add (returnConts, cont), labels = labels }
SOME r => let val { call, project, ref_read, ref_write, other, returnConts, labels } = !r
in r := { call = oneMore call, project = project, ref_read = ref_read, ref_write = ref_write, other = other, returnConts = C.CVarSet.add (returnConts, cont), labels = labels }
end
| NONE => ()
)
Expand All @@ -734,13 +745,13 @@ fun useValueAsCallee (env, cont, C.Var v)
| useValueAsCallee (_, _, C.String16Const _) = ()
fun useValueAsRecord (env, label, result, C.Var v)
= (case TypedSyntax.VIdTable.find env v of
SOME r => let val { call, project, other, returnConts, labels } = !r
SOME r => let val { call, project, ref_read, ref_write, other, returnConts, labels } = !r
val result' = case result of
SOME (TypedSyntax.MkVId (name, _)) => SOME name
| NONE => NONE
fun mergeOption (x as SOME _, _) = x
| mergeOption (NONE, y) = y
in r := { call = call, project = oneMore project, other = other, returnConts = returnConts, labels = Syntax.LabelMap.insertWith mergeOption (labels, label, result') }
in r := { call = call, project = oneMore project, ref_read = ref_read, ref_write = ref_write, other = other, returnConts = returnConts, labels = Syntax.LabelMap.insertWith mergeOption (labels, label, result') }
end
| NONE => ()
)
Expand All @@ -753,6 +764,38 @@ fun useValueAsRecord (env, label, result, C.Var v)
| useValueAsRecord (_, _, _, C.Char16Const _) = ()
| useValueAsRecord (_, _, _, C.StringConst _) = ()
| useValueAsRecord (_, _, _, C.String16Const _) = ()
fun useValueAsRefRead (env, C.Var v)
= (case TypedSyntax.VIdTable.find env v of
SOME r => let val { call, project, ref_read, ref_write, other, returnConts, labels } = !r
in r := { call = call, project = project, ref_read = oneMore ref_read, ref_write = ref_write, other = other, returnConts = returnConts, labels = labels }
end
| NONE => ()
)
| useValueAsRefRead (_, C.Unit) = ()
| useValueAsRefRead (_, C.Nil) = ()
| useValueAsRefRead (_, C.BoolConst _) = ()
| useValueAsRefRead (_, C.IntConst _) = ()
| useValueAsRefRead (_, C.WordConst _) = ()
| useValueAsRefRead (_, C.CharConst _) = ()
| useValueAsRefRead (_, C.Char16Const _) = ()
| useValueAsRefRead (_, C.StringConst _) = ()
| useValueAsRefRead (_, C.String16Const _) = ()
fun useValueAsRefWrite (env, C.Var v)
= (case TypedSyntax.VIdTable.find env v of
SOME r => let val { call, project, ref_read, ref_write, other, returnConts, labels } = !r
in r := { call = call, project = project, ref_read = ref_read, ref_write = oneMore ref_write, other = other, returnConts = returnConts, labels = labels }
end
| NONE => ()
)
| useValueAsRefWrite (_, C.Unit) = ()
| useValueAsRefWrite (_, C.Nil) = ()
| useValueAsRefWrite (_, C.BoolConst _) = ()
| useValueAsRefWrite (_, C.IntConst _) = ()
| useValueAsRefWrite (_, C.WordConst _) = ()
| useValueAsRefWrite (_, C.CharConst _) = ()
| useValueAsRefWrite (_, C.Char16Const _) = ()
| useValueAsRefWrite (_, C.StringConst _) = ()
| useValueAsRefWrite (_, C.String16Const _) = ()
fun useContVarIndirect cenv (v : C.CVar) = (case C.CVarTable.find cenv v of
SOME r => let val { direct, indirect } = !r
in r := { direct = direct, indirect = oneMore indirect }
Expand All @@ -775,7 +818,9 @@ local
else
C.CVarTable.insert cenv (v, ref neverUsedCont)
in
fun goSimpleExp (env, _, _, _, _, C.PrimOp { primOp = _, tyargs = _, args }) = List.app (useValue env) args
fun goSimpleExp (env, _, _, _, _, C.PrimOp { primOp = FSyntax.PrimCall Primitives.Ref_set, tyargs = _, args = [r, v] }) = (useValueAsRefWrite (env, r); useValue env v)
| goSimpleExp (env, _, _, _, _, C.PrimOp { primOp = FSyntax.PrimCall Primitives.Ref_read, tyargs = _, args = [r] }) = useValueAsRefRead (env, r)
| goSimpleExp (env, _, _, _, _, C.PrimOp { primOp = _, tyargs = _, args }) = List.app (useValue env) args
| goSimpleExp (env, _, _, _, _, C.Record fields) = Syntax.LabelMap.app (useValue env) fields
| goSimpleExp (_, _, _, _, _, C.ExnTag { name = _, payloadTy = _ }) = ()
| goSimpleExp (env, _, _, _, result, C.Projection { label, record, fieldTypes = _ }) = useValueAsRecord (env, label, result, record)
Expand Down Expand Up @@ -1105,8 +1150,8 @@ and isDiscardableExp (env : value_info TypedSyntax.VIdMap.map, C.Let { decs, con
| isDiscardableExp (_, C.Unreachable) = false
datatype param_transform = KEEP | ELIMINATE | UNPACK of (C.Var * Syntax.Label) list
fun tryUnpackParam (ctx, usage) param = case CpsUsageAnalysis.getValueUsage (usage, param) of
{ call = NEVER, project = NEVER, other = NEVER, ... } => ELIMINATE
| { call = NEVER, project = _, other = NEVER, labels, ... } =>
{ call = NEVER, project = NEVER, ref_read = NEVER, ref_write = NEVER, other = NEVER, ... } => ELIMINATE
| { call = NEVER, project = _, ref_read = NEVER, ref_write = NEVER, other = NEVER, labels, ... } =>
UNPACK (Syntax.LabelMap.foldri (fn (label, optName, acc) =>
let val name = case optName of
SOME name => name
Expand All @@ -1119,8 +1164,8 @@ fun tryUnpackParam (ctx, usage) param = case CpsUsageAnalysis.getValueUsage (usa
| _ => KEEP
fun tryUnpackContParam (ctx, usage) (SOME param)
= (case CpsUsageAnalysis.getValueUsage (usage, param) of
{ call = NEVER, project = NEVER, other = NEVER, ... } => ELIMINATE
| { call = NEVER, project = _, other = NEVER, labels, ... } =>
{ call = NEVER, project = NEVER, ref_read = NEVER, ref_write = NEVER, other = NEVER, ... } => ELIMINATE
| { call = NEVER, project = _, ref_read = NEVER, ref_write = NEVER, other = NEVER, labels, ... } =>
UNPACK (Syntax.LabelMap.foldri (fn (label, optName, acc) =>
let val name = case optName of
SOME name => name
Expand Down Expand Up @@ -1148,8 +1193,8 @@ fun wraparound (P.INT, x : IntInf.int) = x
| wraparound (P.INT_INF, x) = x
fun min3 (x, y, z) = IntInf.min (x, IntInf.min (y, z))
fun max3 (x, y, z) = IntInf.max (x, IntInf.max (y, z))
fun simplifySimpleExp (_ : value_info TypedSyntax.VIdMap.map, C.Record _) = NOT_SIMPLIFIED
| simplifySimpleExp (env, C.PrimOp { primOp, tyargs = _, args })
fun simplifySimpleExp (usage, _ : value_info TypedSyntax.VIdMap.map, C.Record _) = NOT_SIMPLIFIED
| simplifySimpleExp (usage, env, C.PrimOp { primOp, tyargs = _, args })
= (case (primOp, args) of
(F.ListOp, []) => VALUE C.Nil (* empty list *)
| (F.PrimCall P.JavaScript_call, [f, C.Var args]) =>
Expand Down Expand Up @@ -1410,10 +1455,21 @@ fun simplifySimpleExp (_ : value_info TypedSyntax.VIdMap.map, C.Record _) = NOT_
VALUE (C.Char16Const (Int.fromLarge c))
else
NOT_SIMPLIFIED
| (F.PrimCall P.Ref_read, [C.Var v]) =>
let val u = CpsUsageAnalysis.getValueUsage (usage, v)
in case (#ref_write u, #other u) of
(CpsUsageAnalysis.NEVER, CpsUsageAnalysis.NEVER) =>
(case TypedSyntax.VIdMap.find (env, v) of
SOME { exp = SOME (C.PrimOp { primOp = F.PrimCall P.Ref_ref, tyargs = _, args = [initialValue] }), ... } =>
VALUE initialValue
| _ => NOT_SIMPLIFIED
)
| _ => NOT_SIMPLIFIED
end
| _ => NOT_SIMPLIFIED
)
| simplifySimpleExp (_, C.ExnTag _) = NOT_SIMPLIFIED
| simplifySimpleExp (env, C.Projection { label, record, fieldTypes = _ })
| simplifySimpleExp (usage, _, C.ExnTag _) = NOT_SIMPLIFIED
| simplifySimpleExp (usage, env, C.Projection { label, record, fieldTypes = _ })
= (case record of
C.Var v => (case TypedSyntax.VIdMap.find (env, v) of
SOME { exp = SOME (C.Record fields), ... } => (case Syntax.LabelMap.find (fields, label) of
Expand All @@ -1424,7 +1480,7 @@ fun simplifySimpleExp (_ : value_info TypedSyntax.VIdMap.map, C.Record _) = NOT_
)
| _ => NOT_SIMPLIFIED
)
| simplifySimpleExp (_, C.Abs { contParam = _, params = _, body = _ }) = NOT_SIMPLIFIED (* TODO: Try eta conversion *)
| simplifySimpleExp (usage, _, C.Abs { contParam = _, params = _, body = _ }) = NOT_SIMPLIFIED (* TODO: Try eta conversion *)
and simplifyDec (ctx : Context, usage : { usage : CpsUsageAnalysis.usage_table, rec_usage : CpsUsageAnalysis.usage_table, cont_usage : CpsUsageAnalysis.cont_usage_table, cont_rec_usage : CpsUsageAnalysis.cont_usage_table, dead_code_analysis : CpsDeadCodeAnalysis.usage }, appliedCont : C.CVar option) (dec, (env, cenv, subst, csubst, acc : C.Dec list))
= case dec of
C.ValDec { exp, result } =>
Expand All @@ -1435,7 +1491,7 @@ and simplifyDec (ctx : Context, usage : { usage : CpsUsageAnalysis.usage_table,
else
NONE
| NONE => NONE
in case simplifySimpleExp (env, exp) of
in case simplifySimpleExp (#usage usage, env, exp) of
VALUE v => let val () = #simplificationOccurred ctx := true
val subst = case result of
SOME result => TypedSyntax.VIdMap.insert (subst, result, v)
Expand All @@ -1453,8 +1509,8 @@ and simplifyDec (ctx : Context, usage : { usage : CpsUsageAnalysis.usage_table,
in case (exp, result) of
(C.Abs { contParam, params, body }, SOME result) =>
(case CpsUsageAnalysis.getValueUsage (#usage usage, result) of
{ call = NEVER, project = NEVER, other = NEVER, returnConts = _, labels = _ } => (env, cenv, subst, csubst, acc)
| { call = ONCE, project = NEVER, other = NEVER, returnConts = _, labels = _ } =>
{ call = NEVER, project = NEVER, ref_read = NEVER, ref_write = NEVER, other = NEVER, returnConts = _, labels = _ } => (env, cenv, subst, csubst, acc)
| { call = ONCE, project = NEVER, ref_read = NEVER, ref_write = NEVER, other = NEVER, returnConts = _, labels = _ } =>
let val body = simplifyCExp (ctx, env, cenv, subst, csubst, usage, body)
val env = TypedSyntax.VIdMap.insert (env, result, { exp = SOME (C.Abs { contParam = contParam, params = params, body = body }), isDiscardableFunction = isDiscardableExp (env, body) })
in (env, cenv, subst, csubst, acc)
Expand All @@ -1464,7 +1520,7 @@ and simplifyDec (ctx : Context, usage : { usage : CpsUsageAnalysis.usage_table,
NONE
else
case u of
{ call = _, project = NEVER, other = NEVER, ... } =>
{ call = _, project = NEVER, ref_read = NEVER, ref_write = NEVER, other = NEVER, ... } =>
let val t = List.map (tryUnpackParam (ctx, #usage usage)) params
in if List.exists (fn ELIMINATE => true | UNPACK _ => true | KEEP => false) t then
SOME t
Expand Down Expand Up @@ -1548,9 +1604,9 @@ and simplifyDec (ctx : Context, usage : { usage : CpsUsageAnalysis.usage_table,
if List.exists (fn (f, _, _, _) => CpsDeadCodeAnalysis.isUsed (#dead_code_analysis usage, f)) defs then
let fun transform ((f, k, params, body), (env, acc))
= let val shouldTransformParams = case CpsUsageAnalysis.getValueUsage (#usage usage, f) of
{ call = _, project = NEVER, other = NEVER, ... } =>
{ call = _, project = NEVER, ref_read = NEVER, ref_write = NEVER, other = NEVER, ... } =>
(case CpsUsageAnalysis.getValueUsage (#rec_usage usage, f) of
{ call = _, project = NEVER, other = NEVER, ... } =>
{ call = _, project = NEVER, ref_read = NEVER, ref_write = NEVER, other = NEVER, ... } =>
let val t = List.map (tryUnpackParam (ctx, #usage usage)) params
in if List.exists (fn ELIMINATE => true | UNPACK _ => true | KEEP => false) t then
SOME t
Expand Down Expand Up @@ -1719,7 +1775,7 @@ and simplifyDec (ctx : Context, usage : { usage : CpsUsageAnalysis.usage_table,
| NONE =>
let val body = simplifyCExp (ctx, env, cenv, subst, csubst, usage, body)
val params = List.map (fn SOME p => (case CpsUsageAnalysis.getValueUsage (#usage usage, p) of
{ call = NEVER, project = NEVER, other = NEVER, ... } => NONE
{ call = NEVER, project = NEVER, ref_read = NEVER, ref_write = NEVER, other = NEVER, ... } => NONE
| _ => SOME p
)
| NONE => NONE
Expand Down Expand Up @@ -1827,9 +1883,9 @@ and simplifyCExp (ctx : Context, env : value_info TypedSyntax.VIdMap.map, cenv :
val subst = ListPair.foldlEq (fn (p, a, subst) => TypedSyntax.VIdMap.insert (subst, p, a)) subst (params, args)
val csubst = C.CVarMap.insert (csubst, contParam, cont)
val canOmitAlphaConversion = case CpsUsageAnalysis.getValueUsage (#usage usage, applied) of
{ call = ONCE, project = NEVER, other = NEVER, returnConts = _, labels = _ } =>
{ call = ONCE, project = NEVER, ref_read = NEVER, ref_write = NEVER, other = NEVER, returnConts = _, labels = _ } =>
(case CpsUsageAnalysis.getValueUsage (#rec_usage usage, applied) of
{ call = NEVER, project = NEVER, other = NEVER, returnConts = _, labels = _ } => true
{ call = NEVER, project = NEVER, ref_read = NEVER, ref_write = NEVER, other = NEVER, returnConts = _, labels = _ } => true
| _ => false
)
| _ => false
Expand Down

0 comments on commit e8d2a09

Please sign in to comment.