In [92]:
from google.cloud import storage
from tqdm import tqdm
import pandas as pd
from geopy import distance 
import plotly.express as px
import os
import pyproj
from pyproj import CRS
from shapely.geometry import shape
from shapely.geometry.polygon import Polygon
import shapely.ops as ops
from shapely.ops import transform
from functools import partial
import numpy as np
pd.set_option('display.max_rows', None)

# Test Set Generation

**Author:** Madhava Paliyam (madhavapaliyam@gmail.com)

**Description:** Selects images for test set randomly



**Inputs**: Number of points in the test set, minimum seperation distance for points.

**Outputs**: Test set uploaded onto dvc 

#### Method for Selecting Images for Test Set: 
- Filter by points in each admin2 zone 
- Draw bounding box around the points 
- Calculate area of bounding box 
- Sample points in each admin2 zone according to the proportion area that each bounding box covers in comparision to the entire dataset: 


$$npts_{zone} = npts_{total} \times \frac{area_{zone}}{area_{total}}$$

$area_{total}$ is the sum of the areas of the bounding boxes: $$area_{total} = \sum^{admin2\space zones}{area_{zone}}$$

$area_{zone}$ is the area of the bounding box in the admin2 zone calculated using the minimum and maximum latitudes and longitudes. 

$npts_{total}$ is the total number of points needed in the test set sample. 

Lastly, any points are too close together are resampled. 

##### Get database info csv from google cloud 

In [88]:
!dvc pull -q -f 


##### SET SIZE FOR DATASET #####
DATASET_SIZE = 1000 

##### SET MINIMUM SEPERATION #####
MIN_SEPERATION = 200 # m 


# dictionary to choose admin zone based on country 
admin_zone_to_use = {'KE': 'admin1', 'UG': 'admin2', 'US': 'admin2'}

to_drop = {'US': ['City of Baltimore'], 'KE': [np.nan], 'UG': [np.nan]}

