In [1]:
import sys
sys.path.append('src')

from data.data_loader import DataLoader
from features.feature_engineering import FeatureEngineering
from models.model_evaluator import ModelEvaluator

import pandas as pd
import numpy as np
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error

## 1. 数据加载

加载Excel数据并划分训练集和测试集。

In [2]:
# 初始化数据加载器
file_path = r"E:\博士\1-课题\0-固废产生的研究\2-数据整理结果\1-MSW_CW_IW_HIW_training_data.xlsx"
data_loader = DataLoader(file_path)

# 加载数据
df = data_loader.load_data(
    sheet_name='msw_result',
    target_column='MSW',
    feature_columns=['Population', 'GDP PPP 2017', 'GDP PPP/capita 2017']
)

In [3]:
# 划分国家外部测试集
train_data, country_test_data = data_loader.split_data_by_countries(df,test_size= 0.15,random_state=123)
# 划分时间外部测试集
train_data, time_test_data = data_loader.split_data_by_time(train_data)
# 数据检查
data_loader.analyze_datasets(df)
data_loader.analyze_datasets(train_data)
data_loader.analyze_datasets(country_test_data)
data_loader.analyze_datasets(time_test_data)

train_data.to_csv('src/data/train_data.csv',index=False)
country_test_data.to_csv('src/data/country_test_data.csv',index=False)
time_test_data.to_csv('src/data/time_test_data.csv',index=False)



数据集统计信息:
总数据条数: 1982
国家总数: 69
包含的国家: Albania, Algeria, Argentina, Australia, Austria, Bangladesh, Belgium, Bosnia and Herzegovina, Brazil, Bulgaria, Canada, China, Colombia, Congo, Dem. Rep., Croatia, Cyprus, Czechia, Denmark, Egypt, Arab Rep., Estonia, Ethiopia, Finland, France, Germany, Greece, Hungary, Iceland, India, Indonesia, Iran, Islamic Rep., Iraq, Ireland, Italy, Japan, Korea, Rep., Latvia, Lithuania, Luxembourg, Malaysia, Malta, Mexico, Montenegro, Morocco, Netherlands, Nigeria, North Macedonia, Norway, Pakistan, Peru, Philippines, Poland, Portugal, Romania, Russian Federation, Saudi Arabia, Serbia, Slovak Republic, Slovenia, South Africa, Spain, Sweden, Switzerland, Tanzania, Thailand, Turkiye, Uganda, Ukraine, United Kingdom, United States

数据集统计信息:
总数据条数: 1476
国家总数: 59
包含的国家: Albania, Argentina, Australia, Austria, Bangladesh, Belgium, Bosnia and Herzegovina, Bulgaria, Canada, China, Colombia, Congo, Dem. Rep., Croatia, Cyprus, Denmark, Egypt, Arab Rep., Ethiopia, Finlan

## 2. 特征工程

创建时间特征和类别交互特征。

In [None]:
# 初始化特征工程
feature_engineering = FeatureEngineering()

# 处理训练集特征
train_data_processed, target_column = feature_engineering.fit_transform(
    train_data,
    target_column='MSW',
    categorical_columns=['Region', 'Income Group']
)

# 处理测试集特征
country_test_data_processed, _ = feature_engineering.transform(
    country_test_data,
    target_column='MSW',
)

# 处理测试集特征
time_test_data_processed, _ = feature_engineering.transform(
    time_test_data,
    target_column='MSW'
)

## 3. 模型训练和评估

训练多个模型并评估它们的性能。

In [5]:
# 初始化模型评估器
model_evaluator = ModelEvaluator()

# 设置实验
feature_cols = [col for col in train_data_processed.columns 
               if col not in ['MSW', 'Country Name', 'Year']]

train = train_data_processed[feature_cols]

In [6]:
model_evaluator.setup_experiment(
    train_size = 0.8,
    train_data=train,
    target_column=target_column,
    categorical_features=['Region', 'Income Group'],
)

Unnamed: 0,Description,Value
0,Session id,866
1,Target,MSW_None
2,Target type,Regression
3,Original data shape,"(1476, 23)"
4,Transformed data shape,"(1476, 32)"
5,Transformed train set shape,"(1180, 32)"
6,Transformed test set shape,"(296, 32)"
7,Numeric features,20
8,Categorical features,2
9,Preprocess,True


In [7]:
# 训练模型
models = model_evaluator.train_top_models(n_models=3)

