# 1. 데이터 불러오기 

In [1]:
import pandas as pd

train = pd.read_csv('./data/train.csv')
test = pd.read_csv('./data/test.csv')
submission = pd.read_csv('./data/submission.csv')

In [2]:
train.columns

Index(['PassengerId', 'Survived', 'Pclass', 'Name', 'Sex', 'Age', 'SibSp',
       'Parch', 'Ticket', 'Fare', 'Cabin', 'Embarked'],
      dtype='object')

In [3]:
# 7개의 독립변수만 사용
columns = ['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked']
train = train[columns + ['Survived']]
test = test[columns]

# 2. 전처리

### 2-1. 전처리

In [4]:
train.isnull().sum()

Pclass        0
Sex           0
Age         177
SibSp         0
Parch         0
Fare          0
Embarked      2
Survived      0
dtype: int64

In [5]:
test.isnull().sum()

Pclass       0
Sex          0
Age         86
SibSp        0
Parch        0
Fare         1
Embarked     0
dtype: int64

In [6]:
# Null 처리
mean_age = train['Age'].mean()
mean_fare = train['Fare'].mean()

train['Age'] = train['Age'].fillna(mean_age)
test.loc[:,'Age'] = test['Age'].fillna(mean_age)
train['Fare'] = train['Fare'].fillna(mean_fare)
test['Fare'] = test['Fare'].fillna(mean_fare)

### 2-2. Object 데이터 수치화

In [7]:
train.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 8 columns):
 #   Column    Non-Null Count  Dtype  
---  ------    --------------  -----  
 0   Pclass    891 non-null    int64  
 1   Sex       891 non-null    object 
 2   Age       891 non-null    float64
 3   SibSp     891 non-null    int64  
 4   Parch     891 non-null    int64  
 5   Fare      891 non-null    float64
 6   Embarked  889 non-null    object 
 7   Survived  891 non-null    int64  
dtypes: float64(2), int64(4), object(2)
memory usage: 55.8+ KB


In [8]:
train = pd.get_dummies(train,columns=['Sex','Embarked','Pclass'])
test = pd.get_dummies(test,columns=['Sex','Embarked','Pclass'])
train.head()

Unnamed: 0,Age,SibSp,Parch,Fare,Survived,Sex_female,Sex_male,Embarked_C,Embarked_Q,Embarked_S,Pclass_1,Pclass_2,Pclass_3
0,22.0,1,0,7.25,0,False,True,False,False,True,False,False,True
1,38.0,1,0,71.2833,1,True,False,True,False,False,True,False,False
2,26.0,0,0,7.925,1,True,False,False,False,True,False,False,True
3,35.0,1,0,53.1,1,True,False,False,False,True,True,False,False
4,35.0,0,0,8.05,0,False,True,False,False,True,False,False,True


### 2-3. 데이터 분할

In [9]:
train_x = train.drop(columns='Survived')
train_y = train['Survived']

from sklearn.model_selection import train_test_split
train_x, val_x, train_y, val_y  = train_test_split(train_x, train_y, test_size=0.2, random_state=0)

# 3. 데이터 불균형 해소를 위한 데이터 증강 (SMOTE)

SMOTE는 불균형한 데이터에서 소수 클래스의 샘플을 증강시켜 데이터를 균형있게 만드는 오버 샘플링 방법 중 하나이다. 소수 클래스 샘플과 가장 가까운 이웃들을 찾아 새로운 데이터를 생성한다.

In [10]:
from imblearn.over_sampling import SMOTE

smote = SMOTE(random_state=0)
X_resampled, y_resampled = smote.fit_resample(train_x,list(train_y))

X_resampled['Survived'] = y_resampled
train_dataset = X_resampled

# 4. 모델 학습

In [11]:
from sklearn.tree import DecisionTreeClassifier

model = DecisionTreeClassifier(max_depth=6, random_state=0)
model.fit(train_dataset.drop(columns='Survived'),train_dataset['Survived'])

모델 평가

In [12]:
val_pred = model.predict(val_x)

from sklearn.metrics import confusion_matrix
print(confusion_matrix(val_y, val_pred))

from sklearn.metrics import classification_report
print(classification_report(val_y, val_pred))

[[94 16]
 [18 51]]
              precision    recall  f1-score   support

           0       0.84      0.85      0.85       110
           1       0.76      0.74      0.75        69

    accuracy                           0.81       179
   macro avg       0.80      0.80      0.80       179
weighted avg       0.81      0.81      0.81       179



In [13]:
# 0과 1로 분류하지 않은 결과
from sklearn.metrics import roc_auc_score
print(roc_auc_score(val_y,val_pred))

0.7968379446640316


예측

In [14]:
y_pred = model.predict(test)  

sub = submission.copy()
sub['Survived'] = y_pred
sub.head()

Unnamed: 0,PassengerId,Survived
0,892,0
1,893,1
2,894,0
3,895,0
4,896,1


In [15]:
sub.to_csv('./sub/sub_smote_DecisionTree.csv', index=False)
# 0.7742210321