### 预测数据处理

* 将数据转换成preprocess形状
* 填补空缺项（分类列填missing，数值列填nan）
* 剩余的分类列做one-hot编码
* 对数据进行标准化处理

In [2]:
import numpy as np
import pandas as pd
import joblib
import pickle
from xgboost import XGBClassifier

In [10]:
# 加载数据处理模型
imputer = joblib.load('./model/imputer.joblib')
scaler = joblib.load('./model/scaler.joblib')
with open('./model/NumericCols.pkl', 'rb') as f:
    NumericCols = pickle.load(f)

# 读取数据
input_data = pd.read_excel('./predict/predict_data.xlsx')
print(input_data)
# 读取需要的列
preprocessed_data = pd.read_csv('./predict/preprocessed.csv', nrows=0)

# print(input_data.columns)

# 将多余的列删除
for field in input_data.columns:
    if field not in preprocessed_data.columns:
        input_data = input_data.drop(field, axis=1)

# 添加宜居度
input_data['habitable'] = np.nan
input_data = input_data.set_index("rowid")

# 填充缺失值
input_data[input_data._get_numeric_data().columns] = imputer.transform(input_data[input_data._get_numeric_data().columns])

# 标准化
input_data[NumericCols] = scaler.transform(input_data[NumericCols]) 


open('./predict/predict_data_clearify.csv', 'w').close()

input_data.to_csv('./predict/predict_data_clearify.csv')

   rowid pl_hostname  pl_letter  pl_name  pl_discmethod  pl_controvflag  \
0      1        liqi        NaN      NaN            NaN             NaN   

   pl_pnum  pl_orbper  pl_orbpererr1  pl_orbpererr2  ...  st_bmy  st_bmyerr  \
0      NaN        NaN            NaN            NaN  ...     NaN        NaN   

   st_bmylim  st_m1  st_m1err  st_m1lim  st_c1  st_c1err  st_c1lim  st_colorn  
0        NaN    NaN       NaN       NaN    NaN       NaN       NaN        NaN  

[1 rows x 356 columns]
Index(['rowid', 'pl_pnum', 'pl_orbper', 'pl_orbpererr1', 'pl_orbpererr2',
       'pl_orbperlim', 'pl_orbsmaxlim', 'pl_radj', 'pl_radjerr1',
       'pl_radjerr2', 'pl_radjlim', 'ra', 'dec', 'st_dist', 'st_disterr1',
       'st_disterr2', 'st_optmag', 'gaia_gmag', 'st_teff', 'st_tefferr1',
       'st_tefferr2', 'st_mass', 'st_masserr1', 'st_rad', 'st_raderr1',
       'st_raderr2', 'pl_tranflag', 'pl_rvflag', 'pl_radelim', 'pl_radslim',
       'pl_trandur', 'pl_trandurerr1', 'pl_trandurerr2', 'pl_tranmid

### 预测

In [3]:
model = XGBClassifier()
model.load_model('./model/xgb.model')

test = pd.read_csv('./predict/predict_data_clearify.csv')
remove = ['rowid', 'habitable']
test = test.drop(remove, axis = 1)

for i in range(len(test)):
    test_x = test.iloc[i].to_numpy()
    test_x = test_x.reshape(1, -1)
    # print(test_x)
    y_preds = model.predict(test_x)
    print(y_preds, end=' ')
print()

[0] 
