In [1]:
%load_ext autoreload
%autoreload 2

from os.path import join
from os.path import abspath
import itertools

import numpy as np
import pandas as pd
from causalnex.structure.notears import from_pandas
from causalnex.plots import plot_structure
from causalnex.network import BayesianNetwork
from causalnex.evaluation import classification_report, roc_auc
from sklearn.model_selection import KFold

from data_utils import read_tweets, read_covid, reformat_dataframe, add_missing_countries
from sentiment import add_lang, add_sentiment
from train_utils import discretize_df
from train import train_bn
from feature_extraction import get_feature_matrix
from inference import marginal_probs
from eval_utils import save_logs
from configs import config as cf

## 1 - Read official COVID-19 (https://github.com/CSSEGISandData/COVID-19) and collected tweet data 

In [2]:
infected = reformat_dataframe(read_covid(cf.INFECTED_PATH))
deaths = reformat_dataframe(read_covid(cf.DEATHS_PATH))
infected.head(5)

Unnamed: 0_level_0,2020-01-22,2020-01-23,2020-01-24,2020-01-25,2020-01-26,2020-01-27,2020-01-28,2020-01-29,2020-01-30,2020-01-31,...,2020-03-09,2020-03-10,2020-03-11,2020-03-12,2020-03-13,2020-03-14,2020-03-15,2020-03-16,2020-03-17,2020-03-18
Country,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Afghanistan,0,0,0,0,0,0,0,0,0,0,...,3,4,5,6,7,8,11,16,20,22
Albania,0,0,0,0,0,0,0,0,0,0,...,1,4,8,15,23,31,38,44,49,55
Algeria,0,0,0,0,0,0,0,0,0,0,...,19,20,20,21,23,29,37,46,54,63
Andorra,0,0,0,0,0,0,0,0,0,0,...,1,1,1,1,1,1,1,1,14,27
Angola,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [3]:
tweets = read_tweets(cf.TWEETS_PATH)
add_missing_countries(tweets) # add country info if missing (inplace)
tweets = add_lang(tweets)
tweets = add_sentiment(tweets)

Tweets dataframe shape=(954902, 19)
439210 tweets do not have country information!
Loading formatted geocoded file...
7605 tweets that do not have country information will be discarded!


### Groupby days

In [4]:
# calculate daily counts
counts_per_country = tweets.groupby(['date', 'Location']).count().reset_index()[['date', 'Location', 'id']]
counts_per_country.rename(columns={'id': 'twitter_activity'}, inplace=True)

# calculate daily average sentiment
counts_per_country['sentiment'] = tweets.groupby(['date', 'Location']).mean().reset_index()[['sentiment']].fillna(0, inplace=False)

# calculate features
features = get_feature_matrix(counts_per_country, infected, mode='infected', country_list=cf.countries)
features[['deaths', 'deaths_new', 'deaths_perc_change']] = get_feature_matrix(
    counts_per_country, deaths, mode='deaths', country_list=cf.countries
)[['deaths', 'deaths_new', 'deaths_perc_change']]
features.drop(['date'], axis=1, inplace=True)
print('Feature matrix shape:{}'.format(features.shape))

100%|██████████| 12/12 [00:28<00:00,  2.36s/it]
100%|██████████| 12/12 [00:28<00:00,  2.35s/it]

Feature matrix shape:(684, 12)





In [5]:
features.tail()

Unnamed: 0,infected,twitter_activity,sentiment,infected_new,infected_perc_change,restriction,over_65,twitter_usage,single_household,deaths,deaths_new,deaths_perc_change
52,130.28413,54.054054,5.970217,22.522523,20.900322,0,19.813,10.7,43.9,0.0,0.0,0.0
53,145.183645,56.652807,8.62061,14.899515,11.43617,0,19.813,10.7,43.9,0.17325,0.17325,100.0
54,152.633403,68.780319,-15.013248,7.449757,5.131265,0,19.813,10.7,43.9,0.3465,0.17325,100.0
55,163.548164,66.181566,-4.621465,10.914761,7.150965,0,19.813,10.7,43.9,0.519751,0.17325,50.0
56,177.581428,65.835066,8.323642,14.033264,8.580508,0,19.813,10.7,43.9,0.693001,0.17325,33.333333


## 2 - Structure Learning with NOTEARS

In [14]:
# Apply NOTEARS algorithm
graph = from_pandas(features, 
                 max_iter=200,
                 tabu_child_nodes=cf.tabu_child_nodes,
                 tabu_parent_nodes=cf.tabu_parent_nodes,
                 tabu_edges=cf.tabu_edges,
                 w_threshold=cf.edge_threshold)

plot = plot_structure(graph)
plot.draw(abspath(join(cf.LOGS_DIR, 'structure_graph.png')))

## 3 - Causal Inference

###  Leave-One-Country-Out Cross Validation - Fitting the Conditional Distribution of the Bayesian Network

In [17]:
# map features to to "high" or "low"
fit_feats = discretize_df(features)

