/
user.go
135 lines (108 loc) · 2.97 KB
/
user.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
package dal
import (
"database/sql"
"errors"
"fmt"
"github.com/jmoiron/sqlx"
"golang.org/x/crypto/bcrypt"
)
func NewUser(db *sqlx.DB) *User {
user := &User{}
user.db = db
user.table = "users"
user.hasID = true
return user
}
type UserRow struct {
ID int64 `db:"id"`
Email string `db:"email"`
Password string `db:"password"`
}
type User struct {
Base
}
func (u *User) userRowFromSqlResult(tx *sqlx.Tx, sqlResult sql.Result) (*UserRow, error) {
userId, err := sqlResult.LastInsertId()
if err != nil {
return nil, err
}
return u.GetById(tx, userId)
}
// AllUsers returns all user rows.
func (u *User) AllUsers(tx *sqlx.Tx) ([]*UserRow, error) {
users := []*UserRow{}
query := fmt.Sprintf("SELECT * FROM %v", u.table)
err := u.db.Select(&users, query)
return users, err
}
// GetById returns record by id.
func (u *User) GetById(tx *sqlx.Tx, id int64) (*UserRow, error) {
user := &UserRow{}
query := fmt.Sprintf("SELECT * FROM %v WHERE id=$1", u.table)
err := u.db.Get(user, query, id)
return user, err
}
// GetByEmail returns record by email.
func (u *User) GetByEmail(tx *sqlx.Tx, email string) (*UserRow, error) {
user := &UserRow{}
query := fmt.Sprintf("SELECT * FROM %v WHERE email=$1", u.table)
err := u.db.Get(user, query, email)
return user, err
}
// GetByEmail returns record by email but checks password first.
func (u *User) GetUserByEmailAndPassword(tx *sqlx.Tx, email, password string) (*UserRow, error) {
user, err := u.GetByEmail(tx, email)
if err != nil {
return nil, err
}
err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
if err != nil {
return nil, err
}
return user, err
}
// Signup create a new record of user.
func (u *User) Signup(tx *sqlx.Tx, email, password, passwordAgain string) (*UserRow, error) {
if email == "" {
return nil, errors.New("Email cannot be blank.")
}
if password == "" {
return nil, errors.New("Password cannot be blank.")
}
if password != passwordAgain {
return nil, errors.New("Password is invalid.")
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), 5)
if err != nil {
return nil, err
}
data := make(map[string]interface{})
data["email"] = email
data["password"] = hashedPassword
sqlResult, err := u.InsertIntoTable(tx, data)
if err != nil {
return nil, err
}
return u.userRowFromSqlResult(tx, sqlResult)
}
// UpdateEmailAndPasswordById updates user email and password.
func (u *User) UpdateEmailAndPasswordById(tx *sqlx.Tx, userId int64, email, password, passwordAgain string) (*UserRow, error) {
data := make(map[string]interface{})
if email != "" {
data["email"] = email
}
if password != "" && passwordAgain != "" && password == passwordAgain {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), 5)
if err != nil {
return nil, err
}
data["password"] = hashedPassword
}
if len(data) > 0 {
_, err := u.UpdateById(tx, data, userId)
if err != nil {
return nil, err
}
}
return u.GetById(tx, userId)
}