-
Notifications
You must be signed in to change notification settings - Fork 0
/
db.go
116 lines (99 loc) · 3.39 KB
/
db.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
package internal
import (
"errors"
"log"
"time"
"github.com/jtanza/post-pigeon/internal/model"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
)
type DB struct {
db *gorm.DB
}
func NewDB() DB {
db, err := gorm.Open(sqlite.Open(createDSN()), &gorm.Config{
Logger: logger.Default.LogMode(logger.Warn),
NamingStrategy: schema.NamingStrategy{
SingularTable: true,
},
})
if err != nil {
log.Fatal(err)
}
return DB{db}
}
// PersistPost derives a model.Post and model.PostContent from the provided request and persists them to the db
func (d DB) PersistPost(postUUID string, request model.PostRequest, html string, expiration *time.Time) error {
return d.db.Transaction(func(tx *gorm.DB) error {
fingerprint, err := Fingerprint(request.PublicKey)
if err != nil {
return err
}
post := model.Post{UUID: postUUID, Key: request.PublicKey, Fingerprint: fingerprint, ExpiresAt: expiration}
if postResult := tx.Create(&post); postResult.Error != nil {
return postResult.Error
}
postLocation := model.PostContent{
PostUUID: postUUID,
HTML: html,
Message: request.Body,
Title: request.Title,
}
if postLocationResult := tx.Create(&postLocation); postLocationResult.Error != nil {
return postLocationResult.Error
}
return nil
})
}
// DeletePost drops from the db the model.Post and model.PostContent associated with the postDeleteRequest
func (d DB) DeletePost(postDeleteRequest model.PostDeleteRequest) error {
return d.db.Transaction(func(tx *gorm.DB) error {
if postDelete := d.db.Unscoped().Where("uuid = ?", postDeleteRequest.UUID).Delete(&model.Post{}); postDelete.Error != nil {
return postDelete.Error
}
if postContentDelete := d.db.Unscoped().Where("post_uuid = ?", postDeleteRequest.UUID).Delete(&model.PostContent{}); postContentDelete.Error != nil {
return postContentDelete.Error
}
return nil
})
}
func (d DB) GetPostContent(postUUID string) (*model.PostContent, error) {
var postContent model.PostContent
if postQuery := d.db.Where("post_uuid = ?", postUUID).First(&postContent); postQuery.Error != nil {
if errors.Is(postQuery.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, postQuery.Error
}
return &postContent, nil
}
func (d DB) GetPost(postUUID string) (*model.Post, error) {
var post model.Post
if postQuery := d.db.Where("uuid = ?", postUUID).First(&post); postQuery.Error != nil {
if errors.Is(postQuery.Error, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, postQuery.Error
}
return &post, nil
}
// GetUserPosts returns all known posts published by the provided fingerprint
func (d DB) GetUserPosts(fingerprint string) ([]model.FullPost, error) {
var posts []model.FullPost
if postQuery := d.db.Model(&model.Post{}).Select("post.UUID, post.Key, post.Fingerprint, post.created_at, post_content.Title, post_content.HTML, post_content.Message").Joins("left join post_content on post.uuid = post_content.post_uuid").Where("post.fingerprint = ?", fingerprint).Scan(&posts); postQuery.Error != nil {
return nil, postQuery.Error
}
return posts, nil
}
func (d DB) DeleteExpiredPosts() (int64, error) {
postQuery := d.db.Unscoped().Model(&model.Post{}).Where("expires_at <= datetime('now')").Delete(&model.Post{})
if postQuery.Error != nil {
return 0, postQuery.Error
}
return postQuery.RowsAffected, nil
}
func createDSN() string {
return "file:postpigeon.db"
}