# Import Modules

In [None]:
import fiona 
import geopandas as gpd
import pandas as pd
import numpy as np
pd.set_option('display.max_columns', None)

import os

import matplotlib.pyplot as plt
import seaborn as sns

import datetime
import tqdm
import time

from collections import defaultdict, Counter

# Define classes for node and graph

In [None]:
class Node:
    
    def __init__(self, coor):
        self.coor = coor # (long, lat)
        self.sons = set() # downstream
        self.father = set() # upstream

In [None]:
class waste_water_processor:
    
    def __init__(self, base_dir):
        
        self.base = base_dir #where to read saved files
        if not self.base.endswith('/'):
            self.base += '/'
        self.init() #empty data structures
        
    def init(self):
        self.coor_to_name = defaultdict(list)
        self.all_nodes = {}
        self.name_to_coor = {}
        self.build_graph()
    
    def build_graph(self):
        #load saved files to build the whole graph
        data = pd.read_csv(self.base + 'connections.csv', dtype={'cbg':str, 'cb':str})
        data['name'] = data['name'].apply(eval)
        data['down_stream'] = data['down_stream'].apply(eval)
        data['up_stream'] = data['up_stream'].apply(eval)
        
        #initialize basic nodes
        print("Initializing Manholes")
        for row in tqdm.tqdm(range(data.shape[0])):
            long = data.loc[row, 'long']
            lat = data.loc[row, 'lat']
            names = data.loc[row, 'name']
            cbg = data.loc[row, 'cbg']
            cb = data.loc[row, 'cb']
            coor = (long, lat)
            self.all_nodes[coor] = Node(coor)
            self.coor_to_name[coor] = names
            for name in names:
                self.name_to_coor[name] = coor
            self.all_nodes[coor].cb = cb
            self.all_nodes[coor].cbg = cbg
                
        #connect upstreams and downstreams
        print("Connecting Manholes")
        for row in tqdm.tqdm(range(data.shape[0])):
            long = data.loc[row, 'long']
            lat = data.loc[row, 'lat']
            coor = (long, lat)
            
            downstream = data.loc[row, 'down_stream']
            for next_coor in downstream:
                next_coor = self.locate_coor(next_coor)
                self.all_nodes[coor].sons |= set([next_coor])
                
            upstream = data.loc[row, 'up_stream']
            for prev_coor in upstream:
                prev_coor = self.locate_coor(prev_coor)
                self.all_nodes[coor].father |= set([prev_coor])
                
        #check how many manholes are there in each census block
        self.cb_counter = defaultdict(int)
        for coor, node in self.all_nodes.items():
            self.cb_counter[node.cb] += 1
            
        population = pd.read_csv(self.base + 'us2019_yolo.csv', dtype={'block_fips':str})
        self.all_cb = population.block_fips.values.tolist()
        self.all_cb.sort()
        
        #load sampling locations
        self.locations = pd.read_csv(self.base + 'COD sampling MHs.csv')
            
    def locate_coor(self, coor):
        '''
        This functions reads a coordinate (format long, lat) and find the closest manhole
        '''
        if self.all_nodes.get(coor, None) is None: # not a real node
            smallest = 10000
            for coor0 in self.all_nodes.keys():
                dist = (coor0[0] - coor[0])**2 + (coor0[1] - coor[1])**2
                if dist < smallest:
                    smallest = dist
                    target_coor = coor0
            coor = target_coor
            return coor
        else:
            return coor
    
    def dfs(self, coor, visited, direction):
        '''
        This is a inside utility function. Do not use.
        DFS starts from a manhole and look for its upstreams/downstreams, depending on the value of direction
        '''
        if not visited[coor]:
            visited[coor] = 1
            if direction == 'upstream':
                next_list = self.all_nodes[coor].father
            else:
                next_list = self.all_nodes[coor].sons
            for next_coor in next_list:
                self.dfs(next_coor, visited, direction)
                
    def find_connection(self, coor, direction):
        '''
        Given a coordinate (long, lat), find its nearest manhole and search for its all upstreams/downstreams
        '''
        direction = direction.lower()
        assert direction in ['downstream', 'upstream']
        
        coor = self.locate_coor(coor)
        visited = defaultdict(int)
        self.dfs(coor, visited, direction)
        origin = coor
        
        all_x = []
        all_y = []
        for new_coor in visited.keys():
            if visited[new_coor]:
                all_x.append(new_coor[0])
                all_y.append(new_coor[1])
                
        return all_x, all_y
    
    def plot_connection(self, coor, direction):
        '''
        Given a coordinate (long, lat), find its nearest manhole and plot its all upstreams/downstreams
        '''
    
        direction = direction.lower()
        assert direction in ['downstream', 'upstream']
        
        plt.figure(figsize=(17, 8))
        
        all_x = []
        all_y = []
        for key, value in self.coor_to_name.items():
            all_x.append(key[0])
            all_y.append(key[1])
        
        plt.scatter(all_x, all_y, label='normal')
        
        coor = self.locate_coor(coor)
        all_x, all_y = self.find_connection(coor, direction)
        
        plt.scatter(all_x, all_y, label=direction)
        
        plt.scatter(coor[0], coor[1], label='source', color='red', marker='s')
        
        plt.legend()
        plt.title("{} nodes for node {}".format(direction, list(self.coor_to_name[coor])[0]))
        plt.show()
        
    def get_population_composition(self, node_name):
        '''
        Inside utility function, do not use.
        Get the population composition at a give manhole
        '''
        node_name += '-1'
        node = self.all_nodes[self.name_to_coor[node_name]]
        return node.population
    
    def process_HDT_data(self, path):
        '''
        Process HDT data (v3) so that it can be fed to `find_collection_points` method for further analysis
        
        Params:
        =======
        path: string, the directory for the target HDT file
        '''
        data = pd.read_csv(path, dtype={'CensusGEOID':str})
        data = data[['ResultDate', 'Result', 'CensusGEOID']]
        data.dropna(how='any', inplace=True)
        data['CensusGEOID'] = data['CensusGEOID'].apply(lambda x : x if len(x) == 15 else '0' + x)
        def get_date(x):
            x = x.split(' ')[0]
            m, d, y = x.split('/')
            m = int(m)
            d = int(d)
            y = int(y)
            date = datetime.date(y, m, d)
            return date
        
        data['date'] = data['ResultDate'].apply(get_date)
        data['Result'] = data['Result'].apply(lambda x : 1 if x == 'Detected' else 0)
        data = data.groupby(['date', 'CensusGEOID']).aggregate('sum')
        data.reset_index(drop=False, inplace=True)
        data.rename(columns = {'CensusGEOID':'census_block', 'Result':'positive'}, inplace=True)
        data = data[data['census_block'].apply(lambda x : x.startswith('06113'))]
        data.reset_index(drop=True, inplace=True)
        return data
    
    def process_HDT_data_all_tests(self, path):
        '''
        Process HDT data (v3) so that it can be fed to `find_collection_points` method for further analysis
        
        Params:
        =======
        path: string, the directory for the target HDT file
        '''
        data = pd.read_csv(path, dtype={'CensusGEOID':str})
        data = data[['ResultDate', 'Result', 'CensusGEOID']]
        data.dropna(how='any', inplace=True)
        data['CensusGEOID'] = data['CensusGEOID'].apply(lambda x : x if len(x) == 15 else '0' + x)
        def get_date(x):
            x = x.split(' ')[0]
            m, d, y = x.split('/')
            m = int(m)
            d = int(d)
            y = int(y)
            date = datetime.date(y, m, d)
            return date
        
        data['date'] = data['ResultDate'].apply(get_date)
        data['Result'] = data['Result'].apply(lambda x : 1 if x == 'Detected' else 1)
        data = data.groupby(['date', 'CensusGEOID']).aggregate('sum')
        data.reset_index(drop=False, inplace=True)
        data.rename(columns = {'CensusGEOID':'census_block', 'Result':'positive'}, inplace=True)
        data = data[data['census_block'].apply(lambda x : x.startswith('06113'))]
        data.reset_index(drop=True, inplace=True)
        return data
        
    def find_collection_points(self, source_locations):
        '''
        For a given DataFrame of Infection locations, find the expectation of infection at each collection point
        
        Params:
        source_locations : pd.DataFrame, should have at least two columns.
            column 'census_block', shows which census blocks have infections
            column 'positive', shows positive counts for each corresponding census block
            (optional) column 'date', indicating the date when the record is collected
        '''
        #check if there are multiple dates
        if 'date' in source_locations.columns:
            start_date = source_locations['date'].min()
            end_date = source_locations['date'].max()
            days = (end_date - start_date).days
            result = pd.DataFrame({})
            for i in tqdm.tqdm(range(days + 1)):
                date = start_date + datetime.timedelta(days=i)
                temp = source_locations[source_locations['date'] == date]
                temp.reset_index(inplace=True, drop=True)
                temp = temp[['census_block', 'positive']]
                temp = self.find_collection_points(temp)
                temp = temp[temp['total_infection'] > 0]
                temp['date'] = date
                result = result.append(temp)
            result.reset_index(inplace=True, drop=True)
            return result
        
        #check initial values
        infection = defaultdict(lambda : 0)
        for cb, count in zip(source_locations.census_block, source_locations.positive):
            infection[cb] = count
        
        #topological sort
        in_order = defaultdict(int)
        for coor, node in self.all_nodes.items():
            cb = node.cb
            node.population = defaultdict(float)
            node.population[cb] = infection[cb] / self.cb_counter[cb]
            for next_coor in node.sons:
                in_order[next_coor] += 1
        
        queue = []
        for coor, node in self.all_nodes.items():
            if in_order[coor] == 0:
                queue.append(coor)
                
        while queue:
            next_queue = []
            for coor in queue:
                node = self.all_nodes[coor]
                N = len(node.sons)
                for next_coor in node.sons:
                    next_node = self.all_nodes[next_coor]
                    for key, value in node.population.items():
                        next_node.population[key] += value / len(node.sons)
                    in_order[next_coor] -= 1
                    if in_order[next_coor] == 0:
                        next_queue.append(next_coor)
            queue = next_queue
        
        infection = [infection[cb] for cb in self.all_cb]
        MH_to_cb = {'census_block':self.all_cb, 'total_infection':infection}
        for row in range(self.locations.shape[0]):
            MH = self.locations.loc[row, 'MH ID']
            composition = self.get_population_composition(MH)
            temp_values = []
            for cb in self.all_cb:
                temp_values.append(composition[cb])
            MH_to_cb[MH] = temp_values
            
        return pd.DataFrame(MH_to_cb)

