diff --git a/freegeoip.go b/freegeoip.go index 9d7c9a2..a487746 100644 --- a/freegeoip.go +++ b/freegeoip.go @@ -169,19 +169,28 @@ func GeoipHandler() http.HandlerFunc { return } // GET continuing... - var ( - ip, ipkey string - err error - ) - if ip, _, err = net.SplitHostPort(r.RemoteAddr); err != nil { - ipkey = r.RemoteAddr // Support for XHeaders + var srcIP net.IP + if ip, _, err := net.SplitHostPort(r.RemoteAddr); err != nil { + srcIP = net.ParseIP(r.RemoteAddr) // Use X-Real-IP } else { - ipkey = ip + srcIP = net.ParseIP(ip) + } + if srcIP == nil { + http.Error(w, http.StatusText(400), 400) + return + } + nsrcIP, err := ip2int(srcIP) + if err != nil { + if conf.Debug { + log.Println(err) + } + http.Error(w, http.StatusText(400), 400) + return } // Check quota. if conf.Limit.MaxRequests > 0 { var ok bool - if ok, err = quota.Ok(ipkey); err != nil { + if ok, err = quota.Ok(nsrcIP); err != nil { if conf.Debug { log.Println(err) // redis error } @@ -193,6 +202,10 @@ func GeoipHandler() http.HandlerFunc { return } } + var ( + queryIP net.IP + nqueryIP uint32 + ) // Parse URL (e.g. /csv/ip, /xml/) a := strings.SplitN(r.URL.Path, "/", 3) if len(a) == 3 && a[2] != "" { @@ -202,12 +215,24 @@ func GeoipHandler() http.HandlerFunc { http.Error(w, http.StatusText(404), 404) return } - ip = addrs[0] + if queryIP = net.ParseIP(addrs[0]); queryIP == nil { + http.Error(w, http.StatusText(400), 400) + return + } + nqueryIP, err = ip2int(net.ParseIP(addrs[0])) + if err != nil { + if conf.Debug { + log.Println(err) + } + http.Error(w, http.StatusText(400), 400) + return + } } else { - ip = ipkey + queryIP = srcIP + nqueryIP = nsrcIP } // Query the db. - geoip, err := GeoipLookup(stmt, ip) + geoip, err := lookup(stmt, queryIP, nqueryIP) if err != nil { http.NotFound(w, r) return @@ -254,30 +279,23 @@ func GeoipHandler() http.HandlerFunc { } } -const query = `SELECT - city_location.country_code, - country_blocks.country_name, - city_location.region_code, - region_names.region_name, - city_location.city_name, - city_location.postal_code, - city_location.latitude, - city_location.longitude, - city_location.metro_code, - city_location.area_code -FROM city_blocks - NATURAL JOIN city_location - INNER JOIN country_blocks ON - city_location.country_code = country_blocks.country_code - LEFT OUTER JOIN region_names ON - city_location.country_code = region_names.country_code - AND - city_location.region_code = region_names.region_code -WHERE city_blocks.ip_start <= ? -ORDER BY city_blocks.ip_start DESC LIMIT 1` +func ip2int(ip net.IP) (uint32, error) { + var n uint32 + ipv4 := ip.To4() + if ipv4 == nil { + return 0, fmt.Errorf("IP %s is not IPv4", ip.String()) + } + if err := binary.Read( + bytes.NewBuffer(ipv4), + binary.BigEndian, + &n, + ); err != nil { + return 0, fmt.Errorf("IP conversion failed: %s", err.Error()) + } + return n, nil +} -func GeoipLookup(stmt *sql.Stmt, ip string) (*GeoIP, error) { - IP := net.ParseIP(ip) +func lookup(stmt *sql.Stmt, IP net.IP, nIP uint32) (*GeoIP, error) { var reserved bool for _, net := range reservedIPs { if net.Contains(IP) { @@ -285,15 +303,12 @@ func GeoipLookup(stmt *sql.Stmt, ip string) (*GeoIP, error) { break } } - geoip := GeoIP{Ip: ip} + geoip := GeoIP{Ip: IP.String()} if reserved { geoip.CountryCode = "RD" geoip.CountryName = "Reserved" } else { - var uintIP uint32 - binary.Read(bytes.NewBuffer(IP.To4()), - binary.BigEndian, &uintIP) - if err := stmt.QueryRow(uintIP).Scan( + if err := stmt.QueryRow(nIP).Scan( &geoip.CountryCode, &geoip.CountryName, &geoip.RegionCode, @@ -346,23 +361,46 @@ var reservedIPs = []net.IPNet{ {net.IPv4(255, 255, 255, 255), net.IPv4Mask(255, 255, 255, 255)}, } +// SQLite query. +const query = `SELECT + city_location.country_code, + country_blocks.country_name, + city_location.region_code, + region_names.region_name, + city_location.city_name, + city_location.postal_code, + city_location.latitude, + city_location.longitude, + city_location.metro_code, + city_location.area_code +FROM city_blocks + NATURAL JOIN city_location + INNER JOIN country_blocks ON + city_location.country_code = country_blocks.country_code + LEFT OUTER JOIN region_names ON + city_location.country_code = region_names.country_code + AND + city_location.region_code = region_names.region_code +WHERE city_blocks.ip_start <= ? +ORDER BY city_blocks.ip_start DESC LIMIT 1` + // Quota interface for limiting access to the API. type Quota interface { Setup(args ...string) // Initialize quota backend - Ok(ipkey string) (bool, error) // Returns true if under quota + Ok(ipkey uint32) (bool, error) // Returns true if under quota } // MapQuota implements the Quota interface using a map as the backend. type MapQuota struct { mu sync.Mutex - m map[string]int + m map[uint32]int } func (q *MapQuota) Setup(args ...string) { - q.m = make(map[string]int) + q.m = make(map[uint32]int) } -func (q *MapQuota) Ok(ipkey string) (bool, error) { +func (q *MapQuota) Ok(ipkey uint32) (bool, error) { q.mu.Lock() defer q.mu.Unlock() if n, ok := q.m[ipkey]; ok { @@ -392,16 +430,17 @@ func (q *RedisQuota) Setup(args ...string) { q.c.Timeout = time.Duration(800) * time.Millisecond } -func (q *RedisQuota) Ok(ipkey string) (bool, error) { - if ns, err := q.c.Get(ipkey); err != nil { +func (q *RedisQuota) Ok(ipkey uint32) (bool, error) { + k := fmt.Sprintf("%d", ipkey) // "numeric" key + if ns, err := q.c.Get(k); err != nil { return false, fmt.Errorf("redis: %s", err.Error()) } else if ns == "" { - if err := q.c.Set(ipkey, "1"); err != nil { + if err := q.c.Set(k, "1"); err != nil { return false, fmt.Errorf("redis: %s", err.Error()) } - q.c.Expire(ipkey, conf.Limit.Expire) // what if.. + q.c.Expire(k, conf.Limit.Expire) // what if.. } else if n, _ := strconv.Atoi(ns); n < conf.Limit.MaxRequests { - q.c.Incr(ipkey) + q.c.Incr(k) } else { return false, nil }