In [1]:
import pandas as pd
import numpy as np
from catboost import CatBoostRegressor, Pool
from sklearn.metrics import r2_score, mean_squared_error
from sklearn.model_selection import train_test_split

In [3]:
data = pd.read_csv("Data/world.csv")
data.columns = data.columns.str.strip().str.replace(' ', '_').str.lower()

data

Unnamed: 0,country,region,population,area_(sq._mi.),pop._density_(per_sq._mi.),coastline_(coast/area_ratio),net_migration,infant_mortality_(per_1000_births),gdp_($_per_capita),literacy_(%),phones_(per_1000),arable_(%),birthrate,deathrate
0,St Pierre & Miquelon,NORTHERN AMERICA,7026,242,290,4959,-486,754,6900,990,6832,1304,1352,683
1,Saint Helena,SUB-SAHARAN AFRICA,7502,413,182,1453,0,19,2500,970,2933,129,1213,653
2,Montserrat,LATIN AMER. & CARIB,9439,102,925,3922,0,735,3400,970,,20,1759,71
3,Tuvalu,OCEANIA,11810,26,4542,9231,0,2003,1100,,593,0,2218,711
4,Nauru,OCEANIA,13287,21,6327,14286,0,995,5000,,1430,0,2476,67
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
221,Brazil,LATIN AMER. & CARIB,188078227,8511965,221,009,-003,2961,7600,864,2253,696,1656,617
222,Indonesia,ASIA (EX. NEAR EAST),245452739,1919440,1279,285,0,356,3200,879,520,1132,2034,625
223,United States,NORTHERN AMERICA,298444215,9631420,310,021,341,65,37800,970,8980,1913,1414,826
224,India,ASIA (EX. NEAR EAST),1095351995,3287590,3332,021,-007,5629,2900,595,454,544,2201,818


In [5]:
data = data.drop(['country'],axis=1)
data

Unnamed: 0,region,population,area_(sq._mi.),pop._density_(per_sq._mi.),coastline_(coast/area_ratio),net_migration,infant_mortality_(per_1000_births),gdp_($_per_capita),literacy_(%),phones_(per_1000),arable_(%),birthrate,deathrate
0,NORTHERN AMERICA,7026,242,290,4959,-486,754,6900,990,6832,1304,1352,683
1,SUB-SAHARAN AFRICA,7502,413,182,1453,0,19,2500,970,2933,129,1213,653
2,LATIN AMER. & CARIB,9439,102,925,3922,0,735,3400,970,,20,1759,71
3,OCEANIA,11810,26,4542,9231,0,2003,1100,,593,0,2218,711
4,OCEANIA,13287,21,6327,14286,0,995,5000,,1430,0,2476,67
...,...,...,...,...,...,...,...,...,...,...,...,...,...
221,LATIN AMER. & CARIB,188078227,8511965,221,009,-003,2961,7600,864,2253,696,1656,617
222,ASIA (EX. NEAR EAST),245452739,1919440,1279,285,0,356,3200,879,520,1132,2034,625
223,NORTHERN AMERICA,298444215,9631420,310,021,341,65,37800,970,8980,1913,1414,826
224,ASIA (EX. NEAR EAST),1095351995,3287590,3332,021,-007,5629,2900,595,454,544,2201,818


In [6]:
print(data.head())

                                region  population  area_(sq._mi.)  \
0  NORTHERN AMERICA                           7026             242   
1  SUB-SAHARAN AFRICA                         7502             413   
2              LATIN AMER. & CARIB            9439             102   
3  OCEANIA                                   11810              26   
4  OCEANIA                                   13287              21   

  pop._density_(per_sq._mi.) coastline_(coast/area_ratio) net_migration  \
0                       29,0                        49,59         -4,86   
1                       18,2                        14,53             0   
2                       92,5                        39,22             0   
3                      454,2                        92,31             0   
4                      632,7                       142,86             0   

  infant_mortality_(per_1000_births)  gdp_($_per_capita) literacy_(%)  \
0                               7,54                690

In [7]:
data = data.replace(',', '.', regex=True)
data