# Initialize Graph from existing Files

In [None]:
graph = waste_water_processor('./')

# Example of plotting Upstream/Downstream

Use `graph.plot_connection` method, pass in a coordinate and the direction

In [None]:
graph.plot_connection(graph.name_to_coor['M16-011-1'], 'upstream')

In [None]:
graph.plot_connection(graph.name_to_coor['M16-011-1'], 'downstream')

In [None]:
for row in range(graph.locations.shape[0]):
    long = graph.locations.loc[row, 'Long']
    lat = graph.locations.loc[row, 'Lat']
    graph.plot_connection((long, lat), 'upstream')

# Composition of Waste Water Source

Here we assume that each infected person produces the same amount of waste water each day. This amount is called **a unit**. We further assume that for each census block, all manholes in that census block has the same probability of collecting the waste water produced by that infected person. 

Here we check the **expect value of units of waste water produced by infected people at each collection manhoel**.

## Example for a single date

To check how waste water produced by infected people are collected, you should call `graph.find_collection_points` method and pass a `pd.DataFrame` object as input. The dataframe should have two columns, `'census_block'` and `'positive'`, recording the number of positive cases in each census block.

**For examples for data with multiple dates, please check the end of this notebook**

In [None]:
#example of an input
source = pd.DataFrame({'census_block':['061130106021000', '061130106021001', '061130106021008', '061130106021009'], 
                       'positive':[20, 10, 5, 8]})
