# 09_02: Fitting models

In [1]:
import math
import collections
import dataclasses
import datetime

import numpy as np
import pandas as pd
import matplotlib.pyplot as pp

In [2]:
gm = pd.read_csv('gapminder.csv', dtype_backend='pyarrow')

In [None]:
gm.info()

In [4]:
gdata = gm[gm.year == 1985].copy()

In [5]:
gdata['region'] = gdata['region'].astype('category')

In [None]:
gdata.region.dtype

In [7]:
continent = gdata.region.map({'Africa': 'skyblue', 'Europe': 'gold', 'America': 'palegreen', 'Asia': 'coral', 'Oceania': 'teal'})
population = 1e-6 * gdata.population

def plotbabies():    
    gdata.plot.scatter('age5_surviving', 'babies_per_woman', c=continent, s=population,
                       linewidths=0.5, edgecolor='black', alpha=0.6, figsize=(5,3.5)); # alpha adds some transparency

In [None]:
plotbabies()

In [9]:
import statsmodels
import statsmodels.formula.api as smf

In [10]:
constantmodel = smf.ols(formula='babies_per_woman ~ 1', data=gdata)

In [11]:
constantfit = constantmodel.fit()

In [None]:
constantfit.params

In [None]:
gdata.babies_per_woman.mean()

In [None]:
constantfit.predict(gdata)

In [None]:
plotbabies()
pp.scatter(gdata.age5_surviving, constantfit.predict(gdata), color=continent,
           s=50, marker='.', edgecolor='k', linewidth=0.5); # small square markers with a think black edge

In [16]:
groupfit = smf.ols(formula='babies_per_woman ~ 1 + region', data=gdata).fit()

In [None]:
groupfit.params

In [18]:
groupfit2 = smf.ols(formula='babies_per_woman ~ -1 + region', data=gdata).fit()

In [None]:
groupfit2.params

In [None]:
gdata.groupby('region', observed=True).babies_per_woman.mean()

In [None]:
plotbabies()
pp.scatter(gdata.age5_surviving, groupfit2.predict(gdata), color=continent,
           s=50, marker='.', ec='k', lw=0.5);

In [22]:
survivingfit = smf.ols(formula='babies_per_woman ~ -1 + region + age5_surviving', data=gdata).fit()

In [None]:
survivingfit.params

In [None]:
plotbabies()
pp.scatter(gdata.age5_surviving, survivingfit.predict(gdata), color=continent,
           s=50, marker='.', ec='k', lw=0.5);

In [25]:
survivingfit2 = smf.ols(formula='babies_per_woman ~ -1 + region + age5_surviving:region', data=gdata).fit()

In [None]:
survivingfit2.params

In [None]:
plotbabies()
pp.scatter(gdata.age5_surviving, survivingfit2.predict(gdata), color=continent,
           s=50, marker='.', ec='k', lw=0.5);

In [28]:
twovariablefit = smf.ols(formula='babies_per_woman ~ -1 + region + age5_surviving:region + population', data=gdata).fit()

In [None]:
twovariablefit.params

In [None]:
plotbabies()
pp.scatter(gdata.age5_surviving, twovariablefit.predict(gdata), color=continent,
           s=50, marker='.', ec='k', lw=0.5);