[0m

In [79]:
# Initialize connections to cloud storage
client = storage.Client()

# Read csv with database info 
db = pd.read_csv('../data/database-info.csv', index_col = 0)

# filter to get rid of images currently being labeled 
db = db[db['being_labeled'] == False]
print(len(db))

db.isna().sum() 



Columns (7) have mixed types.Specify dtype option on import or set low_memory=False.



1498498


input_img             0
latitude          20831
longitude         20831
being_labeled         0
country               6
admin1            24253
admin2           149995
cc                20831
location          20831
test_set              0
time                  0
focal_length          0
pixel_height          0
dtype: int64

In [80]:
db = db[db['latitude'].notna()]
db = db[db['longitude'].notna()]
len(db)

1477667

#### Calculates the area covered by the sampling in each admin2 zone

In [93]:
def area_for_admin2_km2(filtered_db):
    min_lon = filtered_db['longitude'].min()
    max_lon = filtered_db['longitude'].max()
    min_lat = filtered_db['latitude'].min()
    max_lat = filtered_db['latitude'].max()

    polygon = Polygon(
            [
                [min_lon, min_lat],
                [min_lon, max_lat],
                [max_lon, max_lat],
                [max_lon, min_lat],
            ]
        )


    polygon = ops.transform(
            partial(
                pyproj.transform,
                pyproj.Proj("EPSG:4326"),
                pyproj.Proj(
                    proj="aea", lat_1=polygon.bounds[1], lat_2=polygon.bounds[3]
                ),
            ),
            polygon,
        )
    
    # if polygon.area == np.nan: 
    #     geom = {'type': 'Polygon',
    #             'coordinates': [[[min_lon, min_lat],
    #                     [min_lon, max_lat],
    #                     [max_lon, max_lat],
    #                     [max_lon, min_lat],
    #                 ]]}

    #     s = shape(geom)
    #     proj = partial(pyproj.transform, pyproj.Proj(CRS("EPSG:4326")), pyproj.Proj(CRS("EPSG:3857")))

    #     s_new = transform(proj, s)
    #     return s_new.area * 1e-6 # Convert from m2 to km2

    # else: 
    return polygon.area * 1e-6


'''
Returns the number of images to sample from each admin zone based on the formula given above 
'''
def get_areas_km2(db):

    country_admin2_area_images = []
    for cc in tqdm(db['cc'].unique()): 
        print(f"Country Code: {cc}")
        admin_level = admin_zone_to_use[cc]
        print(f"Using admin zone level: {admin_level}")

        admin_zones = list(db[db['cc'] == cc][admin_level].unique())

        if cc in to_drop: 
            for zone_to_drop in to_drop[cc]: 
                admin_zones.remove(zone_to_drop) 
    
        print(f"For {admin_level} using zones: {admin_zones}")

        for admin_zone in admin_zones: 
            filtered_db = db[db[admin_level] == admin_zone]
            total_images_for_zone = len(filtered_db)
            area = area_for_admin2_km2(filtered_db) 
            t = (cc, admin_zone, area, total_images_for_zone)
            country_admin2_area_images.append(t)
        
    area_info = pd.DataFrame(country_admin2_area_images, columns=['cc', 'admin_zone', 'area', 'total_images'])

    area_info['num_images_to_sample'] = (area_info['area'] / area_info['area'].sum()) * DATASET_SIZE

    area_info['num_images_to_sample'] = area_info['num_images_to_sample'].round()

    area_info['num_images_to_sample'] = [int(n) if n > 5 else 0 for n in area_info['num_images_to_sample']]

    return area_info


area_info = get_areas_km2(db)  
print(area_info['num_images_to_sample'].sum())
print(area_info['total_images'].sum())
area_info


  0%|          | 0/3 [00:00<?, ?it/s]

Country Code: KE
Using admin zone level: admin1
For admin1 using zones: ['Laikipia', 'Nakuru', 'Nyandarua', 'Kericho', 'Bomet', 'Nyamira District', 'Kisii', 'Migori', 'Homa Bay', 'Narok', 'Uasin Gishu', 'Nandi', 'Bungoma', 'Trans Nzoia', 'West Pokot', 'Marakwet District', 'Kisumu', 'Kakamega', 'Siaya', 'Busia']


 33%|███▎      | 1/3 [00:02<00:04,  2.01s/it]

Country Code: UG
Using admin zone level: admin2
For admin2 using zones: ['Bukwa District', 'Jinja District', 'Mbale District', 'Mukono District', 'Buikwe District', 'Bududa District', 'Manafwa District', 'Kampala District', 'Wakiso District', 'Luwero District', 'Kayunga District', 'Nakasongola District', 'Nakaseke District', 'Mityana District', 'Mpigi District', 'Butambala District', 'Kasese District', 'Kabarole District', 'Kamwenge District', 'Rubirizi District', 'Kyenjojo District', 'Kibale District', 'Hoima District', 'Masindi District', 'Bulisa District', 'Nebbi District', 'Nwoya District']


 67%|██████▋   | 2/3 [00:05<00:02,  2.61s/it]

Country Code: US
Using admin zone level: admin2
For admin2 using zones: ['Defiance County', 'Paulding County', 'Allen County', 'Adams County', 'Huntington County', 'Blackford County', 'Delaware County', 'Madison County', 'Grant County', 'Miami County', 'Wabash County', 'Fulton County', 'Cass County', 'White County', 'Jasper County', 'Newton County', 'Iroquois County', 'Ford County', 'McLean County', 'Tazewell County', 'Mason County', 'Knox County', 'Henry County', 'Butler County', 'Floyd County', 'Mitchell County', 'Mower County', 'Dodge County']


100%|██████████| 3/3 [00:07<00:00,  2.59s/it]

966
1453195





Unnamed: 0,cc,admin_zone,area,total_images,num_images_to_sample
0,KE,Laikipia,43.415879,537,0
1,KE,Nakuru,5132.5611,8916,105
2,KE,Nyandarua,0.493049,351,0
3,KE,Kericho,3925.231032,7133,81
4,KE,Bomet,44.040203,1640,0
5,KE,Nyamira District,476.816181,1770,10
6,KE,Kisii,816.812433,8661,17
7,KE,Migori,245.384511,2263,0
8,KE,Homa Bay,1592.450481,5059,33
9,KE,Narok,1750.320469,5668,36


In [77]:
px.set_mapbox_access_token('pk.eyJ1IjoibWFuZ29tYWRoYXZhIiwiYSI6ImNrdWcyNHh2OTIwMmQzMW56eWFibjUwY3QifQ.aQydTOk0ne3KrV87Ib_TrQ')
fig = px.scatter_mapbox(db[db['admin1'] == 'Illinois'], lat='latitude', lon='longitude', size_max=15, zoom=10)
fig.show()

#### Remove any images that are too close together

In [9]:

# ensures that coordinate 1 (c1) is seperated from all the points in points by at least THRESHOLD
def is_far(c1, points):
    for point in points:
        c2 = (point['latitude'], point['longitude'])
        dist = distance.distance(c1, c2).m
        if dist < MIN_SEPERATION: 
            return False 
        
    return True 


# samples points based on country and ensures that they are seperated by at least THRESHOLD 
def sample_geographically_distributed_points(country, admin2, num_points):
    # add 5 images for each entry 
    num_points = num_points + 5

    if country == 'KE': 
        total_points = len(db[db['admin1'] == admin2])
        points = [db[db['admin1'] == admin2].sample(1).iloc[0].to_dict()]

    else: 
        total_points = len(db[db['admin2'] == admin2])
        points = [db[db['admin2'] == admin2].sample(1).iloc[0].to_dict()]
    points_left = num_points - 1

    print(total_points)
    num_times_resampled = 0
    with tqdm(total=num_points) as progress_bar:
        progress_bar.update(1)
        while points_left != 0: 
            if country == 'KE':
                random_sample = db[db['admin1'] == admin2].sample(1).iloc[0].to_dict()
            else: 
                random_sample = db[db['admin2'] == admin2].sample(1).iloc[0].to_dict()
            c1 = (random_sample['latitude'], random_sample['longitude'])
            if is_far(c1, points):
                points.append(random_sample)
                points_left = points_left - 1
                progress_bar.update(1)
                num_times_resampled = 0
            else: 
                num_times_resampled += 1
                if num_times_resampled % 50 == 0: 
                    print(f"Resampled {num_times_resampled} times.")
                if num_times_resampled >= total_points / 4: 
                    print(f"COULD NOT FIND ENOUGH POINTS FOR {admin2}")
                    break 

    return pd.DataFrame(points)

In [None]:
df_list = []
for r,entry in area_info.iterrows(): 
    if entry['num_images_to_sample'] > 0: 
        print(entry)
        admin2_df = sample_geographically_distributed_points(entry['cc'], entry['admin2'], entry['num_images_to_sample'])
        df_list.append(admin2_df)

# df_list


In [None]:
test_set_concat = pd.concat(df_list)
len(test_set_concat)


In [None]:
px.set_mapbox_access_token('pk.eyJ1IjoibWFuZ29tYWRoYXZhIiwiYSI6ImNrdWcyNHh2OTIwMmQzMW56eWFibjUwY3QifQ.aQydTOk0ne3KrV87Ib_TrQ')
fig = px.scatter_mapbox(test_set_concat, lat='latitude', lon='longitude', size_max=15, zoom=10)
fig.show()

#### Run the cell below if you want to download the images locally

In [None]:
gcloud_uploaded_bucket = client.bucket('street2sat-uploaded')

if not os.path.exists('test_set'):
    os.makedirs('test_set')

for i,point in test_set_concat.iterrows():
    blob = gcloud_uploaded_bucket.blob(point['input_img'].replace('gs://street2sat-uploaded/', ''))
    try: 
        blob.download_to_filename(f'test_set/{i}.jpg')
    except: 
        t = point['input_img'].replace('gs://street2sat-uploaded/', '')
        print(f'image not found: {t}')

#### Run the cell below if you want to update the test set csv with the generated test set 

In [None]:
test_set_concat.to_csv('../data/test.csv')

In [None]:
(test_set_concat['being_labeled'] == False).all()

In [None]:
# TODO: upload to dvc 