Skip to content

Commit

Permalink
fix: the default value for structure fields may now depend on the str…
Browse files Browse the repository at this point in the history
…ucture parameters
  • Loading branch information
leodemoura committed Apr 22, 2022
1 parent 09dfd97 commit 66c82da
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 26 deletions.
10 changes: 10 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
Unreleased
---------

* (Fix) the default value for structure fields may now depend on the structure parameters. Example:
```lean
structure Something (i: Nat) where
n1: Nat := 1
n2: Nat := 1 + i
def s : Something 10 := {}
example : s.n2 = 11 := rfl
```

* Apply `rfl` theorems at the `dsimp` auxiliary method used by `simp`. `dsimp` can be used anywhere in an expression
because it preserves definitional equality.

Expand Down
69 changes: 43 additions & 26 deletions src/Lean/Elab/StructInst.lean
Original file line number Diff line number Diff line change
Expand Up @@ -255,22 +255,26 @@ def Field.isSimple {σ} : Field σ → Bool
| _ => false

inductive Struct where
| mk (ref : Syntax) (structName : Name) (fields : List (Field Struct)) (source : Source)
/- Remark: the field `params` is use for default value propagation. It is initially empty, and then set at `elabStruct`. -/
| mk (ref : Syntax) (structName : Name) (params : Array (Name × Expr)) (fields : List (Field Struct)) (source : Source)
deriving Inhabited

abbrev Fields := List (Field Struct)

def Struct.ref : Struct → Syntax
| ⟨ref, _, _, _⟩ => ref
| ⟨ref, _, _, _, _⟩ => ref

def Struct.structName : Struct → Name
| ⟨_, structName, _, _⟩ => structName
| ⟨_, structName, _, _, _⟩ => structName

def Struct.params : Struct → Array (Name × Expr)
| ⟨_, _, params, _, _⟩ => params

def Struct.fields : Struct → Fields
| ⟨_, _, fields, _⟩ => fields
| ⟨_, _, _, fields, _⟩ => fields

def Struct.source : Struct → Source
| ⟨_, _, _, s⟩ => s
| ⟨_, _, _, _, s⟩ => s

/-- `true` iff all fields of the given structure are marked as `default` -/
partial def Struct.allDefault (s : Struct) : Bool :=
Expand All @@ -287,7 +291,7 @@ def formatField (formatStruct : Struct → Format) (field : Field Struct) : Form
| FieldVal.default => "<default>"

partial def formatStruct : Struct → Format
| ⟨_, structName, fields, source⟩ =>
| ⟨_, structName, _, fields, source⟩ =>
let fieldsFmt := Format.joinSep (fields.map (formatField formatStruct)) ", "
match source with
| Source.none => "{" ++ fieldsFmt ++ "}"
Expand Down Expand Up @@ -356,18 +360,22 @@ private def mkStructView (stx : Syntax) (structName : Name) (source : Source) :
let first ← toFieldLHS fieldStx[0][0]
let rest ← fieldStx[0][1].getArgs.toList.mapM toFieldLHS
return { ref := fieldStx, lhs := first :: rest, val := FieldVal.term val : Field Struct }
return ⟨stx, structName, fields, source⟩
return ⟨stx, structName, #[], fields, source⟩

def Struct.modifyFieldsM {m : TypeType} [Monad m] (s : Struct) (f : Fields → m Fields) : m Struct :=
match s with
| ⟨ref, structName, fields, source⟩ => return ⟨ref, structName, (← f fields), source⟩
| ⟨ref, structName, params, fields, source⟩ => return ⟨ref, structName, params, (← f fields), source⟩

def Struct.modifyFields (s : Struct) (f : Fields → Fields) : Struct :=
Id.run <| s.modifyFieldsM f

def Struct.setFields (s : Struct) (fields : Fields) : Struct :=
s.modifyFields fun _ => fields

def Struct.setParams (s : Struct) (ps : Array (Name × Expr)) : Struct :=
match s with
| ⟨ref, structName, _, fields, source⟩ => ⟨ref, structName, ps, fields, source⟩

private def expandCompositeFields (s : Struct) : Struct :=
s.modifyFields fun fields => fields.map fun field => match field with
| { lhs := FieldLHS.fieldName ref (Name.str Name.anonymous _ _) :: rest, .. } => field
Expand Down Expand Up @@ -476,7 +484,7 @@ mutual
let field := fields.head!
match Lean.isSubobjectField? env s.structName fieldName with
| some substructName =>
let substruct := Struct.mk s.ref substructName substructFields s.source
let substruct := Struct.mk s.ref substructName #[] substructFields s.source
let substruct ← expandStruct substruct
pure { field with lhs := [field.lhs.head!], val := FieldVal.nested substruct }
| none => do
Expand Down Expand Up @@ -511,7 +519,7 @@ mutual
match Lean.isSubobjectField? env s.structName fieldName with
| some substructName => do
let addSubstruct : TermElabM Fields := do
let substruct := Struct.mk ref substructName [] s.source
let substruct := Struct.mk ref substructName #[] [] s.source
let substruct ← expandStruct substruct
addField (FieldVal.nested substruct)
match s.source with
Expand Down Expand Up @@ -546,21 +554,22 @@ end
structure CtorHeaderResult where
ctorFn : Expr
ctorFnType : Expr
instMVars : Array MVarId := #[]
instMVars : Array MVarId
params : Array (Name × Expr)

