diff --git a/README.md b/README.md index 65deb13..1fa5d17 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,20 @@ +# shortener + +Веб-сервис сокращения URL + +| Путь | Метод | Описание | +|--------------------|--------|--------------------------------------| +| / | POST | Добавление ссылки | +| /:id | GET | Получение ссылки | +| /api/shorten | POST | Добавление ссылки в JSON | +| /api/shorten/batch | POST | Пакетное добавление ссылок | +| /api/user/urls | GET | Получение списка ссылок пользователя | +| /api/user/urls | DELETE | Удаление ссылок пользователя | +| /ping | GET | Проверка соединения с БД | + + + + # go-musthave-shortener-tpl Шаблон репозитория для практического трека «Go в веб-разработке». diff --git a/cmd/shortener/README.md b/cmd/shortener/README.md deleted file mode 100644 index d5b535d..0000000 --- a/cmd/shortener/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# cmd/shortener - -В данной директории будет содержаться код, который скомпилируется в бинарное приложение \ No newline at end of file diff --git a/cmd/shortener/shortener.go b/cmd/shortener/shortener.go index aef193f..ec8e072 100644 --- a/cmd/shortener/shortener.go +++ b/cmd/shortener/shortener.go @@ -11,10 +11,12 @@ func main() { var serverAddr string var baseURL string var dbFile string + var dbCredentials string flag.StringVar(&serverAddr, "a", os.Getenv("SERVER_ADDRESS"), "server address") flag.StringVar(&baseURL, "b", os.Getenv("BASE_URL"), "base URL") flag.StringVar(&dbFile, "f", os.Getenv("FILE_STORAGE_PATH"), "file storage path") + flag.StringVar(&dbCredentials, "d", os.Getenv("DATABASE_DSN"), "database credentials") flag.Parse() if serverAddr == "" { @@ -25,7 +27,7 @@ func main() { baseURL = "http://" + serverAddr } - err := app.Start(serverAddr, baseURL, dbFile) + err := app.Start(serverAddr, baseURL, dbFile, dbCredentials) if err != nil { panic(err) } diff --git a/go.mod b/go.mod index 45105cc..d3565e9 100644 --- a/go.mod +++ b/go.mod @@ -15,8 +15,11 @@ require ( ) require ( + github.com/google/uuid v1.3.0 + github.com/jmoiron/sqlx v1.3.5 github.com/labstack/echo v3.3.10+incompatible github.com/labstack/gommon v0.3.1 // indirect + github.com/lib/pq v1.10.7 github.com/mattn/go-colorable v0.1.11 // indirect github.com/mattn/go-isatty v0.0.14 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect diff --git a/go.sum b/go.sum index c855a40..c3b428e 100644 --- a/go.sum +++ b/go.sum @@ -16,11 +16,16 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= +github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= +github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= github.com/jrick/logrotate v1.0.0/go.mod h1:LNinyqDIJnpAur+b8yyulnQw/wDuN1+BYKlTRt3OuAQ= github.com/kkdai/bstream v0.0.0-20161212061736-f391b8402d23/go.mod h1:J+Gs4SYgM6CZQHDETBtE9HaSEkGmuNXF86RwHhHUvq4= github.com/labstack/echo v3.3.10+incompatible h1:pGRcYk231ExFAyoAjAfD85kQzRJCRI8bbnE7CX5OEgg= @@ -29,10 +34,14 @@ github.com/labstack/echo/v4 v4.8.0 h1:wdc6yKVaHxkNOEdz4cRZs1pQkwSXPiRjq69yWP4QQS github.com/labstack/echo/v4 v4.8.0/go.mod h1:xkCDAdFCIf8jsFQ5NnbK7oqaF/yU1A1X20Ltm0OvSks= github.com/labstack/gommon v0.3.1 h1:OomWaJXm7xR6L1HmEtGyQf26TEn7V6X88mktX9kee9o= github.com/labstack/gommon v0.3.1/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= +github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-colorable v0.1.11 h1:nQ+aFkoE2TMGc0b68U2OKSexC+eq46+XwZzWXHRmPYs= github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= diff --git a/internal/app/README.md b/internal/app/README.md deleted file mode 100644 index ba14e13..0000000 --- a/internal/app/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# internal/app - -В данной директории будет содержаться имплементация вашего сервиса \ No newline at end of file diff --git a/internal/app/app.go b/internal/app/app.go index 65e5d2c..8b443a6 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -9,10 +9,10 @@ import ( "github.com/labstack/echo/v4/middleware" ) -func Start(serverAddr, baseURL, dbFile string) error { - h, err := handler.New(serverAddr, baseURL, dbFile) +func Start(serverAddr, baseURL, dbFile, dbCredentials string) error { + h, err := handler.New(serverAddr, baseURL, dbFile, dbCredentials) if err != nil { - return fmt.Errorf("handler: %v", err) + return fmt.Errorf("handler: %w", err) } e := echo.New() @@ -22,6 +22,10 @@ func Start(serverAddr, baseURL, dbFile string) error { e.POST("/", h.CreateURL) e.GET("/:id", h.RetrieveURL) e.POST("/api/shorten", h.CreateURLInJSON) + e.POST("/api/shorten/batch", h.CreateBatchURL) + e.GET("/api/user/urls", h.ListURL) + e.DELETE("/api/user/urls", h.DeleteURL) + e.GET("/ping", h.Ping) e.Logger.Fatal(e.Start(serverAddr)) diff --git a/internal/handler/handler.go b/internal/handler/handler.go index b76106b..2bfa39b 100644 --- a/internal/handler/handler.go +++ b/internal/handler/handler.go @@ -2,6 +2,9 @@ package handler import ( "crypto" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -12,57 +15,82 @@ import ( "github.com/combodga/Project/internal/storage" "github.com/btcsuite/btcutil/base58" + "github.com/google/uuid" "github.com/labstack/echo/v4" ) type Handler struct { - ServerAddr string - BaseURL string - Storage *storage.Storage + ServerAddr string + BaseURL string + Storage *storage.Storage + DBCredentials string + Key string } -func New(serverAddr, baseURL, dbFile string) (*Handler, error) { - s, err := storage.New(dbFile) +type Link struct { + Result string `json:"result"` +} + +type LinkJSON struct { + CorrelationID string `json:"correlation_id"` + OriginalURL string `json:"original_url"` +} + +type BatchLink struct { + CorrelationID string `json:"correlation_id"` + ShortURL string `json:"short_url"` +} + +type Element struct { + ShortURL string `json:"short_url"` + OriginalURL string `json:"original_url"` +} + +func New(serverAddr, baseURL, dbFile, dbCredentials string) (*Handler, error) { + s, err := storage.New(dbFile, dbCredentials) if err != nil { - err = fmt.Errorf("storage: %v", err) + err = fmt.Errorf("new storage: %w", err) } return &Handler{ - ServerAddr: serverAddr, - BaseURL: baseURL, - Storage: s, + ServerAddr: serverAddr, + BaseURL: baseURL, + Storage: s, + DBCredentials: dbCredentials, + Key: "b8ffa0f4-3f11-44b1-b0bf-9109f47e468b", }, err } -type Link struct { - Result string `json:"result"` -} - func (h *Handler) CreateURL(c echo.Context) error { + user := getUser(c, h.Key) body, err := io.ReadAll(c.Request().Body) if err != nil { - return err + return fmt.Errorf("read request body: %w", err) } link := string(body) - id, err := h.fetchID(c, link) + id, err := h.fetchID(c, user, link) + if errors.Is(err, h.Storage.ErrDupKey) { + return c.String(http.StatusConflict, h.BaseURL+"/"+id) + } if err != nil { - return fmt.Errorf("fetch id: %v", err) + return fmt.Errorf("fetch id: %w", err) } return c.String(http.StatusCreated, h.BaseURL+"/"+id) } func (h *Handler) CreateURLInJSON(c echo.Context) error { + user := getUser(c, h.Key) body, err := io.ReadAll(c.Request().Body) if err != nil { - return err + return fmt.Errorf("read request body: %w", err) } data := make(map[string]string) err = json.Unmarshal(body, &data) if err != nil { - return err + return fmt.Errorf("json unmarshal: %w", err) } link, ok := data["url"] @@ -70,29 +98,140 @@ func (h *Handler) CreateURLInJSON(c echo.Context) error { return errors.New("error reading json") } - id, err := h.fetchID(c, link) + uniqueErr := false + id, err := h.fetchID(c, user, link) + if errors.Is(err, h.Storage.ErrDupKey) { + err = nil + uniqueErr = true + } if err != nil { - return fmt.Errorf("fetchID: %v", err) + return fmt.Errorf("fetch id: %w", err) } l := &Link{ Result: h.BaseURL + "/" + id, } + + if uniqueErr { + return c.JSON(http.StatusConflict, l) + } + return c.JSON(http.StatusCreated, l) } +func (h *Handler) CreateBatchURL(c echo.Context) error { + user := getUser(c, h.Key) + body, err := io.ReadAll(c.Request().Body) + if err != nil { + return fmt.Errorf("read request body: %w", err) + } + + var l []LinkJSON + err = json.Unmarshal(body, &l) + if err != nil { + return fmt.Errorf("json unmarshal: %w", err) + } + + var bl []BatchLink + var uniqueErr bool + for result := range l { + link := l[result] + + var id string + id, err = h.fetchID(c, user, link.OriginalURL) + if errors.Is(err, h.Storage.ErrDupKey) { + err = nil + uniqueErr = true + } + if err != nil { + return fmt.Errorf("fetch id: %w", err) + } + + bl = append(bl, BatchLink{ + CorrelationID: link.CorrelationID, + ShortURL: h.BaseURL + "/" + id, + }) + } + + if uniqueErr { + return c.JSON(http.StatusConflict, bl) + } + + return c.JSON(http.StatusCreated, bl) +} + func (h *Handler) RetrieveURL(c echo.Context) error { + user := getUser(c, h.Key) id := c.Param("id") - url, ok := h.Storage.GetURL(id) - if !ok { + url, status := h.Storage.GetURL(user, id) + if status == 0 { return c.String(http.StatusNotFound, "error, there is no such link") + } else if status == 2 { + return c.String(http.StatusGone, "error, link was deleted") } return c.Redirect(http.StatusTemporaryRedirect, url) } -func (h *Handler) fetchID(c echo.Context, link string) (string, error) { +func (h *Handler) ListURL(c echo.Context) error { + user := getUser(c, h.Key) + list, ok := h.Storage.ListURL(user) + if !ok { + return c.String(http.StatusNoContent, "error, you haven't any saved links") + } + + var arr []*Element + for shortURL, originalURL := range list { + arr = append(arr, &Element{ + ShortURL: h.BaseURL + "/" + shortURL, + OriginalURL: originalURL, + }) + } + + return c.JSON(http.StatusOK, arr) +} + +func (h *Handler) DeleteURL(c echo.Context) error { + user := getUser(c, h.Key) + list, ok := h.Storage.ListURL(user) + if !ok { + return c.String(http.StatusBadRequest, "error, you haven't any saved links") + } + + body, err := io.ReadAll(c.Request().Body) + if err != nil { + return fmt.Errorf("read request body: %w", err) + } + + var l []string + err = json.Unmarshal(body, &l) + if err != nil { + return fmt.Errorf("json unmarshal: %w", err) + } + + for _, linkToDelete := range l { + for savedLink := range list { + if savedLink == linkToDelete { + go h.Storage.UpdateURL(user, linkToDelete, true) + break + } + } + } + + return c.String(http.StatusAccepted, "URLs deleted") +} + +func (h *Handler) Ping(c echo.Context) error { + ok := h.Storage.Ping() + if !ok { + return c.String(http.StatusInternalServerError, "error, no connection to db") + } + + return c.String(http.StatusOK, "db connected") +} + +func (h *Handler) fetchID(c echo.Context, user, link string) (string, error) { if len(link) > 2048 { return "", c.String(http.StatusBadRequest, "error, the link cannot be longer than 2048 characters") } @@ -102,26 +241,30 @@ func (h *Handler) fetchID(c echo.Context, link string) (string, error) { return "", c.String(http.StatusBadRequest, "error, the link is invalid") } - id, ok := h.Storage.GetURL(link) - if !ok { + id, status := h.Storage.GetURL(user, link) + if status == 0 { id, err = shortener(link) if err != nil { return "", c.String(http.StatusBadRequest, "error, failed to create a shortened URL") } } - err = h.Storage.SetURL(id, link) + err = h.Storage.SetURL(user, id, link) + if errors.Is(err, h.Storage.ErrDupKey) { + return id, err + } + if err != nil { return "", c.String(http.StatusInternalServerError, "error, failed to store a shortened URL") } - return id, nil + return id, err } func shortener(s string) (string, error) { h := crypto.MD5.New() if _, err := h.Write([]byte(s)); err != nil { - return "", fmt.Errorf("abbreviation error URL: %v", err) + return "", fmt.Errorf("abbreviation error URL: %w", err) } hash := string(h.Sum([]byte{})) @@ -130,3 +273,43 @@ func shortener(s string) (string, error) { return id, nil } + +func getUser(c echo.Context, key string) string { + user, err1 := readCookie(c, "user") + sign, err2 := readCookie(c, "sign") + if err1 == nil && err2 == nil && sign == getSign(user, key) { + return user + } + + user = randUser() + writeCookie(c, "user", user) + writeCookie(c, "sign", getSign(user, key)) + return user +} + +func randUser() string { + uuidWithHyphen := uuid.New() + return uuidWithHyphen.String() +} + +func getSign(user, key string) string { + h := hmac.New(sha256.New, []byte(key)) + h.Write([]byte(user)) + dst := h.Sum(nil) + return hex.EncodeToString(dst)[:32] +} + +func writeCookie(c echo.Context, name, value string) { + cookie := new(http.Cookie) + cookie.Name = name + cookie.Value = value + c.SetCookie(cookie) +} + +func readCookie(c echo.Context, name string) (string, error) { + cookie, err := c.Cookie(name) + if err != nil { + return "", err + } + return cookie.Value, nil +} diff --git a/internal/handler/handler_test.go b/internal/handler/handler_test.go index c889d06..284b84c 100644 --- a/internal/handler/handler_test.go +++ b/internal/handler/handler_test.go @@ -27,7 +27,7 @@ var ( func TestInit(t *testing.T) { var err error - H, err = New("localhost:8080", "http://localhost:8080", "") + H, err = New("localhost:8080", "http://localhost:8080", "", "") if err != nil { t.Fatal("can't start test") } diff --git a/internal/storage/storage.go b/internal/storage/storage.go index 8974201..2222dce 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -1,82 +1,235 @@ package storage import ( + "database/sql" "encoding/json" "errors" + "fmt" "os" "sync" + + "github.com/jmoiron/sqlx" + "github.com/lib/pq" ) type Storage struct { - DBFile string - Pairs map[string]string - Mutex *sync.RWMutex + DBFile string + DBCredentials string + Pairs map[string]map[string]string + HiddenPairs map[string]map[string]string + Mutex *sync.RWMutex + ErrDupKey error } -func New(dbFile string) (*Storage, error) { - s := &Storage{ - DBFile: dbFile, - Pairs: make(map[string]string), - Mutex: &sync.RWMutex{}, - } +type Link struct { + User string `db:"usr"` + ID string `db:"short"` + Link string `db:"long"` + IsHidden bool `db:"hidden"` +} - if dbFile == "" { - return s, nil +func New(dbFile, dbCredentials string) (*Storage, error) { + s := &Storage{ + DBFile: dbFile, + DBCredentials: dbCredentials, + Pairs: make(map[string]map[string]string), + HiddenPairs: make(map[string]map[string]string), + Mutex: &sync.RWMutex{}, + ErrDupKey: fmt.Errorf("duplicate key"), } s.Mutex.Lock() defer s.Mutex.Unlock() - pairsStr, err := os.ReadFile(dbFile) - if errors.Is(err, os.ErrNotExist) { + if dbCredentials != "" { + db, err := sqlx.Connect("postgres", s.DBCredentials) + if err != nil { + return s, fmt.Errorf("db connect: %w", err) + } + defer db.Close() + + db.MustExec(` + CREATE TABLE IF NOT EXISTS shortener ( + usr text, + short text unique, + long text, + hidden bool + ); + `) + + link := Link{} + rows, err := db.Queryx("SELECT * FROM shortener") + if err != nil { + return s, fmt.Errorf("read rows: %w", err) + } + defer rows.Close() + for rows.Next() { + err := rows.StructScan(&link) + if err != nil { + return s, fmt.Errorf("rows struct scan: %w", err) + } + if link.IsHidden { + if len(s.HiddenPairs[link.User]) == 0 { + s.HiddenPairs[link.User] = make(map[string]string) + } + s.HiddenPairs[link.User][link.ID] = link.Link + } else { + if len(s.Pairs[link.User]) == 0 { + s.Pairs[link.User] = make(map[string]string) + } + s.Pairs[link.User][link.ID] = link.Link + } + } + err = rows.Err() + if err != nil { + return s, fmt.Errorf("rows error: %w", err) + } + return s, nil } - if err != nil { - return s, err - } - err = json.Unmarshal(pairsStr, &s.Pairs) - if err != nil { - return s, err + if dbFile != "" { + pairsStr, err := os.ReadFile(dbFile) + if errors.Is(err, os.ErrNotExist) { + return s, nil + } + if err != nil { + return s, fmt.Errorf("read file: %w", err) + } + + err = json.Unmarshal(pairsStr, &s.Pairs) + if err != nil { + return s, fmt.Errorf("json unmarshal: %w", err) + } } return s, nil } -func (s *Storage) GetURL(id string) (string, bool) { +func (s *Storage) GetURL(user, id string) (string, int) { if len(id) <= 0 { - return "", false + return "", 0 } s.Mutex.Lock() - url, ok := s.Pairs[id] - s.Mutex.Unlock() - if !ok { - return "", false + defer s.Mutex.Unlock() + + for user := range s.Pairs { + url, ok := s.Pairs[user][id] + if ok { + return url, 1 + } + } + + for user := range s.HiddenPairs { + url, ok := s.HiddenPairs[user][id] + if ok { + return url, 2 + } + } + + return "", 0 +} + +func (s *Storage) SetURL(user, id, link string) error { + s.Mutex.Lock() + defer s.Mutex.Unlock() + + if len(s.Pairs[user]) == 0 { + s.Pairs[user] = make(map[string]string) + } + s.Pairs[user][id] = link + + if s.DBCredentials != "" { + db, err := sqlx.Connect("postgres", s.DBCredentials) + if err != nil { + return fmt.Errorf("sql connect: %w", err) + } + defer db.Close() + + _, err = db.Exec("INSERT INTO shortener VALUES ($1, $2, $3, FALSE)", user, id, link) + if err != nil { + if err, ok := err.(*pq.Error); ok { + if err.Code == "23505" { + return s.ErrDupKey + } + } + } + + if err != nil { + return fmt.Errorf("db error: %w", err) + } + return nil } - return url, true + if s.DBFile != "" { + jsonStr, err := json.Marshal(s.Pairs) + if err != nil { + return fmt.Errorf("json marshal: %w", err) + } + + err = os.WriteFile(s.DBFile, []byte(jsonStr), 0777) + if err != nil { + return fmt.Errorf("write file: %w", err) + } + } + + return nil } -func (s *Storage) SetURL(id, link string) error { +func (s *Storage) UpdateURL(user, id string, isHidden bool) error { s.Mutex.Lock() defer s.Mutex.Unlock() - s.Pairs[id] = link + link, ok := s.Pairs[user][id] + if !ok { + return fmt.Errorf("error deleting URL") + } + delete(s.Pairs[user], id) - if s.DBFile == "" { + if len(s.HiddenPairs[user]) == 0 { + s.HiddenPairs[user] = make(map[string]string) + } + s.HiddenPairs[user][id] = link + + if s.DBCredentials != "" { + db, err := sqlx.Connect("postgres", s.DBCredentials) + if err != nil { + return fmt.Errorf("sql connect: %w", err) + } + defer db.Close() + + _, err = db.Exec("UPDATE shortener SET hidden = TRUE WHERE usr = '$1' AND short = '$2'", user, id) + if err != nil { + return fmt.Errorf("db error: %w", err) + } return nil } - jsonStr, err := json.Marshal(s.Pairs) - if err != nil { - return err + return nil +} + +func (s *Storage) ListURL(user string) (map[string]string, bool) { + s.Mutex.Lock() + list, ok := s.Pairs[user] + s.Mutex.Unlock() + if !ok || len(list) == 0 { + return list, false } - err = os.WriteFile(s.DBFile, []byte(jsonStr), 0777) + return list, true +} + +func (s *Storage) Ping() bool { + db, err := sql.Open("postgres", s.DBCredentials) if err != nil { - return err + return false } + defer db.Close() - return nil + if err = db.Ping(); err != nil { + return false + } + + return true } diff --git a/internal/storage/storage_test.go b/internal/storage/storage_test.go index d26a5e2..a4f5e83 100644 --- a/internal/storage/storage_test.go +++ b/internal/storage/storage_test.go @@ -6,18 +6,19 @@ import ( var ( tests = []struct { + user string key string value string }{ - {key: "key", value: "value"}, - {key: "a", value: "b"}, + {user: "test", key: "key", value: "value"}, + {user: "test", key: "a", value: "b"}, } S *Storage ) func TestInit(t *testing.T) { var err error - S, err = New("") + S, err = New("", "") if err != nil { t.Fatal("can't start test") } @@ -25,7 +26,7 @@ func TestInit(t *testing.T) { func TestSetURL(t *testing.T) { for _, testCase := range tests { - err := S.SetURL(testCase.key, testCase.value) + err := S.SetURL(testCase.user, testCase.key, testCase.value) if err != nil { t.Fatalf("can't save value %v for key %v", testCase.value, testCase.key) } @@ -33,19 +34,19 @@ func TestSetURL(t *testing.T) { } func TestGetURL(t *testing.T) { - _, ok := S.GetURL("non-existant-key") - if ok { + _, status := S.GetURL("test", "non-existant-key") + if status != 0 { t.Fatal("got value for non existant key") } - _, ok = S.GetURL("") - if ok { + _, status = S.GetURL("test", "") + if status != 0 { t.Fatal("got value for empty key") } for _, testCase := range tests { - val, ok := S.GetURL(testCase.key) - if !ok { + val, status := S.GetURL(testCase.user, testCase.key) + if status == 0 { t.Fatalf("can't get value for key %v", testCase.key) } if val != testCase.value {