Unnamed: 0,Model,MAE,MSE,RMSE,R2,RMSLE,MAPE,TT (Sec)
et,Extra Trees Regressor,10433600.5889,812123826342702.8,20095664.4396,0.6668,0.8075,1.4835,0.034
gbr,Gradient Boosting Regressor,10210872.5063,768041760913811.0,20379628.3067,0.6182,0.6997,0.8086,0.042
catboost,CatBoost Regressor,10331231.2455,1027153043303803.6,22623511.5519,0.6114,0.6846,0.8071,0.33
ada,AdaBoost Regressor,11409574.533,750101538061141.8,20736692.3738,0.5955,1.1594,3.0127,0.022
rf,Random Forest Regressor,11162658.9722,796650210101648.1,21542654.3273,0.5353,0.7209,0.9613,0.046
dt,Decision Tree Regressor,11570815.0209,804359922549666.1,22195458.3902,0.4783,0.8343,1.3692,0.014
lightgbm,Light Gradient Boosting Machine,11077607.4802,888794293528288.1,23701961.588,0.4638,0.752,1.3391,0.086
xgboost,Extreme Gradient Boosting,12158289.65,1015842111855001.6,25139749.5,0.4305,0.8178,1.3229,0.058
ridge,Ridge Regression,13840612.2218,807854953513190.5,24402466.6836,0.3082,1.2726,3.8683,0.422
omp,Orthogonal Matching Pursuit,14691545.996,875536813651081.5,24689124.626,0.2357,1.5539,4.9995,0.354


In [8]:
# 准备测试数据
X_test = country_test_data_processed[feature_cols]
y_test = country_test_data_processed[target_column]

# 评估所有模型组合
results_country = model_evaluator.evaluate_ensemble_combinations(X_test, y_test)

Unnamed: 0,Model,MAE,MSE,RMSE,R2,RMSLE,MAPE
0,Extra Trees Regressor,5062041.919,64253999093205.805,8015859.2236,0.8118,0.4277,0.438


Unnamed: 0,Model,MAE,MSE,RMSE,R2,RMSLE,MAPE
0,Gradient Boosting Regressor,3442371.3735,33468606910407.176,5785205.8659,0.902,0.3274,0.3162


Unnamed: 0,Model,MAE,MSE,RMSE,R2,RMSLE,MAPE
0,CatBoost Regressor,4613232.0859,60109691459546.97,7753044.0125,0.8239,0.4574,0.4627


Unnamed: 0_level_0,MAE,MSE,RMSE,R2,RMSLE,MAPE
Fold,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0,21300693.8622,2670682209187896.0,51678643.6469,-0.0485,1.1743,1.089
1,15356214.0066,984396030609005.0,31375086.145,0.7477,0.6219,0.6507
2,2898821.0493,18997611881735.117,4358624.999,0.9454,0.2403,0.1867
3,2807112.5457,15523403636233.367,3939975.0807,0.8903,1.2025,3.2951
4,3136282.9548,21591925295527.555,4646711.2344,0.8656,0.32,0.2802
Mean,9099824.8837,742238236122079.4,19199808.2212,0.6801,0.7118,1.1003
Std,7766828.4151,1034220393175367.2,19328879.957,0.37,0.4095,1.1425


Unnamed: 0,Model,MAE,MSE,RMSE,R2,RMSLE,MAPE
0,Voting Regressor,4200865.3757,46641206781928.26,6829436.7837,0.8634,0.3686,0.3612


Unnamed: 0_level_0,MAE,MSE,RMSE,R2,RMSLE,MAPE
Fold,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0,21428492.3678,2667034886403157.0,51643343.0986,-0.047,1.2226,1.3373
1,21987818.129,1779550303000219.0,42184716.4622,0.5438,0.6137,0.4758
2,2469013.443,13146756853574.889,3625845.6743,0.9622,0.2138,0.1672
3,2830992.7074,17827657759569.605,4222281.1085,0.8741,1.2195,3.29
4,2097440.1453,8581268030250.1455,2929380.1444,0.9466,0.2959,0.2398
Mean,10162751.3585,897228174409354.1,20921113.2976,0.6559,0.7131,1.102
Std,9431295.3258,1118512108310672.6,21436771.9771,0.3828,0.4357,1.1707


Unnamed: 0,Model,MAE,MSE,RMSE,R2,RMSLE,MAPE
0,Voting Regressor,4825501.4476,61415154757088.0,7836782.1685,0.8201,0.4347,0.4447


