Skip to content

Commit

Permalink
fixes branches interacting with break, raise etc. in strictdefs (#22627)
Browse files Browse the repository at this point in the history
```nim
{.experimental: "strictdefs".}

type Test = object
  id: int

proc test(): Test =
  if true:
    return Test()
  else:
    return
echo test()
```

I will tackle #16735 and #21615 in
the following PR.


The old code just premises that in branches ended with returns, raise
statements etc. , all variables including the result variable are
initialized for that branch. It's true for noreturn statements. But it
is false for the result variable in a branch tailing with a return
statement, in which the result variable is not initialized. The solution
is not perfect for usages below branch statements with the result
variable uninitialized, but it should suffice for now, which gives a
proper warning.

It also fixes

```nim

{.experimental: "strictdefs".}

type Test = object
  id: int

proc foo {.noreturn.} = discard

proc test9(x: bool): Test =
  if x:
    foo()
  else:
    foo()
```
which gives a warning, but shouldn't
  • Loading branch information
ringabout committed Sep 4, 2023
1 parent c5495f4 commit d13aab5
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 20 deletions.
2 changes: 1 addition & 1 deletion compiler/lookups.nim
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ proc lookUp*(c: PContext, n: PNode): PSym =
if result == nil: result = errorUndeclaredIdentifierHint(c, n, ident)
else:
internalError(c.config, n.info, "lookUp")
return
return nil
if amb:
#contains(c.ambiguousSymbols, result.id):
result = errorUseQualifier(c, n.info, result, amb)
Expand Down
87 changes: 68 additions & 19 deletions compiler/sempass2.nim
Original file line number Diff line number Diff line change
Expand Up @@ -352,20 +352,25 @@ proc useVar(a: PEffects, n: PNode) =
a.init.add s.id
useVarNoInitCheck(a, n, s)

type
BreakState = enum
bsNone
bsBreakOrReturn
bsNoReturn

type
TIntersection = seq[tuple[id, count: int]] # a simple count table

proc addToIntersection(inter: var TIntersection, s: int, initOnly: bool) =
proc addToIntersection(inter: var TIntersection, s: int, state: BreakState) =
for j in 0..<inter.len:
if s == inter[j].id:
if not initOnly:
if state == bsNone:
inc inter[j].count
return
if initOnly:
inter.add((id: s, count: 0))
else:
if state == bsNone:
inter.add((id: s, count: 1))
else:
inter.add((id: s, count: 0))

proc throws(tracked, n, orig: PNode) =
if n.typ == nil or n.typ.kind != tyError:
Expand Down Expand Up @@ -469,7 +474,7 @@ proc trackTryStmt(tracked: PEffects, n: PNode) =
track(tracked, n[0])
dec tracked.inTryStmt
for i in oldState..<tracked.init.len:
addToIntersection(inter, tracked.init[i], false)
addToIntersection(inter, tracked.init[i], bsNone)

var branches = 1
var hasFinally = false
Expand Down Expand Up @@ -504,7 +509,7 @@ proc trackTryStmt(tracked: PEffects, n: PNode) =
tracked.init.add b[j][2].sym.id
track(tracked, b[^1])
for i in oldState..<tracked.init.len:
addToIntersection(inter, tracked.init[i], false)
addToIntersection(inter, tracked.init[i], bsNone)
else:
setLen(tracked.init, oldState)
track(tracked, b[^1])
Expand Down Expand Up @@ -673,15 +678,50 @@ proc trackOperandForIndirectCall(tracked: PEffects, n: PNode, formals: PType; ar
localError(tracked.config, n.info, $n & " is not GC safe")
notNilCheck(tracked, n, paramType)

proc breaksBlock(n: PNode): bool =

proc breaksBlock(n: PNode): BreakState =
# semantic check doesn't allow statements after raise, break, return or
# call to noreturn proc, so it is safe to check just the last statements
var it = n
while it.kind in {nkStmtList, nkStmtListExpr} and it.len > 0:
it = it.lastSon

result = it.kind in {nkBreakStmt, nkReturnStmt, nkRaiseStmt} or
it.kind in nkCallKinds and it[0].kind == nkSym and sfNoReturn in it[0].sym.flags
case it.kind
of nkBreakStmt, nkReturnStmt:
result = bsBreakOrReturn
of nkRaiseStmt:
result = bsNoReturn
of nkCallKinds:
if it[0].kind == nkSym and sfNoReturn in it[0].sym.flags:
result = bsNoReturn
else:
result = bsNone
else:
result = bsNone

proc addIdToIntersection(tracked: PEffects, inter: var TIntersection, resCounter: var int,
hasBreaksBlock: BreakState, oldState: int, resSym: PSym, hasResult: bool) =
if hasResult:
var alreadySatisfy = false

if hasBreaksBlock == bsNoReturn:
alreadySatisfy = true
inc resCounter

for i in oldState..<tracked.init.len:
if tracked.init[i] == resSym.id:
if not alreadySatisfy:
inc resCounter
alreadySatisfy = true
else:
addToIntersection(inter, tracked.init[i], hasBreaksBlock)
else:
for i in oldState..<tracked.init.len:
addToIntersection(inter, tracked.init[i], hasBreaksBlock)

template hasResultSym(s: PSym): bool =
s != nil and s.kind in {skProc, skFunc, skConverter, skMethod} and
not isEmptyType(s.typ[0])

proc trackCase(tracked: PEffects, n: PNode) =
track(tracked, n[0])
Expand All @@ -694,6 +734,10 @@ proc trackCase(tracked: PEffects, n: PNode) =
(tracked.config.hasWarn(warnProveField) or strictCaseObjects in tracked.c.features)
var inter: TIntersection = @[]
var toCover = 0
let hasResult = hasResultSym(tracked.owner)
let resSym = if hasResult: tracked.owner.ast[resultPos].sym else: nil
var resCounter = 0

for i in 1..<n.len:
let branch = n[i]
setLen(tracked.init, oldState)
Expand All @@ -703,13 +747,14 @@ proc trackCase(tracked: PEffects, n: PNode) =
for i in 0..<branch.len:
track(tracked, branch[i])
let hasBreaksBlock = breaksBlock(branch.lastSon)
if not hasBreaksBlock:
if hasBreaksBlock == bsNone:
inc toCover
for i in oldState..<tracked.init.len:
addToIntersection(inter, tracked.init[i], hasBreaksBlock)
addIdToIntersection(tracked, inter, resCounter, hasBreaksBlock, oldState, resSym, hasResult)

setLen(tracked.init, oldState)
if not stringCase or lastSon(n).kind == nkElse:
if hasResult and resCounter == n.len-1:
tracked.init.add resSym.id
for id, count in items(inter):
if count >= toCover: tracked.init.add id
# else we can't merge
Expand All @@ -723,14 +768,17 @@ proc trackIf(tracked: PEffects, n: PNode) =
addFact(tracked.guards, n[0][0])
let oldState = tracked.init.len

let hasResult = hasResultSym(tracked.owner)
let resSym = if hasResult: tracked.owner.ast[resultPos].sym else: nil
var resCounter = 0

var inter: TIntersection = @[]
var toCover = 0
track(tracked, n[0][1])
let hasBreaksBlock = breaksBlock(n[0][1])
if not hasBreaksBlock:
if hasBreaksBlock == bsNone:
inc toCover
for i in oldState..<tracked.init.len:
addToIntersection(inter, tracked.init[i], hasBreaksBlock)
addIdToIntersection(tracked, inter, resCounter, hasBreaksBlock, oldState, resSym, hasResult)

for i in 1..<n.len:
let branch = n[i]
Expand All @@ -743,13 +791,14 @@ proc trackIf(tracked: PEffects, n: PNode) =
for i in 0..<branch.len:
track(tracked, branch[i])
let hasBreaksBlock = breaksBlock(branch.lastSon)
if not hasBreaksBlock:
if hasBreaksBlock == bsNone:
inc toCover
for i in oldState..<tracked.init.len:
addToIntersection(inter, tracked.init[i], hasBreaksBlock)
addIdToIntersection(tracked, inter, resCounter, hasBreaksBlock, oldState, resSym, hasResult)

setLen(tracked.init, oldState)
if lastSon(n).len == 1:
if hasResult and resCounter == n.len:
tracked.init.add resSym.id
for id, count in items(inter):
if count >= toCover: tracked.init.add id
# else we can't merge as it is not exhaustive
Expand Down
64 changes: 64 additions & 0 deletions tests/init/tcompiles.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
discard """
matrix: "--warningAsError:ProveInit --warningAsError:Uninit"
"""

{.experimental: "strictdefs".}

type Test = object
id: int

proc foo {.noreturn.} = discard

block:
proc test(x: bool): Test =
if x:
foo()
else:
foo()

block:
proc test(x: bool): Test =
if x:
result = Test()
else:
foo()

discard test(true)

block:
proc test(x: bool): Test =
if x:
result = Test()
else:
return Test()

discard test(true)

block:
proc test(x: bool): Test =
if x:
return Test()
else:
return Test()

discard test(true)

block:
proc test(x: bool): Test =
if x:
result = Test()
else:
result = Test()
return

discard test(true)

block:
proc test(x: bool): Test =
if x:
result = Test()
return
else:
raise newException(ValueError, "unreachable")

discard test(true)
93 changes: 93 additions & 0 deletions tests/init/treturns.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
{.experimental: "strictdefs".}

type Test = object
id: int

proc foo {.noreturn.} = discard

proc test1(): Test =
if true: #[tt.Warning
^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
return Test()
else:
return

proc test0(): Test =
if true: #[tt.Warning
^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
return
else:
foo()

proc test2(): Test =
if true: #[tt.Warning
^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
return
else:
return

proc test3(): Test =
if true: #[tt.Warning
^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
return
else:
return Test()

proc test4(): Test =
if true: #[tt.Warning
^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
return
else:
result = Test()
return

proc test5(x: bool): Test =
case x: #[tt.Warning
^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
of true:
return
else:
return Test()

proc test6(x: bool): Test =
case x: #[tt.Warning
^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
of true:
return
else:
return

proc test7(x: bool): Test =
case x: #[tt.Warning
^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
of true:
return
else:
discard

proc test8(x: bool): Test =
case x: #[tt.Warning
^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
of true:
discard
else:
raise

proc hasImportStmt(): bool =
if false: #[tt.Warning
^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
return true
else:
discard

discard hasImportStmt()

block:
proc hasImportStmt(): bool =
if false: #[tt.Warning
^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
return true
else:
return

discard hasImportStmt()

0 comments on commit d13aab5

Please sign in to comment.