Permalink
Browse files

Use the MD5 checksum on the unzipped file as an update key

  • Loading branch information...
1 parent 6a6cfcf commit bc2cd576bc81e4fd2c9062960e3ec0876caf2c1f Ashley Martens committed Oct 27, 2016
Showing with 69 additions and 12 deletions.
  1. +22 −12 db.go
  2. +47 −0 db_test.go
View
@@ -39,6 +39,7 @@ var (
// DB is the IP geolocation database.
type DB struct {
file string // Database file name.
+ checksum string // MD5 of the unzipped database file
reader *maxminddb.Reader // Actual db object.
notifyQuit chan struct{} // Stop auto-update and watch goroutines.
notifyOpen chan string // Notify when a db file is open.
@@ -165,37 +166,39 @@ func (db *DB) watchEvents(watcher *fsnotify.Watcher) {
}
func (db *DB) openFile() error {
- reader, err := db.newReader(db.file)
+ reader, checksum, err := db.newReader(db.file)
if err != nil {
return err
}
stat, err := os.Stat(db.file)
if err != nil {
return err
}
- db.setReader(reader, stat.ModTime())
+ db.setReader(reader, stat.ModTime(), checksum)
return nil
}
-func (db *DB) newReader(dbfile string) (*maxminddb.Reader, error) {
+func (db *DB) newReader(dbfile string) (*maxminddb.Reader, string, error) {
f, err := os.Open(dbfile)
if err != nil {
- return nil, err
+ return nil, "", err
}
defer f.Close()
gzf, err := gzip.NewReader(f)
if err != nil {
- return nil, err
+ return nil, "", err
}
defer gzf.Close()
b, err := ioutil.ReadAll(gzf)
if err != nil {
- return nil, err
+ return nil, "", err
}
- return maxminddb.FromBytes(b)
+ checksum := fmt.Sprintf("%x", md5.Sum(b))
+ mmdb, err := maxminddb.FromBytes(b)
+ return mmdb, checksum, err
}
-func (db *DB) setReader(reader *maxminddb.Reader, modtime time.Time) {
+func (db *DB) setReader(reader *maxminddb.Reader, modtime time.Time, checksum string) {
db.mu.Lock()
defer db.mu.Unlock()
if db.closed {
@@ -207,6 +210,7 @@ func (db *DB) setReader(reader *maxminddb.Reader, modtime time.Time) {
}
db.reader = reader
db.lastUpdated = modtime.UTC()
+ db.checksum = checksum
select {
case db.notifyOpen <- db.file:
default:
@@ -216,6 +220,7 @@ func (db *DB) setReader(reader *maxminddb.Reader, modtime time.Time) {
func (db *DB) autoUpdate(url string) {
backoff := time.Second
for {
+ db.sendInfo("starting update")
err := db.runUpdate(url)
if err != nil {
bs := backoff.Seconds()
@@ -225,6 +230,7 @@ func (db *DB) autoUpdate(url string) {
} else {
backoff = db.updateInterval
}
+ db.sendInfo("finished update")
select {
case <-db.notifyQuit:
return
@@ -235,7 +241,6 @@ func (db *DB) autoUpdate(url string) {
}
func (db *DB) runUpdate(url string) error {
- db.sendInfo("starting update")
yes, err := db.needUpdate(url)
if err != nil {
return err
@@ -252,7 +257,6 @@ func (db *DB) runUpdate(url string) error {
// Cleanup the tempfile if renaming failed.
os.RemoveAll(tmpfile)
}
- db.sendInfo("finished update")
return err
}
@@ -261,19 +265,26 @@ func (db *DB) needUpdate(url string) (bool, error) {
if err != nil {
return true, nil // Local db is missing, must be downloaded.
}
+
resp, err := http.Head(url)
if err != nil {
return false, err
}
defer resp.Body.Close()
+
+ // Check X-Database-MD5 if it exists
+ headerMd5 := resp.Header.Get("X-Database-MD5")
+ if len(headerMd5) > 0 && db.checksum != headerMd5 {
+ return true, nil
+ }
+
if stat.Size() != resp.ContentLength {
return true, nil
}
return false, nil
}
func (db *DB) download(url string) (tmpfile string, err error) {
- db.sendInfo("starting download")
resp, err := http.Get(url)
if err != nil {
return "", err
@@ -290,7 +301,6 @@ func (db *DB) download(url string) (tmpfile string, err error) {
if err != nil {
return "", err
}
- db.sendInfo("finished download")
return tmpfile, nil
}
View
@@ -84,6 +84,53 @@ func TestNeedUpdateSameFile(t *testing.T) {
}
}
+func TestNeedUpdateSameMD5(t *testing.T) {
+ db := &DB{file: testFile}
+ _, checksum, err := db.newReader(db.file)
+ if err != nil {
+ t.Fatal(err)
+ }
+ db.checksum = checksum
+ mux := http.NewServeMux()
+ changeHeaderThenServe := func(h http.Handler) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Add("X-Database-MD5", checksum)
+ h.ServeHTTP(w, r)
+ }
+ }
+ mux.Handle("/testdata/", changeHeaderThenServe(http.FileServer(http.Dir("."))))
+ srv := httptest.NewServer(mux)
+ defer srv.Close()
+ yes, err := db.needUpdate(srv.URL + "/" + testFile)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if yes {
+ t.Fatal("Unexpected: db is not supposed to need an update")
+ }
+}
+
+func TestNeedUpdateMD5(t *testing.T) {
+ mux := http.NewServeMux()
+ changeHeaderThenServe := func(h http.Handler) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Add("X-Database-MD5", "9823y5981y2398y1234")
+ h.ServeHTTP(w, r)
+ }
+ }
+ mux.Handle("/testdata/", changeHeaderThenServe(http.FileServer(http.Dir("."))))
+ srv := httptest.NewServer(mux)
+ defer srv.Close()
+ db := &DB{file: testFile}
+ yes, err := db.needUpdate(srv.URL + "/" + testFile)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !yes {
+ t.Fatal("Unexpected: db is supposed to need an update")
+ }
+}
+
func TestNeedUpdate(t *testing.T) {
mux := http.NewServeMux()
mux.Handle("/testdata/", http.FileServer(http.Dir(".")))

0 comments on commit bc2cd57

Please sign in to comment.