# PEEWEE ORM

Here I will demonstrate how to use PEEWEE ORM to get the world bank data into a database. 

In [1]:
from peewee import * 
from playhouse.postgres_ext import PostgresqlDatabase

import pandas as pd

from tqdm import tqdm
from os.path import join 

import pandas as pd 

from time import sleep

from IPython import display
import matplotlib.pyplot as pl 
import seaborn as sns

% matplotlib inline

In [2]:
db = PostgresqlDatabase('ads')

class BaseModel(Model):
    class Meta:
        database = db

class Country(BaseModel): 
    id = PrimaryKeyField()
    
    name = CharField()
    
class LifeExpectancy(BaseModel):
    id = PrimaryKeyField()
    
    country = ForeignKeyField(
        Country, 
        index=True, 
        null=False, 
        related_name='life_expectancy', 
        on_delete='cascade'
    )
    
    year = IntegerField(null=False)
    value = FloatField(null=True)
    
class Population(BaseModel):
    id = PrimaryKeyField()
    
    country = ForeignKeyField(
        Country, 
        index=True, 
        null=False, 
        related_name='population', 
        on_delete='cascade'
    )
    
    year = IntegerField(null=False)
    value = FloatField(null=True)
    
class GDP(BaseModel):
    id = PrimaryKeyField()
    
    country = ForeignKeyField(
        Country, 
        index=True, 
        null=False, 
        related_name='gdp', 
        on_delete='cascade'
    )
    
    year = IntegerField(null=False)
    value = FloatField(null=True)
    
models = [
    Country, 
    LifeExpectancy, 
    Population, 
    GDP, 
]

In [5]:
Population.sqlall()

['CREATE TABLE "population" ("id" SERIAL NOT NULL PRIMARY KEY, "country_id" INTEGER NOT NULL, "year" INTEGER NOT NULL, "value" REAL, FOREIGN KEY ("country_id") REFERENCES "country" ("id") ON DELETE cascade)',
 'CREATE INDEX "population_country_id" ON "population" ("country_id")']

In [None]:
def rearrange_dataframe(df, indicator_name):
    country = 'Country Name'
    years = [c for c in df.columns if c[0] == '1' or c[0] == '2']
    df = pd.melt(df[[country] + years], id_vars=country, var_name='year')
    df.rename(columns={'value': indicator_name}, inplace=True)
    return df

db.drop_tables(models, safe=True, cascade=True)
db.create_tables(models, safe=True)

sources = [
    ('API_SP.DYN.LE00.IN_DS2_en_csv_v2', 'Life expectency at birth', LifeExpectancy),
    ('API_SP.POP.TOTL_DS2_en_csv_v2',    'Total population',         Population    ),
    ('API_NY.GDP.PCAP.CD_DS2_en_csv_v2', 'GDP per capita',           GDP           ),
]

for source, key, model in sources: 
    df = rearrange_dataframe(pd.read_csv(join(source, '{}.csv'.format(source)), skiprows=4), key)
    
    for ri, row in tqdm(df.iterrows()):
        country, inserted = Country.get_or_create(name=row['Country Name'])
        
        model.create(
            country=country, 
            year=int(row['year']), 
            value=row[key]
        )
        
        # For faster insersion, the insert_many method may be used. 

In [None]:
GDP.select().count()

In [None]:
q = GDP.select().limit(10)

print(q)

lq = list(q)
print(lq)
print(lq[0])
print(lq[0].country.name)

In [None]:
c1 = Country.select().where(Country.id == 110).get()

c1.name

In [None]:
list(c1.GDP.dicts())

In [None]:
list(Country.select(
    Country.name, 
    GDP.year,
    GDP.value,
    Population.value,
).where(
    Country.id == 110
).join(
    GDP, 
    on=Country.id == GDP.country_id
).join(
    Population, 
    on=Population.id == GDP.country_id
).tuples())

In [None]:
list(Country.select(
    Country.name, 
    GDP.year,
    GDP.value.alias('gdp'),
    Population.value.alias('pop'),
).where(
    Country.id == 110
).join(
    GDP, 
    on=Country.id == GDP.country_id
).join(
    Population, 
    on=Population.id == GDP.country_id
).dicts())