# LOCO
auc_scores = []
countries = []
for train_country_ind, test_country_ind in KFold(n_splits=cf.n_countries).split(range(cf.n_countries)):
    train_indices = cf.splits[train_country_ind]
    test_indices = cf.splits[test_country_ind]
    train_indices = list(itertools.chain.from_iterable(train_indices))
    test_indices = list(itertools.chain.from_iterable(test_indices))
    train = fit_feats.iloc[train_indices]
    test = fit_feats.iloc[test_indices]
    
    bn = train_bn(train, graph)
    _, auc = roc_auc(bn, test, 'twitter_activity')
    auc_scores.append(auc)
    countries.append(cf.countries[test_country_ind[0]])
    print('Country as test = {}, AUC = {:.3f}'.format(cf.countries[test_country_ind[0]], auc))
print('\tMean AUC = {:.3f}'.format(np.mean(auc_scores)))
save_logs(countries, auc_scores, 'aucs')

Country as test = Italy, AUC = 0.976
Country as test = Spain, AUC = 0.766
Country as test = Germany, AUC = 0.992
Country as test = France, AUC = 0.776
Country as test = Switzerland, AUC = 0.789
Country as test = United Kingdom, AUC = 0.684
Country as test = Netherlands, AUC = 0.746
Country as test = Norway, AUC = 0.907
Country as test = Austria, AUC = 0.798
Country as test = Belgium, AUC = 0.728
Country as test = Sweden, AUC = 0.998
Country as test = Denmark, AUC = 0.831
	Mean AUC = 0.833


### Inspect Marginal Likelihoods After Observations

In [19]:
bn = BayesianNetwork(graph)
bn = bn.fit_node_states(fit_feats)
bn = bn.fit_cpds(fit_feats, method='BayesianEstimator', bayes_prior='K2')
bn.cpds['twitter_activity']

deaths_new,high,high,high,high,high,high,high,high,low,low,low,low,low,low,low,low
infected_new,high,high,high,high,low,low,low,low,high,high,high,high,low,low,low,low
restriction,0,0,1,1,0,0,1,1,0,0,1,1,0,0,1,1
twitter_usage,high,low,high,low,high,low,high,low,high,low,high,low,high,low,high,low
twitter_activity,Unnamed: 1_level_4,Unnamed: 2_level_4,Unnamed: 3_level_4,Unnamed: 4_level_4,Unnamed: 5_level_4,Unnamed: 6_level_4,Unnamed: 7_level_4,Unnamed: 8_level_4,Unnamed: 9_level_4,Unnamed: 10_level_4,Unnamed: 11_level_4,Unnamed: 12_level_4,Unnamed: 13_level_4,Unnamed: 14_level_4,Unnamed: 15_level_4,Unnamed: 16_level_4
high,0.857143,0.125,0.8,0.444444,0.5,0.25,0.5,0.5,0.894737,0.388889,0.818182,0.833333,0.225806,0.119601,0.5,0.5
low,0.142857,0.875,0.2,0.555556,0.5,0.75,0.5,0.5,0.105263,0.611111,0.181818,0.166667,0.774194,0.880399,0.5,0.5


#### Expected Marginal Likelihoods

##### Lower daily infections/deaths increases the probability of percentage increase being low. This is expected simply due to mathematical definition of percentage.

In [20]:
_ = marginal_probs(bn, 'deaths_perc_change', {'deaths_new': 'high'})
_ = marginal_probs(bn, 'deaths_perc_change', {'deaths_new': 'low'})

Marginal probabilities of "deaths_perc_change" | {'deaths_new': 'high'} = {'high': 0.4446071052335914, 'low': 0.5553928947664085}
Marginal probabilities of "deaths_perc_change" | {'deaths_new': 'low'} = {'high': 0.05965902319764994, 'low': 0.9403409768023501}


#### Insightful Marginal Likelihoods That Support Domain/Expert Knowledge and Literature

In [21]:
_ = marginal_probs(bn, 'infected', {'over_65': 'high'})
_ = marginal_probs(bn, 'infected', {'over_65': 'low'})

Marginal probabilities of "infected" | {'over_65': 'high'} = {'high': 0.1882654368732316, 'low': 0.8117345631267684}
Marginal probabilities of "infected" | {'over_65': 'low'} = {'high': 0.15213984204244294, 'low': 0.8478601579575571}


In [22]:
_ = marginal_probs(bn, 'deaths_perc_change', {'over_65': 'high'})
_ = marginal_probs(bn, 'deaths_perc_change', {'over_65': 'low'})

Marginal probabilities of "deaths_perc_change" | {'over_65': 'high'} = {'high': 0.08110649658975573, 'low': 0.9188935034102442}
Marginal probabilities of "deaths_perc_change" | {'over_65': 'low'} = {'high': 0.07511992628516961, 'low': 0.9248800737148305}


In [23]:
_ = marginal_probs(bn, 'deaths_perc_change', {'single_household': 'high'})
_ = marginal_probs(bn, 'deaths_perc_change', {'single_household': 'low'})

Marginal probabilities of "deaths_perc_change" | {'single_household': 'high'} = {'high': 0.06778127725480429, 'low': 0.9322187227451956}
Marginal probabilities of "deaths_perc_change" | {'single_household': 'low'} = {'high': 0.09137034138746945, 'low': 0.9086296586125305}