Unnamed: 0_level_0,MAE,MSE,RMSE,R2,RMSLE,MAPE
Fold,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0,21438915.1254,2674848260035854.5,51718935.2175,-0.0501,1.1914,1.0955
1,17694410.2459,1533732465517348.5,39162896.5415,0.6069,0.4991,0.4225
2,3039412.5922,18611533926425.363,4314108.7059,0.9465,0.2758,0.221
3,3058804.076,20696674227907.848,4549359.7602,0.8538,1.0011,1.7582
4,2664809.0358,15845881552538.115,3980688.5777,0.9014,0.279,0.2555
Mean,9579270.2151,852746963052014.8,20745197.7606,0.6517,0.6493,0.7505
Std,8241391.4471,1083723895511025.2,20551976.3752,0.37,0.3786,0.5946


Unnamed: 0,Model,MAE,MSE,RMSE,R2,RMSLE,MAPE
0,Voting Regressor,4003858.4855,44946312599332.805,6704201.1157,0.8684,0.3831,0.378


Unnamed: 0_level_0,MAE,MSE,RMSE,R2,RMSLE,MAPE
Fold,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0,21329396.321,2670414231704124.5,51676050.8524,-0.0484,1.191,1.166
1,17225506.3497,1375653136065606.2,37089798.2748,0.6474,0.5372,0.4805
2,2644056.2902,14911873529986.6,3861589.5082,0.9571,0.2395,0.1877
3,2779162.8292,16103005400747.879,4012855.0187,0.8862,1.1589,2.7715
4,2560008.4284,12659344284799.383,3557997.2295,0.9212,0.2886,0.2543
Mean,9307626.0437,817948318197053.0,20039658.1767,0.6727,0.683,0.972
Std,8243423.1534,1065737152697696.0,20404911.6234,0.3765,0.4143,0.9641


Unnamed: 0,Model,MAE,MSE,RMSE,R2,RMSLE,MAPE
0,Voting Regressor,4333235.0996,50464266565581.805,7103820.5612,0.8522,0.3921,0.391


In [9]:
# 准备测试数据
X_test = time_test_data_processed[feature_cols]
y_test = time_test_data_processed[target_column]

# 评估所有模型组合
results_time = model_evaluator.evaluate_ensemble_combinations(X_test, y_test)
pd.DataFrame(results_time).T

Unnamed: 0,Model,MAE,MSE,RMSE,R2,RMSLE,MAPE
0,Extra Trees Regressor,2224323.1243,30522534510792.87,5524720.3106,0.9908,0.1419,0.0898


Unnamed: 0,Model,MAE,MSE,RMSE,R2,RMSLE,MAPE
0,Gradient Boosting Regressor,3821791.0971,68924402489883.34,8302072.1805,0.9791,0.3781,0.221


Unnamed: 0,Model,MAE,MSE,RMSE,R2,RMSLE,MAPE
0,CatBoost Regressor,4403129.68,111158503420705.56,10543173.3089,0.9664,0.5013,0.5022


Unnamed: 0_level_0,MAE,MSE,RMSE,R2,RMSLE,MAPE
Fold,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0,21300693.8622,2670682209187896.0,51678643.6469,-0.0485,1.1743,1.089
1,15356214.0066,984396030609005.0,31375086.145,0.7477,0.6219,0.6507
2,2898821.0493,18997611881735.117,4358624.999,0.9454,0.2403,0.1867
3,2807112.5457,15523403636233.367,3939975.0807,0.8903,1.2025,3.2951
4,3136282.9548,21591925295527.555,4646711.2344,0.8656,0.32,0.2802
Mean,9099824.8837,742238236122079.4,19199808.2212,0.6801,0.7118,1.1003
Std,7766828.4151,1034220393175367.2,19328879.957,0.37,0.4095,1.1425


Unnamed: 0,Model,MAE,MSE,RMSE,R2,RMSLE,MAPE
0,Voting Regressor,2885849.5942,35488030839124.69,5957183.1296,0.9893,0.2048,0.1485


Unnamed: 0_level_0,MAE,MSE,RMSE,R2,RMSLE,MAPE
Fold,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0,21428492.3678,2667034886403157.0,51643343.0986,-0.047,1.2226,1.3373
1,21987818.129,1779550303000219.0,42184716.4622,0.5438,0.6137,0.4758
2,2469013.443,13146756853574.889,3625845.6743,0.9622,0.2138,0.1672
3,2830992.7074,17827657759569.605,4222281.1085,0.8741,1.2195,3.29
4,2097440.1453,8581268030250.1455,2929380.1444,0.9466,0.2959,0.2398
Mean,10162751.3585,897228174409354.1,20921113.2976,0.6559,0.7131,1.102
Std,9431295.3258,1118512108310672.6,21436771.9771,0.3828,0.4357,1.1707