In [None]:
list(Country.select(
    Country.name, 
    GDP.year,
    GDP.value.alias('gdp'),
    Population.value.alias('pop'),
    LifeExpectancy.value.alias('le'),
).where(
    Country.id == 110
).join(
    GDP, 
    on=Country.id == GDP.country_id
).join(
    Population, 
    on=Population.id == GDP.country_id
).join(
    LifeExpectancy, 
    on=LifeExpectancy.id == GDP.country_id
).limit(10).dicts())

In [None]:
df = pd.DataFrame(list(Country.select(
    Country.name, 
    GDP.year,
    GDP.value.alias('gdp'),
    Population.value.alias('pop'),
    LifeExpectancy.value.alias('le'),
).where(
    Country.id == 110
).join(
    GDP, 
    on=Country.id == GDP.country_id
).join(
    Population, 
    on=Population.id == GDP.country_id
).join(
    LifeExpectancy, 
    on=LifeExpectancy.id == GDP.country_id
).dicts()))
df.set_index('year', inplace=True)
del df['name']

In [None]:
import matplotlib.pyplot as pl 
% matplotlib inline

df.plot(subplots=True)

In [None]:
df = pd.DataFrame(list(GDP.select(
    GDP.year,
    GDP.value.alias('gdp'),
    Population.value.alias('pop'),
    LifeExpectancy.value.alias('le'),
).join(
    Population, 
    on=(GDP.year == Population.year) & (GDP.country_id == Population.country_id)
).join(
    LifeExpectancy, 
    on=(GDP.year == LifeExpectancy.year) & (GDP.country_id == LifeExpectancy.country_id)
).join(
    Country, 
    on=GDP.country_id == Country.id
).where(
    Country.name == 'United Kingdom'
).dicts()))

df.set_index('year', inplace=True)
df.head()

In [None]:
df.plot(subplots=True, figsize=(10, 10))

In [None]:
df = pd.DataFrame(list(GDP.select(
    Country.name.alias('country'),
    GDP.year,
    GDP.value.alias('gdp'),
    Population.value.alias('pop'),
    LifeExpectancy.value.alias('le'),
).join(
    Population, 
    on=(GDP.year == Population.year) & (GDP.country_id == Population.country_id)
).join(
    LifeExpectancy, 
    on=(GDP.year == LifeExpectancy.year) & (GDP.country_id == LifeExpectancy.country_id)
).join(
    Country, 
    on=GDP.country_id == Country.id
).dicts()))

df.head()

In [None]:
countries = {
    'United Kingdom', 
    'France', 
    'Germany', 
}

keys = ('pop', 'le', 'gdp')

fig, axes = pl.subplots(3, 1, figsize=(10, 10))

for country, group in df.groupby('country'): 
    if country in countries: 
        for ax, key in zip(axes, keys):
            ax.plot(group.year, group[key], label=country)
            
for ax, key in zip(axes, keys): 
    pl.sca(ax)
    pl.title(key)
    pl.legend()

In [None]:
def pw_scatter(df, year):
    current_palette = sns.color_palette()
    alpha = 0.25
    for i, c in enumerate(countries):
        country = Country.select().where(Country.name == c).get()
        gdp = GDP.select().where(GDP.country_id == country, GDP.year == year).get()
        pop = Population.select().where(Population.country_id == country, Population.year == year).get()
        le = LifeExpectancy.select().where(LifeExpectancy.country_id == country, LifeExpectancy.year == year).get()
        ax.plot(gdp.value, le.value, marker='o', linestyle='', color=current_palette[i], ms=pop.value / 2.5e6, label=c, alpha=alpha)
    ax.set_xlabel(x)
    ax.set_ylabel(y)
    ax.set_xlim([0, df[x].max()])
    ax.set_ylim([df[y].min()*0.9, df[y].max()*1.1])
    ax.set_title(year)

# Animated version
fig, ax = pl.subplots(figsize=[15 ,10])
countries = ['United Kingdom', 'France', 'Germany']

for year in range(1960, 2017):
    pw_scatter(df_selected, year)

    if do_legend:
        lgnd = ax.legend()
        for i in range(len(countries)):
            lgnd.legendHandles[i]._legmarker.set_markersize(20)
            lgnd.legendHandles[i]._legmarker.set_alpha(1.0)
        do_legend = False

    display.clear_output(wait=True)
    display.display(pl.gcf())
    sleep(0.05)