Skip to content

Commit

Permalink
Use the MD5 checksum on the unzipped file as an update key
Browse files Browse the repository at this point in the history
  • Loading branch information
Ashley Martens committed Oct 28, 2016
1 parent 6a6cfcf commit bc2cd57
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 12 deletions.
34 changes: 22 additions & 12 deletions db.go
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ var (
// DB is the IP geolocation database. // DB is the IP geolocation database.
type DB struct { type DB struct {
file string // Database file name. file string // Database file name.
checksum string // MD5 of the unzipped database file
reader *maxminddb.Reader // Actual db object. reader *maxminddb.Reader // Actual db object.
notifyQuit chan struct{} // Stop auto-update and watch goroutines. notifyQuit chan struct{} // Stop auto-update and watch goroutines.
notifyOpen chan string // Notify when a db file is open. notifyOpen chan string // Notify when a db file is open.
Expand Down Expand Up @@ -165,37 +166,39 @@ func (db *DB) watchEvents(watcher *fsnotify.Watcher) {
} }


func (db *DB) openFile() error { func (db *DB) openFile() error {
reader, err := db.newReader(db.file) reader, checksum, err := db.newReader(db.file)
if err != nil { if err != nil {
return err return err
} }
stat, err := os.Stat(db.file) stat, err := os.Stat(db.file)
if err != nil { if err != nil {
return err return err
} }
db.setReader(reader, stat.ModTime()) db.setReader(reader, stat.ModTime(), checksum)
return nil return nil
} }


func (db *DB) newReader(dbfile string) (*maxminddb.Reader, error) { func (db *DB) newReader(dbfile string) (*maxminddb.Reader, string, error) {
f, err := os.Open(dbfile) f, err := os.Open(dbfile)
if err != nil { if err != nil {
return nil, err return nil, "", err
} }
defer f.Close() defer f.Close()
gzf, err := gzip.NewReader(f) gzf, err := gzip.NewReader(f)
if err != nil { if err != nil {
return nil, err return nil, "", err
} }
defer gzf.Close() defer gzf.Close()
b, err := ioutil.ReadAll(gzf) b, err := ioutil.ReadAll(gzf)
if err != nil { if err != nil {
return nil, err return nil, "", err
} }
return maxminddb.FromBytes(b) checksum := fmt.Sprintf("%x", md5.Sum(b))
mmdb, err := maxminddb.FromBytes(b)
return mmdb, checksum, err
} }


func (db *DB) setReader(reader *maxminddb.Reader, modtime time.Time) { func (db *DB) setReader(reader *maxminddb.Reader, modtime time.Time, checksum string) {
db.mu.Lock() db.mu.Lock()
defer db.mu.Unlock() defer db.mu.Unlock()
if db.closed { if db.closed {
Expand All @@ -207,6 +210,7 @@ func (db *DB) setReader(reader *maxminddb.Reader, modtime time.Time) {
} }
db.reader = reader db.reader = reader
db.lastUpdated = modtime.UTC() db.lastUpdated = modtime.UTC()
db.checksum = checksum
select { select {
case db.notifyOpen <- db.file: case db.notifyOpen <- db.file:
default: default:
Expand All @@ -216,6 +220,7 @@ func (db *DB) setReader(reader *maxminddb.Reader, modtime time.Time) {
func (db *DB) autoUpdate(url string) { func (db *DB) autoUpdate(url string) {
backoff := time.Second backoff := time.Second
for { for {
db.sendInfo("starting update")
err := db.runUpdate(url) err := db.runUpdate(url)
if err != nil { if err != nil {
bs := backoff.Seconds() bs := backoff.Seconds()
Expand All @@ -225,6 +230,7 @@ func (db *DB) autoUpdate(url string) {
} else { } else {
backoff = db.updateInterval backoff = db.updateInterval
} }
db.sendInfo("finished update")
select { select {
case <-db.notifyQuit: case <-db.notifyQuit:
return return
Expand All @@ -235,7 +241,6 @@ func (db *DB) autoUpdate(url string) {
} }


func (db *DB) runUpdate(url string) error { func (db *DB) runUpdate(url string) error {
db.sendInfo("starting update")
yes, err := db.needUpdate(url) yes, err := db.needUpdate(url)
if err != nil { if err != nil {
return err return err
Expand All @@ -252,7 +257,6 @@ func (db *DB) runUpdate(url string) error {
// Cleanup the tempfile if renaming failed. // Cleanup the tempfile if renaming failed.
os.RemoveAll(tmpfile) os.RemoveAll(tmpfile)
} }
db.sendInfo("finished update")
return err return err
} }


Expand All @@ -261,19 +265,26 @@ func (db *DB) needUpdate(url string) (bool, error) {
if err != nil { if err != nil {
return true, nil // Local db is missing, must be downloaded. return true, nil // Local db is missing, must be downloaded.
} }

resp, err := http.Head(url) resp, err := http.Head(url)
if err != nil { if err != nil {
return false, err return false, err
} }
defer resp.Body.Close() defer resp.Body.Close()

// Check X-Database-MD5 if it exists
headerMd5 := resp.Header.Get("X-Database-MD5")
if len(headerMd5) > 0 && db.checksum != headerMd5 {
return true, nil
}

if stat.Size() != resp.ContentLength { if stat.Size() != resp.ContentLength {
return true, nil return true, nil
} }
return false, nil return false, nil
} }


func (db *DB) download(url string) (tmpfile string, err error) { func (db *DB) download(url string) (tmpfile string, err error) {
db.sendInfo("starting download")
resp, err := http.Get(url) resp, err := http.Get(url)
if err != nil { if err != nil {
return "", err return "", err
Expand All @@ -290,7 +301,6 @@ func (db *DB) download(url string) (tmpfile string, err error) {
if err != nil { if err != nil {
return "", err return "", err
} }
db.sendInfo("finished download")
return tmpfile, nil return tmpfile, nil
} }


Expand Down
47 changes: 47 additions & 0 deletions db_test.go
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -84,6 +84,53 @@ func TestNeedUpdateSameFile(t *testing.T) {
} }
} }


func TestNeedUpdateSameMD5(t *testing.T) {
db := &DB{file: testFile}
_, checksum, err := db.newReader(db.file)
if err != nil {
t.Fatal(err)
}
db.checksum = checksum
mux := http.NewServeMux()
changeHeaderThenServe := func(h http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("X-Database-MD5", checksum)
h.ServeHTTP(w, r)
}
}
mux.Handle("/testdata/", changeHeaderThenServe(http.FileServer(http.Dir("."))))
srv := httptest.NewServer(mux)
defer srv.Close()
yes, err := db.needUpdate(srv.URL + "/" + testFile)
if err != nil {
t.Fatal(err)
}
if yes {
t.Fatal("Unexpected: db is not supposed to need an update")
}
}

func TestNeedUpdateMD5(t *testing.T) {
mux := http.NewServeMux()
changeHeaderThenServe := func(h http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("X-Database-MD5", "9823y5981y2398y1234")
h.ServeHTTP(w, r)
}
}
mux.Handle("/testdata/", changeHeaderThenServe(http.FileServer(http.Dir("."))))
srv := httptest.NewServer(mux)
defer srv.Close()
db := &DB{file: testFile}
yes, err := db.needUpdate(srv.URL + "/" + testFile)
if err != nil {
t.Fatal(err)
}
if !yes {
t.Fatal("Unexpected: db is supposed to need an update")
}
}

func TestNeedUpdate(t *testing.T) { func TestNeedUpdate(t *testing.T) {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle("/testdata/", http.FileServer(http.Dir("."))) mux.Handle("/testdata/", http.FileServer(http.Dir(".")))
Expand Down

0 comments on commit bc2cd57

Please sign in to comment.