source

In [None]:
collection_points = graph.find_collection_points(source)
collection_points.head()

In [None]:
#We can see that some census blocks are not collected
values = collection_points.to_numpy()[:, 1:]
collection_points[values.sum(axis=1) > 0]

## How many units of waste water are collected at each collection point?

Some collection points covers no population?

In [None]:
#load population data
population = pd.read_csv('./us2019_yolo.csv', dtype={'block_fips':str})
population.rename(columns={'block_fips':'census_block', 'pop2019':'positive'}, inplace=True)
population.head()

In [None]:
collection = graph.find_collection_points(population)
collection.drop(columns = ['census_block', 'total_infection'], inplace=True)
collection

In [None]:
collection.to_csv('collection.csv')

In [None]:
collection.sum(axis=0)

In [None]:
graph.plot_connection(graph.name_to_coor['N13-045-1'], 'upstream')

In [None]:
graph.plot_connection(graph.name_to_coor['O15-078-1'], 'upstream')

In [None]:
graph.plot_connection(graph.name_to_coor['P15-027-1'], 'upstream')

In [None]:
graph.plot_connection(graph.name_to_coor['O16-041-1'], 'upstream')

## How is waste water from each census block collected?

We divide the census blocks into four types:
  
  * All waste water collected by one collection point
  * All waste water collected, but by more than one collection points
  * Not all waste water collected
  * No waste water collected

In [None]:
population['positive'] = 1
weights = graph.find_collection_points(population)
weights.set_index('census_block', inplace=True)
weights.drop(columns = ['total_infection'], inplace=True)
values = weights.to_numpy()

In [None]:
total_collected = values.sum(axis=1)
collected_by = (values > 0).sum(axis=1)

In [None]:
#Type 1. All waste water collected by one collection point
weights[np.logical_and(total_collected == 1, collected_by == 1)]

In [None]:
#Type 2. All waste water collected, but by more than one collection points
weights[np.logical_and(total_collected >= 1, collected_by > 1)]

In [None]:
#Type 3. Not all waste water collected
weights[np.logical_and(total_collected < 1, 0 < total_collected)]

In [None]:
#Type 4. No waste water collected
weights[total_collected == 0]

In [None]:
weights.to_csv('weights.csv')

## Example for multiple dates

In [None]:
graph = waste_water_processor('./')

In [None]:
# You might need to change the location to the HDT file here
data = graph.process_HDT_data('./FILENAME.csv')

In [None]:
processed = graph.find_collection_points(data)

In [None]:
processed.head()

In [None]:
processed.shape

In [None]:
processed.to_csv('OUTPUT_FILENAME.csv')

In [None]:
data_all_tests = graph.process_HDT_data_all_tests('./FILENAME.csv')

In [None]:
processed_all_tests = graph.find_collection_points(data_all_tests)

In [None]:
processed_all_tests.to_csv('OUTPUT_FILENAME.csv')