Unnamed: 0,region,population,area_(sq._mi.),pop._density_(per_sq._mi.),coastline_(coast/area_ratio),net_migration,infant_mortality_(per_1000_births),gdp_($_per_capita),literacy_(%),phones_(per_1000),arable_(%),birthrate,deathrate
0,NORTHERN AMERICA,7026,242,29.0,49.59,-4.86,7.54,6900,99.0,683.2,13.04,13.52,6.83
1,SUB-SAHARAN AFRICA,7502,413,18.2,14.53,0,19,2500,97.0,293.3,12.9,12.13,6.53
2,LATIN AMER. & CARIB,9439,102,92.5,39.22,0,7.35,3400,97.0,,20,17.59,7.1
3,OCEANIA,11810,26,454.2,92.31,0,20.03,1100,,59.3,0,22.18,7.11
4,OCEANIA,13287,21,632.7,142.86,0,9.95,5000,,143.0,0,24.76,6.7
...,...,...,...,...,...,...,...,...,...,...,...,...,...
221,LATIN AMER. & CARIB,188078227,8511965,22.1,0.09,-0.03,29.61,7600,86.4,225.3,6.96,16.56,6.17
222,ASIA (EX. NEAR EAST),245452739,1919440,127.9,2.85,0,35.6,3200,87.9,52.0,11.32,20.34,6.25
223,NORTHERN AMERICA,298444215,9631420,31.0,0.21,3.41,6.5,37800,97.0,898.0,19.13,14.14,8.26
224,ASIA (EX. NEAR EAST),1095351995,3287590,333.2,0.21,-0.07,56.29,2900,59.5,45.4,54.4,22.01,8.18


In [9]:
X = data.drop(['gdp_($_per_capita)'],axis=1)
y = data['gdp_($_per_capita)']

x_train,x_test,y_train,y_test = train_test_split(X,y,test_size=0.2,random_state=42)


In [11]:
train_pool = Pool(x_train,y_train,cat_features = ['region'])
test_pool = Pool(x_test,y_test,cat_features=['region'])

model = CatBoostRegressor(iterations=2000,verbose=100,learning_rate=0.025,random_state=42)
model.fit(train_pool)

y_pred = model.predict(test_pool)

0:	learn: 9850.5420663	total: 226ms	remaining: 7m 32s
100:	learn: 3652.3932476	total: 10s	remaining: 3m 8s
200:	learn: 2406.6675029	total: 20.1s	remaining: 3m
300:	learn: 1802.0198223	total: 30.3s	remaining: 2m 51s
400:	learn: 1394.1681182	total: 40.5s	remaining: 2m 41s
500:	learn: 1050.3166969	total: 50.7s	remaining: 2m 31s
600:	learn: 835.3112314	total: 1m	remaining: 2m 21s
700:	learn: 661.9526605	total: 1m 11s	remaining: 2m 11s
800:	learn: 532.4885564	total: 1m 21s	remaining: 2m 1s
900:	learn: 430.1334151	total: 1m 31s	remaining: 1m 51s
1000:	learn: 351.7193749	total: 1m 41s	remaining: 1m 41s
1100:	learn: 293.8739455	total: 1m 52s	remaining: 1m 31s
1200:	learn: 241.4350884	total: 2m 2s	remaining: 1m 21s
1300:	learn: 201.4382441	total: 2m 12s	remaining: 1m 11s
1400:	learn: 169.6043729	total: 2m 22s	remaining: 1m 1s
1500:	learn: 145.1210559	total: 2m 33s	remaining: 50.9s
1600:	learn: 124.5987355	total: 2m 43s	remaining: 40.7s
1700:	learn: 106.2852413	total: 2m 53s	remaining: 30.5s
180

In [12]:
r2 = r2_score(y_test,y_pred)
mse = mean_squared_error(y_test,y_pred)

print(f"R2: {r2}, MSE : {mse}")

R2: 0.7743115875565314, MSE : 22719341.0715578


In [13]:
full_pool = Pool(X,y,cat_features=['region'])
model.fit(full_pool)

0:	learn: 9874.8566958	total: 86.6ms	remaining: 2m 53s
100:	learn: 3690.8284857	total: 10.2s	remaining: 3m 11s
200:	learn: 2484.1206642	total: 20.5s	remaining: 3m 3s
300:	learn: 1902.2630972	total: 30.9s	remaining: 2m 54s
400:	learn: 1492.1672044	total: 41.3s	remaining: 2m 44s
500:	learn: 1151.2924616	total: 51.7s	remaining: 2m 34s
600:	learn: 951.7928894	total: 1m 1s	remaining: 2m 24s
700:	learn: 806.8990943	total: 1m 12s	remaining: 2m 13s
800:	learn: 672.3066419	total: 1m 22s	remaining: 2m 3s
900:	learn: 558.9786037	total: 1m 32s	remaining: 1m 53s
1000:	learn: 464.6079213	total: 1m 42s	remaining: 1m 42s
1100:	learn: 395.2091468	total: 1m 53s	remaining: 1m 32s
1200:	learn: 338.9382925	total: 2m 3s	remaining: 1m 22s
1300:	learn: 288.4771348	total: 2m 13s	remaining: 1m 11s
1400:	learn: 248.4014890	total: 2m 23s	remaining: 1m 1s
1500:	learn: 213.7462536	total: 2m 34s	remaining: 51.3s
1600:	learn: 184.9955834	total: 2m 44s	remaining: 41s
1700:	learn: 157.6275133	total: 2m 54s	remaining: 3

<catboost.core.CatBoostRegressor at 0x19b98a57770>

In [14]:
model.save_model("GDP_model.cbm")