In [24]:
_ = marginal_probs(bn, 'infected', {'single_household': 'high', 'over_65': 'low'})
_ = marginal_probs(bn, 'infected', {'single_household': 'low',  'over_65': 'high'})

Marginal probabilities of "infected" | {'single_household': 'high', 'over_65': 'low'} = {'high': 0.17826086956521736, 'low': 0.8217391304347825}
Marginal probabilities of "infected" | {'single_household': 'low', 'over_65': 'high'} = {'high': 0.2413793103448276, 'low': 0.7586206896551724}


#### Insightful Marginal Likelihoods About Twitter Activity

In [25]:
_ = marginal_probs(bn, 'twitter_activity', {'deaths_new': 'high', 'infected_new': 'high'})
_ = marginal_probs(bn, 'twitter_activity', {'deaths_new': 'low', 'infected_new': 'low'})

Marginal probabilities of "twitter_activity" | {'deaths_new': 'high', 'infected_new': 'high'} = {'high': 0.495659794530057, 'low': 0.5043402054699431}
Marginal probabilities of "twitter_activity" | {'deaths_new': 'low', 'infected_new': 'low'} = {'high': 0.1841544830189222, 'low': 0.8158455169810779}


In [26]:
_ = marginal_probs(bn, 'twitter_activity', {'deaths_new': 'high', 'infected_new': 'high', 'twitter_usage': 'high', 'restriction': 1})
_ = marginal_probs(bn, 'twitter_activity', {'deaths_new': 'low', 'infected_new': 'low', 'twitter_usage': 'low', 'restriction': 0})

Marginal probabilities of "twitter_activity" | {'deaths_new': 'high', 'infected_new': 'high', 'twitter_usage': 'high', 'restriction': 1} = {'high': 0.8, 'low': 0.2}
Marginal probabilities of "twitter_activity" | {'deaths_new': 'low', 'infected_new': 'low', 'twitter_usage': 'low', 'restriction': 0} = {'high': 0.11960132890365449, 'low': 0.8803986710963455}


In [27]:
_ = marginal_probs(bn, 'twitter_activity', {'restriction': 1})
_ = marginal_probs(bn, 'twitter_activity', {'restriction': 0})

Marginal probabilities of "twitter_activity" | {'restriction': 1} = {'high': 0.535037907228066, 'low': 0.46496209277193384}
Marginal probabilities of "twitter_activity" | {'restriction': 0} = {'high': 0.23228723227692152, 'low': 0.7677127677230785}


In [28]:
_ = marginal_probs(bn, 'twitter_activity', {'twitter_usage': 'high'})
_ = marginal_probs(bn, 'twitter_activity', {'twitter_usage': 'low'})

Marginal probabilities of "twitter_activity" | {'twitter_usage': 'high'} = {'high': 0.3189333421537977, 'low': 0.6810666578462024}
Marginal probabilities of "twitter_activity" | {'twitter_usage': 'low'} = {'high': 0.16682484309633536, 'low': 0.8331751569036646}


#### Insightful Marginal Likelihoods About Overall Sentiment

In [29]:
_ = marginal_probs(bn, 'sentiment', {'deaths': 'high'})
_ = marginal_probs(bn, 'sentiment', {'deaths': 'low'})

Marginal probabilities of "sentiment" | {'deaths': 'high'} = {'neg': 0.34426542253093173, 'pos': 0.6557345774690683}
Marginal probabilities of "sentiment" | {'deaths': 'low'} = {'neg': 0.2904481188092652, 'pos': 0.7095518811907349}


In [31]:
_ = marginal_probs(bn, 'sentiment', {'deaths_new': 'high'})
_ = marginal_probs(bn, 'sentiment', {'deaths_new': 'low'})

Marginal probabilities of "sentiment" | {'deaths_new': 'high'} = {'neg': 0.6241951146107049, 'pos': 0.375804885389295}
Marginal probabilities of "sentiment" | {'deaths_new': 'low'} = {'neg': 0.27732727932837664, 'pos': 0.7226727206716234}


In [32]:
_ = marginal_probs(bn, 'sentiment', {'deaths_new': 'high', 'deaths': 'high'})
_ = marginal_probs(bn, 'sentiment', {'deaths_new': 'low', 'deaths': 'low'})

Marginal probabilities of "sentiment" | {'deaths_new': 'high', 'deaths': 'high'} = {'neg': 0.3226579510355828, 'pos': 0.6773420489644172}
Marginal probabilities of "sentiment" | {'deaths_new': 'low', 'deaths': 'low'} = {'neg': 0.2732286391567292, 'pos': 0.7267713608432709}


In [33]:
_ = marginal_probs(bn, 'sentiment', {'restriction': 1})
_ = marginal_probs(bn, 'sentiment', {'restriction': 0})

Marginal probabilities of "sentiment" | {'restriction': 1} = {'neg': 0.5014243502052969, 'pos': 0.49857564979470315}
Marginal probabilities of "sentiment" | {'restriction': 0} = {'neg': 0.2859699394915012, 'pos': 0.7140300605084988}
