# Multi-class Classification of Spectral Data



In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV, cross_val_predict
 
from sklearn.metrics import classification_report, confusion_matrix, roc_curve 
from sklearn.metrics import RocCurveDisplay, roc_auc_score

In [None]:
data = pd.read_csv('3-Sample_data_Set1_O2varyCO2varyCH4_looped_work.csv')
 
lab = data.values[:,1].astype('uint8') #labels
spectra = data.values[:,4:]
label = data.values[:,3]
#y = np.reshape(label, (label.shape[0], 1)) 
y = label.astype('int')
X = savgol_filter(spectra, 9, polyorder = 3, deriv=0)
X1 = savgol_filter(spectra, 11, polyorder = 3, deriv=1)

### Prelims Check on Trends

In [None]:

pca = PCA(n_components=3)
Xpca = pca.fit_transform(StandardScaler().fit_transform(X1[:,:]))

### Multi-class Classification Pipeline

In [None]:

pipe = Pipeline([('scaler', StandardScaler()), 
                 ('pca', PCA()), 
                 ('logit', LogisticRegression(max_iter=100000))]) 
parameters = {'logit__C':np.logspace(-3,0, num=4), 
              'pca__n_components':np.linspace(1,10,10).astype('uint8')}
gs = GridSearchCV(pipe, parameters, scoring = 'accuracy', verbose=0, cv=2, n_jobs=8)
gs.fit(X, y)
print(gs.best_estimator_['logit'])

In [None]:
y_cv = cross_val_predict(gs.best_estimator_, X, y, cv=2, n_jobs=8)
print(classification_report(y, y_cv))
print(confusion_matrix(y, y_cv))

In [None]:
report = classification_report(y, y_cv, output_dict=True)
report_df = pd.DataFrame(report).transpose()
report_df.to_csv('classification_repor.csv', index= True)

In [None]:
report_confusion_matrix = confusion_matrix(y, y_cv)

In [None]:
import matplotlib as mpl

mpl.rc('axes', labelcolor='black')
mpl.rcParams['text.usetex'] = True



#report_confusion_matrix.shape

import seaborn as sns
labels = [1, 2, 3, 4, 5, 6, 7, 8]
confusion_matrix_df = pd.DataFrame(report_confusion_matrix, index = labels,
                  columns = labels)


plt.figure(figsize = (4,3))
sns.heatmap(confusion_matrix_df, annot=True, cmap="RdPu", linewidths=1, linecolor='black')
plt.xlabel('Spectra Class Label')
plt.ylabel('Spectra Class Label')
plt.title("Confusion Matrix", fontsize =12)
plt.savefig('heatmap.svg', bbox_inches="tight", dpi=400)

### ROC curves

In [None]:
y_score = cross_val_predict(gs.best_estimator_, X, y, cv=2, n_jobs=8, method='predict_proba')
 
enc = OneHotEncoder()
enc.fit(y.reshape(-1, 1))
yenc = enc.transform(y.reshape(-1, 1)).toarray()

In [None]:
import matplotlib as mpl

mpl.rc('axes', labelcolor='black')
mpl.rcParams['text.usetex'] = True


from itertools import cycle

fig, ax = plt.subplots(figsize=(3.3, 3.3))


n_classes = 8
colors = cycle(["aqua", "darkorange", "cornflowerblue", "red", "green", "purple", "pink", "blue"])
for class_id, color in zip(range(n_classes), colors):
    RocCurveDisplay.from_predictions(
        yenc[:, class_id],  # <- The index 1 means we are selecting class 2
        y_score[:, class_id],
        name=f"Class {class_id+1}",
        color=color,
        ax=ax,
        #plot_chance_level=(class_id == 2),
    )

plt.plot([0,1],[0,1], '--', color='gray')
plt.axis("square")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curves: One-vs-Rest")
plt.legend()
#plt.show()
plt.savefig('roc_multiclass.svg', bbox_inches="tight", dpi=400)

### Plot the spectral data

In [None]:
spectra_df = data.iloc[:,4:].transpose()
spectra_df['Wavelength_um'] = spectra_df.index

In [None]:
spectra_df = spectra_df.reset_index(drop=True)
#spectra_df

In [None]:
plot_data = spectra_df.iloc[:,[0,11, 21, 31, 41, 51, 61, 71, -1]]


In [None]:
dict = {0: 'Class_1',
        11: 'Class_2',
        21: 'Class_3',
       31: 'Class_4',
       41: 'Class_5',
       51: 'Class_6',
       61: 'Class_7',
       71: 'Class_8'
       }

plot_data.rename(columns=dict,
          inplace=True)

In [None]:
plot_data

In [None]:
import matplotlib as mpl

mpl.rc('axes', labelcolor='black')
mpl.rcParams['text.usetex'] = True


fig, axs = plt.subplots(8,1,figsize=(7, 8), sharex=True)

# 
# default: cycler('color', ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'])

#plot individual lines
axs[0].plot(plot_data['Wavelength_um'], plot_data['Class_1'], linewidth=0.5, color='#1f77b4', label="Class 1")
axs[0].legend(bbox_to_anchor=(0.33, 1.0), loc='upper right')

axs[1].plot(plot_data['Wavelength_um'], plot_data['Class_2'], linewidth=0.5, color='#ff7f0e', label="Class 2")
axs[1].legend(bbox_to_anchor=(0.33, 1.0), loc='upper right')

axs[2].plot(plot_data['Wavelength_um'], plot_data['Class_3'], linewidth=0.5, color='#2ca02c', label="Class 3")
axs[2].legend(bbox_to_anchor=(0.33, 1.0), loc='upper right')

axs[3].plot(plot_data['Wavelength_um'], plot_data['Class_4'], linewidth=0.5, color='#d62728', label="Class 4")
axs[3].legend(bbox_to_anchor=(0.33, 1.0), loc='upper right')

axs[4].plot(plot_data['Wavelength_um'], plot_data['Class_5'], linewidth=0.5, color='#9467bd', label="Class 5")
axs[4].legend(bbox_to_anchor=(0.33, 1.0), loc='upper right')

axs[5].plot(plot_data['Wavelength_um'], plot_data['Class_6'], linewidth=0.5, color='#8c564b', label="Class 6")
axs[5].legend(bbox_to_anchor=(0.33, 1.0), loc='upper right')

axs[6].plot(plot_data['Wavelength_um'], plot_data['Class_7'], linewidth=0.5, color='#e377c2', label="Class 7")
axs[6].legend(bbox_to_anchor=(0.33, 1.0), loc='upper right')

axs[7].plot(plot_data['Wavelength_um'], plot_data['Class_8'], linewidth=0.5, color='#7f7f7f', label="Class 8")
axs[7].legend(bbox_to_anchor=(0.33, 1.0), loc='upper right')

axs[7].set_xticks([0, 4615], minor=False)


# plot and save figure
plt.xlabel('\\textbf{Wavelength (um)}', fontsize=12)

fig.supylabel(r'\textbf{Exoplanet Transit Depth (ppm)}', fontsize=12)

plt.savefig("multiple_spectra_sample.svg", bbox_inches="tight", dpi=400)
