In [2]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV, KFold
from joblib import dump, load
import geopandas as gpd

import numpy as np

In [3]:
pipe = load('./models/cv_rf_2.joblib') 

In [4]:
train = gpd.read_file( "./data/train_data_final.geojson")
test =  gpd.read_file("./data/test_data_final.geojson")

In [5]:
X = train.loc[:,'NDVI_2000':'NDVI_2019']
y = train['label_0']
X_test = test.loc[:,'NDVI_2000':'NDVI_2019']

In [6]:
pipe

GridSearchCV(cv=KFold(n_splits=5, random_state=123, shuffle=True),
             estimator=Pipeline(steps=[('rf', RandomForestClassifier())]),
             n_jobs=16,
             param_grid={'rf__max_depth': [1, 10, 25, 30, 50, 75],
                         'rf__max_features': array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
       18, 19], dtype=int32),
                         'rf__n_estimators': [100, 200, 300, 400, 500]},
             scoring='f1_macro', verbose=1)

In [7]:
best_model = pipe.best_estimator_

In [8]:
print(classification_report(test.label_0,best_model.predict(X_test)))


              precision    recall  f1-score   support

           0       0.72      0.59      0.65       877
           1       0.87      0.90      0.89       900
           2       0.70      0.81      0.75       859

    accuracy                           0.77      2636
   macro avg       0.76      0.77      0.76      2636
weighted avg       0.77      0.77      0.76      2636



In [10]:
test['predictions'] = best_model.predict(X_test)

In [11]:
test

Unnamed: 0,id_0,label_0,NDVI_2000,NDVI_2001,NDVI_2002,NDVI_2004,NDVI_2005,NDVI_2006,NDVI_2007,NDVI_2008,...,NDVI_2012,NDVI_2013,NDVI_2014,NDVI_2015,NDVI_2016,NDVI_2017,NDVI_2018,NDVI_2019,geometry,predictions
0,0_0,1,0.000023,0.000036,0.000042,0.000045,0.000046,0.000045,0.000038,0.000044,...,0.000044,0.000044,0.000028,0.000025,0.000039,0.000037,0.000036,0.000040,POINT (-63.00671 -25.26624),1
1,1_0,1,0.000026,0.000042,0.000047,0.000048,0.000041,0.000045,0.000044,0.000050,...,0.000036,0.000043,0.000033,0.000029,0.000042,0.000036,0.000039,0.000038,POINT (-62.83379 -24.92263),1
2,2_0,1,0.000048,0.000045,0.000050,0.000056,0.000053,0.000053,0.000050,0.000052,...,0.000055,0.000048,0.000043,0.000040,0.000054,0.000051,0.000046,0.000037,POINT (-60.88444 -24.20174),1
3,3_0,1,0.000025,0.000038,0.000043,0.000044,0.000041,0.000050,0.000043,0.000045,...,0.000048,0.000044,0.000043,0.000026,0.000043,0.000043,0.000043,0.000040,POINT (-63.72985 -27.49855),1
4,4_0,1,0.000037,0.000039,0.000043,0.000040,0.000037,0.000035,0.000034,0.000034,...,0.000041,0.000040,0.000045,0.000037,0.000038,0.000038,0.000035,0.000041,POINT (-62.48569 -28.26212),2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2631,2695_0,2,0.000019,0.000030,0.000036,0.000034,0.000035,0.000034,0.000032,0.000031,...,0.000036,0.000034,0.000022,0.000026,0.000036,0.000036,0.000025,0.000032,POINT (-62.74395 -22.75545),0
2632,2696_0,2,0.000027,0.000024,0.000032,0.000033,0.000032,0.000033,0.000030,0.000031,...,0.000032,0.000031,0.000035,0.000027,0.000025,0.000034,0.000030,0.000034,POINT (-65.92174 -28.27560),2
2633,2697_0,2,0.000036,0.000036,0.000037,0.000036,0.000036,0.000038,0.000041,0.000041,...,0.000041,0.000041,0.000038,0.000038,0.000040,0.000049,0.000040,0.000037,POINT (-66.76841 -33.28820),0
2634,2698_0,2,0.000037,0.000037,0.000041,0.000031,0.000034,0.000029,0.000026,0.000024,...,0.000028,0.000025,0.000031,0.000035,0.000035,0.000032,0.000034,0.000037,POINT (-62.35319 -28.85950),2