private def mkCtorHeaderAux : Nat → Expr → Expr → Array MVarId → TermElabM CtorHeaderResult
| 0, type, ctorFn, instMVars => pure { ctorFn := ctorFn, ctorFnType := type, instMVars := instMVars }
| n+1, type, ctorFn, instMVars => do
private def mkCtorHeaderAux : Nat → Expr → Expr → Array MVarId → Array (Name × Expr) → TermElabM CtorHeaderResult
| 0, type, ctorFn, instMVars, params => pure { ctorFn , ctorFnType := type, instMVars, params }
| n+1, type, ctorFn, instMVars, params => do
let type ← whnfForall type
match type with
| Expr.forallE _ d b c =>
| Expr.forallE paramName d b c =>
match c.binderInfo with
| BinderInfo.instImplicit =>
let a ← mkFreshExprMVar d MetavarKind.synthetic
mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a) (instMVars.push a.mvarId!)
mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a) (instMVars.push a.mvarId!) (params.push (paramName, a))
| _ =>
let a ← mkFreshExprMVar d
mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a) instMVars
mkCtorHeaderAux n (b.instantiate1 a) (mkApp ctorFn a) instMVars (params.push (paramName, a))
| _ => throwError "unexpected constructor type"

private partial def getForallBody : Nat → Expr → Option Expr
Expand All @@ -582,7 +591,7 @@ private def mkCtorHeader (ctorVal : ConstructorVal) (expectedType? : Option Expr
let us ← mkFreshLevelMVars ctorVal.levelParams.length
let val := Lean.mkConst ctorVal.name us
let type ← instantiateTypeLevelParams (ConstantInfo.ctorInfo ctorVal) us
let r ← mkCtorHeaderAux ctorVal.numParams type val #[]
let r ← mkCtorHeaderAux ctorVal.numParams type val #[] #[]
propagateExpectedType r.ctorFnType ctorVal.numFields expectedType?
synthesizeAppInstMVars r.instMVars r.ctorFn
pure r
Expand All @@ -605,7 +614,8 @@ def trySynthStructInstance? (s : Struct) (expectedType : Expr) : TermElabM (Opti
private partial def elabStruct (s : Struct) (expectedType? : Option Expr) : TermElabM (Expr × Struct) := withRef s.ref do
let env ← getEnv
let ctorVal := getStructureCtor env s.structName
let { ctorFn := ctorFn, ctorFnType := ctorFnType, .. } ← mkCtorHeader ctorVal expectedType?
-- We store the parameters at the resulting `Struct`. We use this information during default value propagation.
let { ctorFn, ctorFnType, params, .. } ← mkCtorHeader ctorVal expectedType?
let (e, _, fields) ← s.fields.foldlM (init := (ctorFn, ctorFnType, [])) fun (e, type, fields) field =>
match field.lhs with
| [FieldLHS.fieldName ref fieldName] => do
Expand Down Expand Up @@ -640,7 +650,7 @@ private partial def elabStruct (s : Struct) (expectedType? : Option Expr) : Term
cont (markDefaultMissing val) field
| _ => withRef field.ref <| throwFailedToElabField fieldName s.structName m!"unexpected constructor type{indentExpr type}"
| _ => throwErrorAt field.ref "unexpected unexpanded structure field"
pure (e, s.setFields fields.reverse)
return (e, s.setFields fields.reverse |>.setParams params)

namespace DefaultFields

Expand Down Expand Up @@ -724,21 +734,28 @@ partial def mkDefaultValueAux? (struct : Struct) : Expr → TermElabM (Option Ex
if c.binderInfo.isExplicit then
let fieldName := n
match getFieldValue? struct fieldName with
| none => pure none
| none => return none
| some val =>
let valType ← inferType val
if (← isDefEq valType d) then
mkDefaultValueAux? struct (b.instantiate1 val)
else
pure none
return none
else
let arg ← mkFreshExprMVar d
mkDefaultValueAux? struct (b.instantiate1 arg)
if let some (_, param) := struct.params.find? fun (paramName, param) => paramName == n then
-- Recall that we did not use to have support for parameter propagation here.
if (← isDefEq (← inferType param) d) then
mkDefaultValueAux? struct (b.instantiate1 param)
else
return none
else
let arg ← mkFreshExprMVar d
mkDefaultValueAux? struct (b.instantiate1 arg)
| e =>
if e.isAppOfArity ``id 2 then
pure (some e.appArg!)
return some e.appArg!
else
pure (some e)
return some e

def mkDefaultValue? (struct : Struct) (cinfo : ConstantInfo) : TermElabM (Option Expr) :=
withRef struct.ref do
Expand Down
7 changes: 7 additions & 0 deletions tests/lean/run/defaulValueParamIssue.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
structure Something (i: Nat) where
n1: Nat := 1
n2: Nat := 1 + i

def s : Something 10 := {}

example : s.n2 = 11 := rfl

0 comments on commit 66c82da

Please sign in to comment.