From b5ce7c467cd1d120536afd178d63c9e5c825f489 Mon Sep 17 00:00:00 2001 From: Michael Hobbs Date: Wed, 18 Aug 2021 15:46:44 -0700 Subject: [PATCH] ensure against NPE in tests --- conn_test.go | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/conn_test.go b/conn_test.go index 3e2f07b5..4ac3d2b8 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1849,8 +1849,11 @@ func TestConnPrepareContext(t *testing.T) { defer cancel() } _, err := db.PrepareContext(ctx, tt.sql) - if ((err != nil) != (tt.err != nil)) && (err.Error() != tt.err.Error()) { - t.Errorf("conn.PrepareContext() error = %v, expectedErr != %v", err, tt.err) + switch { + case (err != nil) != (tt.err != nil): + t.Fatalf("conn.PrepareContext() unexpected nil err got = %v, expected = %v", err, tt.err) + case (err != nil && tt.err != nil) && (err.Error() != tt.err.Error()): + t.Errorf("conn.PrepareContext() got = %v, expected = %v", err.Error(), tt.err.Error()) } }) } @@ -1902,8 +1905,11 @@ func TestStmtQueryContext(t *testing.T) { t.Fatal(err) } _, err = stmt.QueryContext(ctx) - if ((err != nil) != (tt.err != nil)) && (err.Error() != tt.err.Error()) { - t.Errorf("stmt.QueryContext() error = %v, expectedErr != %v", err, tt.err) + switch { + case (err != nil) != (tt.err != nil): + t.Fatalf("stmt.QueryContext() unexpected nil err got = %v, expected = %v", err, tt.err) + case (err != nil && tt.err != nil) && (err.Error() != tt.err.Error()): + t.Errorf("stmt.QueryContext() got = %v, expected = %v", err.Error(), tt.err.Error()) } }) } @@ -1955,8 +1961,11 @@ func TestStmtExecContext(t *testing.T) { t.Fatal(err) } _, err = stmt.ExecContext(ctx) - if ((err != nil) != (tt.err != nil)) && (err.Error() != tt.err.Error()) { - t.Errorf("stmt.ExecContext() error = %v, expectedErr != %v", err, tt.err) + switch { + case (err != nil) != (tt.err != nil): + t.Fatalf("stmt.ExecContext() unexpected nil err got = %v, expected = %v", err, tt.err) + case (err != nil && tt.err != nil) && (err.Error() != tt.err.Error()): + t.Errorf("stmt.ExecContext() got = %v, expected = %v", err.Error(), tt.err.Error()) } }) }