In [None]:
import requests
import csv

## Functions for getting data From GeoDeepDive
The `get_units` function takes a GeoDeepDive API call (designed for one that searches for a term and returns known stratigraphic names).  It makes the call, and then pulls out the stratigraphic names.  For each paper that references only one stratigraphic name it then calls the Macrostrat API to get the location of that unit.  Assuming it can find the unit in Macrostrat, it adds the unit name and location to a list.  If the GeoDeepDive API call includes the `full_results` option, it should recurse to the next page until all pages have been visited.

In [None]:
def get_units(url,geo_units=[], failed_units=[]):
    print("so far, got %d units and failed %d times" % (len(geo_units), len(failed_units)))
    response = requests.get(url)
    data = response.json()
    papers = data['success']['data']
    one_unit_papers = [paper for paper in papers if len(paper['known_terms'])==1 and len(paper['known_terms'][0]['stratigraphic_names'])==1]
    units = [paper['known_terms'][0]['stratigraphic_names'][0] for paper in one_unit_papers]
    for unit in units:
        unit_response=requests.get("https://macrostrat.org/api/units?strat_name=%s&format=geojson" % unit)
        unit_data = unit_response.json()
        try:
            feature=unit_data['success']['data']['features'][0]
            geo_unit = {'unit_name'    : feature['properties']['unit_name'],
                        'unit_id'      : feature['properties']['unit_id'],
                        'strat_name_id': feature['properties']['strat_name_id'],
                        'lat'          : feature['geometry']['coordinates'][1],
                        'long'         : feature['geometry']['coordinates'][0]}
            geo_units.append(geo_unit)
        except (IndexError, KeyError, TypeError) as err:
            failed_units.append((unit, err))
    try:
        new_url=data['success']['next_page']
        if len(new_url)>0:
            get_units(new_url, geo_units, failed_units)
    except KeyError:
        return (geo_units, failed_units)
    return (geo_units, failed_units)

In [2]:
#wrapper function to construct the API call
def get_keyword_units(keyword, geo_units=[], failed_units=[]):
    url = "https://geodeepdive.org/api/snippets?term=%s&dict_filter=stratigraphic_names&dict=stratigraphic_names&full_results" % keyword
    return get_units(url, geo_units, failed_units)

In [3]:
#writes the data to CSVs for future plotting
def write_units(unit_pair, file_name_term):
    data_file_name = "%s_units.csv" % file_name_term
    failed_file_name = "%s_units_unfindable.csv" % file_name_term
    units_file = open(data_file_name, 'w')
    writer = csv.DictWriter(units_file, fieldnames=unit_pair[0][0].keys())
    writer.writeheader()
    writer.writerows(unit_pair[0])
    units_file.close
    failed_file = open(failed_file_name, 'w')
    writer=csv.writer(failed_file)
    writer.writerow(["unit", "failure reason"])
    writer.writerows(unit_pair[1])
    failed_file.close()

## Functions for plotting data from GeoDeepDive
This assumes you have two tables of stratigraphic units, each representing the endmember of a range of values.  The two tables are merged, with the units from the table representing the "high" endmember getting a value of 1, the "low" endmember getting a value of -1, and units in both tables getting a value of 0

In [None]:
def merge_tables(low_table, high_table, value_key, unit_key='unit_name'):
    all_units = pd.concat([low_table, high_table]).drop_duplicates()
    all_units[value_key]=0
    all_units.loc[all_units[unit_key].isin(high_table[unit_key]), value_key] += 1
    all_units.loc[all_units[unit_key].isin(low_table[unit_key]), value_key] += -1
    return all_units

def plot_units(unit_table, basemap, value_key, x_key='long', y_key='lat'):
    unit_geometry = [Point(xy) for xy in zip(unit_table['long'], unit_table['lat'])]
    unit_geotable = GeoDataFrame(unit_table, geometry=unit_geometry)
    unit_geotable = unit_geotable.set_crs("EPSG:4326").clip(basemap)
    fig, ax = plt.subplots(figsize=(12, 8))
    basemap.plot(ax=ax)
    unit_geotable.plot(ax=ax, marker='o', column=value_key, markersize=15, legend=True)

### Get and write (assumed to be) porous units

In [None]:
porous_units = []
pfailed_units = []
porous_unit_info = get_keyword_units("highly porous", porous_units, pfailed_units)
write_units(porous_unit_info, "porous")

### Get and write (assumed to be) impervious units

In [None]:
imperm_units = []
ipfailed_units = []
imperm_unit_info = get_keyword_units("impermeable", imperm_units, ipfailed_units)
write_units(imperm_unit_info, "impermeable")

### Plot porous/impervious units