# Stellar Classification Dataset - SDSS17

## Qual o problema a ser resolvido?
Na astronomia, o esquema de classificação de galáxias, quasares e estrelas
é um dos mais fundamentais. O dataset Stellar Classification Dataset - SDSS17
contém dados espectrais de diferentes corpos estelares visando classificar estrelas,
galáxias e quasares com base nessas características.

### Descrição do dataset
Esse dataset consiste em 100.000 observações (linhas) do espaço feitas pelo
SDSS (Sloan Digital Sky Survey), onde cada observação é descrita por 17 colunas
(colunas) de recursos e 1 coluna (label) de classe que a identifica as observações
como uma estrela, galáxia ou quasar.

O resultado que desejamos alcançar é a classificação em uma das três
categorias (estrela, galáxia e quasar) com base nos dados analisados.

### Tipo do problema
O tipo do nosso problema envolve aprendizagem supervisionada, mais
especificamente, um problema de classificação.

## Análise exploratória dos dados

### Importações

In [1]:
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import preprocessing

### Carregando os dados

In [2]:
dados = pd.read_csv('./data/star_classification.csv')
dados.shape

(100000, 18)

### Análise dos tipos de dados

In [3]:
dados.dtypes

obj_ID         float64
alpha          float64
delta          float64
u              float64
g              float64
r              float64
i              float64
z              float64
run_ID           int64
rerun_ID         int64
cam_col          int64
field_ID         int64
spec_obj_ID    float64
class           object
redshift       float64
plate            int64
MJD              int64
fiber_ID         int64
dtype: object

### Análise de valores faltantes

In [4]:
dados.isnull().sum()

obj_ID         0
alpha          0
delta          0
u              0
g              0
r              0
i              0
z              0
run_ID         0
rerun_ID       0
cam_col        0
field_ID       0
spec_obj_ID    0
class          0
redshift       0
plate          0
MJD            0
fiber_ID       0
dtype: int64

### Descrição dos dados

In [5]:
dados.head()

Unnamed: 0,obj_ID,alpha,delta,u,g,r,i,z,run_ID,rerun_ID,cam_col,field_ID,spec_obj_ID,class,redshift,plate,MJD,fiber_ID
0,1.237661e+18,135.689107,32.494632,23.87882,22.2753,20.39501,19.16573,18.79371,3606,301,2,79,6.543777e+18,GALAXY,0.634794,5812,56354,171
1,1.237665e+18,144.826101,31.274185,24.77759,22.83188,22.58444,21.16812,21.61427,4518,301,5,119,1.176014e+19,GALAXY,0.779136,10445,58158,427
2,1.237661e+18,142.18879,35.582444,25.26307,22.66389,20.60976,19.34857,18.94827,3606,301,2,120,5.1522e+18,GALAXY,0.644195,4576,55592,299
3,1.237663e+18,338.741038,-0.402828,22.13682,23.77656,21.61162,20.50454,19.2501,4192,301,3,214,1.030107e+19,GALAXY,0.932346,9149,58039,775
4,1.23768e+18,345.282593,21.183866,19.43718,17.58028,16.49747,15.97711,15.54461,8102,301,3,137,6.891865e+18,GALAXY,0.116123,6121,56187,842


In [6]:
dados.describe()

Unnamed: 0,obj_ID,alpha,delta,u,g,r,i,z,run_ID,rerun_ID,cam_col,field_ID,spec_obj_ID,redshift,plate,MJD,fiber_ID
count,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0,100000.0
mean,1.237665e+18,177.629117,24.135305,21.980468,20.531387,19.645762,19.084854,18.66881,4481.36606,301.0,3.51161,186.13052,5.783882e+18,0.576661,5137.00966,55588.6475,449.31274
std,8438560000000.0,96.502241,19.644665,31.769291,31.750292,1.85476,1.757895,31.728152,1964.764593,0.0,1.586912,149.011073,3.324016e+18,0.730707,2952.303351,1808.484233,272.498404
min,1.237646e+18,0.005528,-18.785328,-9999.0,-9999.0,9.82207,9.469903,-9999.0,109.0,301.0,1.0,11.0,2.995191e+17,-0.009971,266.0,51608.0,1.0
25%,1.237659e+18,127.518222,5.146771,20.352353,18.96523,18.135828,17.732285,17.460677,3187.0,301.0,2.0,82.0,2.844138e+18,0.054517,2526.0,54234.0,221.0
50%,1.237663e+18,180.9007,23.645922,22.179135,21.099835,20.12529,19.405145,19.004595,4188.0,301.0,4.0,146.0,5.614883e+18,0.424173,4987.0,55868.5,433.0
75%,1.237668e+18,233.895005,39.90155,23.68744,22.123767,21.044785,20.396495,19.92112,5326.0,301.0,5.0,241.0,8.332144e+18,0.704154,7400.25,56777.0,645.0
max,1.237681e+18,359.99981,83.000519,32.78139,31.60224,29.57186,32.14147,29.38374,8162.0,301.0,6.0,989.0,1.412694e+19,7.011245,12547.0,58932.0,1000.0


