diff --git a/src/cmd/compile/internal/ssa/fuse.go b/src/cmd/compile/internal/ssa/fuse.go index f191c7f9fd912..2647b841d7b90 100644 --- a/src/cmd/compile/internal/ssa/fuse.go +++ b/src/cmd/compile/internal/ssa/fuse.go @@ -6,43 +6,131 @@ package ssa // fuse simplifies control flow by joining basic blocks. func fuse(f *Func) { - for _, b := range f.Blocks { - if b.Kind != BlockPlain { - continue - } - c := b.Succs[0] - if len(c.Preds) != 1 { - continue + for changed := true; changed; { + changed = false + for _, b := range f.Blocks { + changed = fuseBlockIf(b) || changed + changed = fuseBlockPlain(b) || changed } + } +} - // move all of b's values to c. - for _, v := range b.Values { - v.Block = c - c.Values = append(c.Values, v) - } +// fuseBlockIf handles the following cases where s0 and s1 are empty blocks. +// +// b b b +// / \ | \ / | +// s0 s1 | s1 s0 | +// \ / | / \ | +// ss ss ss +// +// If ss doesn't contain any Phi ops and s0 & s1 are empty then the branch +// can be dropped. +// TODO: If ss doesn't contain any Phi ops, are s0 and s1 dead code anyway? +func fuseBlockIf(b *Block) bool { + if b.Kind != BlockIf { + return false + } + + var ss0, ss1 *Block + s0 := b.Succs[0] + if s0.Kind != BlockPlain || len(s0.Preds) != 1 || len(s0.Values) != 0 { + s0, ss0 = nil, s0 + } else { + ss0 = s0.Succs[0] + } + s1 := b.Succs[1] + if s1.Kind != BlockPlain || len(s1.Preds) != 1 || len(s1.Values) != 0 { + s1, ss1 = nil, s1 + } else { + ss1 = s1.Succs[0] + } - // replace b->c edge with preds(b) -> c - c.predstorage[0] = nil - if len(b.Preds) > len(b.predstorage) { - c.Preds = b.Preds - } else { - c.Preds = append(c.predstorage[:0], b.Preds...) + if ss0 != ss1 { + return false + } + ss := ss0 + + // TODO: Handle OpPhi operations. We can still replace OpPhi if the + // slots corresponding to b, s0 and s1 point to the same variable. + for _, v := range ss.Values { + if v.Op == OpPhi { + return false } - for _, p := range c.Preds { - for i, q := range p.Succs { - if q == b { - p.Succs[i] = c - } + } + + // Now we have two following b->ss, b->s0->ss and b->s1->ss, + // with s0 and s1 empty if exist. + // We can replace it with b->ss without if ss has no phis + // which is checked above. + // No critical edge is introduced because b will have one successor. + if s0 != nil { + ss.removePred(s0) + } + if s1 != nil { + ss.removePred(s1) + } + if s0 != nil && s1 != nil { + // Add an edge if both edges are removed, otherwise b is no longer connected to ss. + ss.Preds = append(ss.Preds, b) + } + b.Kind = BlockPlain + b.Control = nil + b.Succs = append(b.Succs[:0], ss) + + // Trash the empty blocks s0 & s1. + if s0 != nil { + s0.Kind = BlockInvalid + s0.Values = nil + s0.Succs = nil + s0.Preds = nil + } + if s1 != nil { + s1.Kind = BlockInvalid + s1.Values = nil + s1.Succs = nil + s1.Preds = nil + } + return true +} + +func fuseBlockPlain(b *Block) bool { + if b.Kind != BlockPlain { + return false + } + + c := b.Succs[0] + if len(c.Preds) != 1 { + return false + } + + // move all of b'c values to c. + for _, v := range b.Values { + v.Block = c + c.Values = append(c.Values, v) + } + + // replace b->c edge with preds(b) -> c + c.predstorage[0] = nil + if len(b.Preds) > len(b.predstorage) { + c.Preds = b.Preds + } else { + c.Preds = append(c.predstorage[:0], b.Preds...) + } + for _, p := range c.Preds { + for i, q := range p.Succs { + if q == b { + p.Succs[i] = c } } - if f.Entry == b { - f.Entry = c - } - - // trash b, just in case - b.Kind = BlockInvalid - b.Values = nil - b.Preds = nil - b.Succs = nil } + if f := b.Func; f.Entry == b { + f.Entry = c + } + + // trash b, just in case + b.Kind = BlockInvalid + b.Values = nil + b.Preds = nil + b.Succs = nil + return true } diff --git a/src/cmd/compile/internal/ssa/fuse_test.go b/src/cmd/compile/internal/ssa/fuse_test.go new file mode 100644 index 0000000000000..b6f6b82c35561 --- /dev/null +++ b/src/cmd/compile/internal/ssa/fuse_test.go @@ -0,0 +1,95 @@ +package ssa + +import ( + "testing" +) + +func TestFuseEliminatesOneBranch(t *testing.T) { + ptrType := &TypeImpl{Size_: 8, Ptr: true, Name: "testptr"} // dummy for testing + c := NewConfig("amd64", DummyFrontend{t}, nil, true) + fun := Fun(c, "entry", + Bloc("entry", + Valu("mem", OpInitMem, TypeMem, 0, nil), + Valu("sb", OpSB, TypeInvalid, 0, nil), + Goto("checkPtr")), + Bloc("checkPtr", + Valu("ptr1", OpLoad, ptrType, 0, nil, "sb", "mem"), + Valu("nilptr", OpConstNil, ptrType, 0, nil, "sb"), + Valu("bool1", OpNeqPtr, TypeBool, 0, nil, "ptr1", "nilptr"), + If("bool1", "then", "exit")), + Bloc("then", + Goto("exit")), + Bloc("exit", + Exit("mem"))) + + CheckFunc(fun.f) + fuse(fun.f) + + for _, b := range fun.f.Blocks { + if b == fun.blocks["then"] && b.Kind != BlockInvalid { + t.Errorf("then was not eliminated, but should have") + } + } +} + +func TestFuseEliminatesBothBranches(t *testing.T) { + ptrType := &TypeImpl{Size_: 8, Ptr: true, Name: "testptr"} // dummy for testing + c := NewConfig("amd64", DummyFrontend{t}, nil, true) + fun := Fun(c, "entry", + Bloc("entry", + Valu("mem", OpInitMem, TypeMem, 0, nil), + Valu("sb", OpSB, TypeInvalid, 0, nil), + Goto("checkPtr")), + Bloc("checkPtr", + Valu("ptr1", OpLoad, ptrType, 0, nil, "sb", "mem"), + Valu("nilptr", OpConstNil, ptrType, 0, nil, "sb"), + Valu("bool1", OpNeqPtr, TypeBool, 0, nil, "ptr1", "nilptr"), + If("bool1", "then", "else")), + Bloc("then", + Goto("exit")), + Bloc("else", + Goto("exit")), + Bloc("exit", + Exit("mem"))) + + CheckFunc(fun.f) + fuse(fun.f) + + for _, b := range fun.f.Blocks { + if b == fun.blocks["then"] && b.Kind != BlockInvalid { + t.Errorf("then was not eliminated, but should have") + } + if b == fun.blocks["else"] && b.Kind != BlockInvalid { + t.Errorf("then was not eliminated, but should have") + } + } +} + +func TestFuseEliminatesEmptyBlocks(t *testing.T) { + c := NewConfig("amd64", DummyFrontend{t}, nil, true) + fun := Fun(c, "entry", + Bloc("entry", + Valu("mem", OpInitMem, TypeMem, 0, nil), + Valu("sb", OpSB, TypeInvalid, 0, nil), + Goto("z0")), + Bloc("z1", + Goto("z2")), + Bloc("z3", + Goto("exit")), + Bloc("z2", + Goto("z3")), + Bloc("z0", + Goto("z1")), + Bloc("exit", + Exit("mem"), + )) + + CheckFunc(fun.f) + fuse(fun.f) + + for k, b := range fun.blocks { + if k[:1] == "z" && b.Kind != BlockInvalid { + t.Errorf("%s was not eliminated, but should have", k) + } + } +} diff --git a/src/cmd/compile/internal/ssa/nilcheck_test.go b/src/cmd/compile/internal/ssa/nilcheck_test.go index 14955e77d87a4..b90d11e5403d4 100644 --- a/src/cmd/compile/internal/ssa/nilcheck_test.go +++ b/src/cmd/compile/internal/ssa/nilcheck_test.go @@ -403,8 +403,11 @@ func TestNilcheckBug(t *testing.T) { Valu("bool2", OpIsNonNil, TypeBool, 0, nil, "ptr1"), If("bool2", "extra", "exit")), Bloc("extra", + // prevent fuse from eliminating this block + Valu("store", OpStore, TypeMem, 8, nil, "ptr1", "nilptr", "mem"), Goto("exit")), Bloc("exit", + Valu("phi", OpPhi, TypeMem, 0, nil, "mem", "store"), Exit("mem"))) CheckFunc(fun.f)