Unnamed: 0,Model,MAE,MSE,RMSE,R2,RMSLE,MAPE
0,Voting Regressor,3183994.6358,58795346774060.65,7667812.3852,0.9822,0.3486,0.2821


Unnamed: 0_level_0,MAE,MSE,RMSE,R2,RMSLE,MAPE
Fold,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0,21438915.1254,2674848260035854.5,51718935.2175,-0.0501,1.1914,1.0955
1,17694410.2459,1533732465517348.5,39162896.5415,0.6069,0.4991,0.4225
2,3039412.5922,18611533926425.363,4314108.7059,0.9465,0.2758,0.221
3,3058804.076,20696674227907.848,4549359.7602,0.8538,1.0011,1.7582
4,2664809.0358,15845881552538.115,3980688.5777,0.9014,0.279,0.2555
Mean,9579270.2151,852746963052014.8,20745197.7606,0.6517,0.6493,0.7505
Std,8241391.4471,1083723895511025.2,20551976.3752,0.37,0.3786,0.5946


Unnamed: 0,Model,MAE,MSE,RMSE,R2,RMSLE,MAPE
0,Voting Regressor,3982528.5093,79123608312584.62,8895145.2103,0.9761,0.37,0.324


Unnamed: 0_level_0,MAE,MSE,RMSE,R2,RMSLE,MAPE
Fold,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0,21329396.321,2670414231704124.5,51676050.8524,-0.0484,1.191,1.166
1,17225506.3497,1375653136065606.2,37089798.2748,0.6474,0.5372,0.4805
2,2644056.2902,14911873529986.6,3861589.5082,0.9571,0.2395,0.1877
3,2779162.8292,16103005400747.879,4012855.0187,0.8862,1.1589,2.7715
4,2560008.4284,12659344284799.383,3557997.2295,0.9212,0.2886,0.2543
Mean,9307626.0437,817948318197053.0,20039658.1767,0.6727,0.683,0.972
Std,8243423.1534,1065737152697696.0,20404911.6234,0.3765,0.4143,0.9641


Unnamed: 0,Model,MAE,MSE,RMSE,R2,RMSLE,MAPE
0,Voting Regressor,3315148.9165,53669167031299.77,7325924.3124,0.9838,0.2953,0.2383


Unnamed: 0,r2,rmse,mae,mape
model_1,0.990761,5524720.0,2224323.0,0.089774
model_2,0.979138,8302072.0,3821791.0,0.220962
model_3,0.966354,10543170.0,4403130.0,0.502236
model_1_model_2_blend,0.989258,5957183.0,2885850.0,0.148481
model_1_model_3_blend,0.982204,7667812.0,3183995.0,0.282141
model_2_model_3_blend,0.976051,8895145.0,3982528.0,0.323972
model_1_model_2_model_3_blend,0.983755,7325924.0,3315149.0,0.238342


In [10]:
# 将字典转换为DataFrame后再进行连接
pd.concat([pd.DataFrame(results_country), pd.DataFrame(results_time)]).groupby(level=0).mean().T

Unnamed: 0,mae,mape,r2,rmse
model_1,3643182.0,0.263878,0.901286,6770290.0
model_2,3632081.0,0.268583,0.940557,7043639.0
model_3,4508181.0,0.482486,0.895152,9148109.0
model_1_model_2_blend,3543357.0,0.254861,0.926327,6393310.0
model_1_model_3_blend,4004748.0,0.363445,0.901165,7752297.0
model_2_model_3_blend,3993193.0,0.351008,0.922205,7799673.0
model_1_model_2_model_3_blend,3824192.0,0.314683,0.917977,7214872.0


In [11]:
model_evaluator.save_selected_models(
    model_names=['model_1','model_2','model_3'],  # 替换为您要保存的模型名称列表
    save_dir='src/models'  # 可自定义保存路径
)


Transformation Pipeline and Model Successfully Saved
模型 model_1 已保存到 src/models\model_1
Transformation Pipeline and Model Successfully Saved
模型 model_2 已保存到 src/models\model_2
Transformation Pipeline and Model Successfully Saved
模型 model_3 已保存到 src/models\model_3