### Correlação

#### Descrição da coluna 'class'

In [7]:
dados['class'].value_counts()

class
GALAXY    59445
STAR      21594
QSO       18961
Name: count, dtype: int64

In [8]:
labelEncoder = preprocessing.LabelEncoder()
dados['class'] = labelEncoder.fit_transform(dados['class'])
dados['class'].value_counts()

class
0    59445
2    21594
1    18961
Name: count, dtype: int64

In [9]:
dados.corr()

Unnamed: 0,obj_ID,alpha,delta,u,g,r,i,z,run_ID,rerun_ID,cam_col,field_ID,spec_obj_ID,class,redshift,plate,MJD,fiber_ID
obj_ID,1.0,-0.013735,-0.301237,0.01531,0.01571,0.153891,0.14767,0.013811,1.0,,-0.046997,0.031498,0.239461,-0.036012,0.0654,0.23946,0.262687,0.067178
alpha,-0.013735,1.0,0.138691,-0.001532,-0.002423,-0.022083,-0.02358,-0.002918,-0.013737,,0.019582,-0.165577,-0.002553,-0.011756,0.001667,-0.002554,0.019943,0.030464
delta,-0.301237,0.138691,1.0,0.002074,0.003523,-0.006835,-0.00448,0.00363,-0.301238,,0.032565,-0.173416,0.112329,0.014452,0.031638,0.112329,0.107333,0.02825
u,0.01531,-0.001532,0.002074,1.0,0.999311,0.054149,0.04573,0.998093,0.015309,,0.003548,-0.008374,0.029997,-0.024645,0.014309,0.029997,0.031997,0.016305
g,0.01571,-0.002423,0.003523,0.999311,1.0,0.062387,0.056271,0.999161,0.01571,,0.003508,-0.008852,0.039443,-0.020066,0.022954,0.039443,0.040274,0.01747
r,0.153891,-0.022083,-0.006835,0.054149,0.062387,1.0,0.962868,0.053677,0.153889,,0.00848,-0.026423,0.655245,-0.076766,0.433241,0.655243,0.67118,0.223106
i,0.14767,-0.02358,-0.00448,0.04573,0.056271,0.962868,1.0,0.055994,0.147668,,0.007615,-0.026679,0.661641,0.015028,0.492383,0.66164,0.672523,0.214787
z,0.013811,-0.002918,0.00363,0.998093,0.999161,0.053677,0.055994,1.0,0.013811,,0.003365,-0.008903,0.037813,-0.001614,0.03038,0.037813,0.037469,0.014668
run_ID,1.0,-0.013737,-0.301238,0.015309,0.01571,0.153889,0.147668,0.013811,1.0,,-0.047098,0.031498,0.23946,-0.036014,0.0654,0.239459,0.262687,0.067165
rerun_ID,,,,,,,,,,,,,,,,,,


In [10]:
X = dados.drop(columns=['class'])
y = dados['class']

#### Divisão dos dados entre Treino e Teste

In [11]:
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y, random_state=42)

#### Normalizando os dados

In [12]:
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
scaler.fit(X_train)

X_train_scaled = scaler.transform(X_train)
X_test_scaled = scaler.transform(X_test)

#### Aplicando árvore de decisão sem normalização dos dados

In [13]:
# Utilizando GridSearchCV para automatizar o teste do modelo com vários parâmetros
from sklearn.model_selection import GridSearchCV

classifier = DecisionTreeClassifier()

param_grid = {
    "criterion": ["gini", "entropy", "log_loss"],
    "max_depth": [None, 5, 10],
    "min_samples_split": [2, 5, 10]
}

# Cria a grid utilizando o classificador, os parâmetros selecionados, e cross-validation com 5 folds
grid_search = GridSearchCV(classifier, param_grid, cv=5)
grid_search.fit(X_train, y_train)

# Melhores parâmetros encontrados
print("Melhores parâmetros: ", grid_search.best_params_)
print("Melhor Score: ", grid_search.best_score_)

Melhores parâmetros:  {'criterion': 'log_loss', 'max_depth': 10, 'min_samples_split': 10}
Melhor Score:  0.9741875


In [14]:
best_params = grid_search.best_params_  
classifier = DecisionTreeClassifier(**best_params)  
classifier.fit(X_train, y_train)

In [15]:
y_pred = classifier.predict(X_test) 
accuracy = accuracy_score(y_test, y_pred) 
print("Acurácia:", accuracy)

Acurácia: 0.9768
