# Downstream Exploitation of Space Data
## Session 6: Supervised Machine Learning

### Learning Objectives

You will: 
* know the type of problems supervised machine learning solves and see some examples
* be able to fit a linear regression to a (toy) dataset
* be able to classify objects using a Random Forest classifier
* get familiat with how to analyze the performance of a classifier
* get familiar with variable star classification learning problem

### Regression

Regression is a type of supervised learning in machine learning where the goal is to model the relationship between a dependent variable (target) and one or more independent variables (features). The objective is to predict a continuous outcome based on the input data.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.linear_model import LinearRegression
from sklearn.datasets import make_regression # we will use this to generate a dataset for our regression

Let's generate our data:

In [None]:
X, y = make_regression(n_samples=100, n_features=1, noise=10, random_state=42)

We can get some information on our dataset:

In [None]:
data = pd.DataFrame({'Feature': X.flatten(), 'Target': y})

In [None]:
print('Feature (x):')
print(data['Feature'].describe())
print('======')
print('Target (y):')
print(data['Target'].describe())

We now split our dataset into a train (80%) and test sets (20%):

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

**Discuss with your neighbour:** why is it important to have a separate test set?

We will now fit a linear regression model to this dataset, i.e. a straight line described as y = mx + b:

In [None]:
model = LinearRegression()
model.fit(X_train, y_train)

In [None]:
y_pred = model.predict(X_test) # making predictions on the test set

We can now visualize our dataset and the fitted line:

In [None]:
plt.figure(figsize=(8, 5))
plt.scatter(X_train, y_train, color='blue', label='Training data', s=5) # trainin set
plt.scatter(X_test, y_test, color='green', label='Test data', s=5) # test set
plt.plot(X_test, y_pred, color='pink', label='Fitted line') # the line we have fitted
plt.xlabel('Feature')
plt.ylabel('Target')
plt.title('Linear regression')
plt.legend()
plt.show()

Let's print the regression parameters:

In [None]:
print(f'Coefficient: {model.coef_[0]}')
print(f'Intercept: {model.intercept_}')

Therefore, our model is decribed by the following equation:

In [None]:
print(f' y = {model.coef_[0]:.2f}x + {model.intercept_:.2f}')

We can predict the value of the target variable for a new datapoint (not in our dataset) using our model:

In [None]:
new_data = np.array([[3.7]])  # change the number in [] to predit a new y value

predicted_y = model.predict(new_data)
print(f'y({new_data[0][0]}) = {predicted_y[0]}')

**To do:** Try a few different x values to predict y values for them.

We can evaluate the model by looking at some of its performance metrics:

In [None]:
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

In [None]:
print(f'Mean Squared Error (MSE): {mse:.2f}')
print(f'R-squared (R²): {r2:.2f}')

**Discuss with your neighbour:** do you think the model performs well?

### Classification

Classification is a type of supervised learning in machine learning where the goal is to assign a label or category to a given input based on its features. The model is trained on a labeled dataset to predict the label of new, unseen data.

#### Toy dataset

In [None]:
from sklearn.datasets import load_iris # we will use this to get our dataset
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import seaborn as sns

Let's load our dataset:

In [None]:
iris = load_iris()
X = iris.data  # features (sepal length, sepal width, petal length, petal width)
y = iris.target  # labels (0 = setosa, 1 = versicolor, 2 = virginica)

Let's do some data exploration:

In [None]:
df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
df['target'] = iris.target
df['target'] = df['target'].map({0: 'setosa', 1: 'versicolor', 2: 'virginica'})

In [None]:
df.head() # prints the first N rows of the dataset (default = 5)

In [None]:
df.describe() # summary statistics for each column

In [None]:
df.info() # some useful general info

We then split the data into a training (80%) and test (20%) set:

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

Let's train our classifier:

In [None]:
clf = RandomForestClassifier(n_estimators=5, random_state=42)
clf.fit(X_train, y_train)

Once it is trained, we predit the label (y) on the test set that we have set aside:

In [None]:
y_pred = clf.predict(X_test)

We can see a classification report to see how our classifier performed:

In [None]:
print(classification_report(y_test, y_pred, target_names=iris.target_names))

**Discuss with your neighbour:** what do you think about the performance of this classifier?

It is also very useful to look at the confusion matrix:

In [None]:
cm = confusion_matrix(y_test, y_pred)

In [None]:
plt.figure(figsize=(6, 4))
sns.heatmap(cm, annot=True, cmap='Blues', fmt='d', xticklabels=iris.target_names, yticklabels=iris.target_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

**To do:** Change the number of estimators (trees) and see how it changes the classification report and the confusion matrix.

**Discuss with your neighbour:** What can you conclude from this? 

We can also see how important the features are:

In [None]:
importances = clf.feature_importances_
feature_names = iris.feature_names

In [None]:
plt.figure(figsize=(8, 5))
plt.barh(feature_names, importances)
plt.xlabel('Importance')
plt.title('Feature importance')
plt.show()

**Discuss with your neighbour:** What conclusion can you make from the feature importance? In your opinion, would the classifier have performed the same if certain features were removed?

#### Variable star dataset

##### Helper functions

Let's first define some functions that will be useful for us:

In [None]:
from scipy.ndimage import gaussian_filter1d

import lightkurve as lk
from astropy.timeseries import LombScargle

In [None]:
def smooth_time_series(time_series, sigma):
    '''
    This function applies Gaussian smoothing to a time series.
    Input: time_series (a series of data points in time order), sigma (width of the filter).
    Output: smoothed time series.
    '''
    smoothed_series = gaussian_filter1d(time_series, sigma=sigma, mode='nearest')
    
    return smoothed_series

In [None]:
def get_random_row(df):
    '''
    This gets a random row from a dataframe and returns TIC, Sector, and Class column content.
    Input: df (dataframe with extracted features).
    Output: TIC (TESS id of a star), Sector (TESS sector), Class (label).
    '''
    random_row = df.sample(n=1).iloc[0]

    return random_row['TIC'], random_row['Sector'], random_row['Class']

In [None]:
def extract_light_curve(tic_id, sector, science_product='QLP'):
    '''
    This function extracts a QLP light curve from a specific sector.
    Input: tic_id (TESS id of a star), sector (TESS sector), science_product (light curve pipeline).
    Output: a search result containing a light curve.
    '''
    search = lk.search_lightcurve(f'TIC {tic_id}', author=science_product)
    light_curve = search[search.mission == f'TESS Sector {sector}']

    return light_curve

In [None]:
def preprocess_light_curve(tic_id, sector, science_product='QLP', sigma=61):
    '''
    This function preprocesses light curve to improve the quality of the data.
    Input: tic_id (TESS id of a star), sector (TESS sector), science_product (light curve pipeline), sigma (width of the filter).
    Output: time and flux after applying TESS quality flags, removing outliers, and smoothing with a Gaussian filter.
    '''
    lc = extract_light_curve(tic_id, sector, science_product=science_product)
    
    if lc is None or len(lc) == 0:
        print(f'No light curve available for TIC {tic_id} in sector {sector}.')
        return None, None 

    light_curve = lc[0].download()
    
    if light_curve is None:
        print(f'Download failed for TIC {tic_id}, sector {sector}.')
        return None, None

    time = light_curve['time'].value
    flux = light_curve['sap_flux'].value

    # step one: applying TESS quality flags -> low-quality data points are removed
    quality_mask = light_curve.quality
    good_quality_mask = (quality_mask == 0)
    
    time = time[good_quality_mask]
    flux = flux[good_quality_mask]

    # step two: removing outliers
    flux_mean = np.mean(flux)
    flux_std = np.std(flux)
    
    time_no_outl = time[np.abs(flux - flux_mean) < 10 * flux_std]
    flux_no_outl = flux[np.abs(flux - flux_mean) < 10 * flux_std]

    # step three: subtracting a Gaussian smoothed time series to remove long-period instrumental trends
    smoothed = smooth_time_series(flux_no_outl, sigma=sigma)
    smoothed_flux = flux_no_outl[:len(smoothed)] - smoothed 
    smoothed_time = time_no_outl[:len(smoothed_flux)]

    smoothed_flux += np.mean(flux_no_outl)

    return smoothed_time, smoothed_flux

In [None]:
def plot_light_curve(tic_id, sector, science_product='QLP', sigma=61):
    '''
    This function plots a preprocessed light curve.
    Input: tic_id (TESS id of a star), sector (TESS sector), science_product (light curve pipeline), sigma (width of the filter).
    Output: none (this function is a process, does not return anything).
    '''
    time, flux = preprocess_light_curve(tic_id, sector, science_product=science_product, sigma=sigma)

    fig, ax = plt.subplots(figsize=(18, 4))
    
    ax.scatter(time, flux, color='black', s=2, label='Preprocessed light curve')
    ax.set_xlabel('Time [d]', fontsize=12)
    ax.set_ylabel('Normalized Flux', fontsize=12)
    ax.legend(fontsize=12, loc='best') 

    fig.suptitle(f'Light Curve for TIC {tic_id}, Sector {sector}', fontsize=14)
    
    plt.tight_layout()
    plt.show()

In [None]:
def compute_periodogram(tic_id, sector, science_product='QLP', sigma=61):
    '''
    This function computes a Lomb Scargle periodogram for a light curve.
    Input: tic_id (TESS id of a star), sector (TESS sector), science_product (light curve pipeline), sigma (width of the filter).
    Output: Lomb Scargle periodogram.
    '''
    smoothed_time, smoothed_flux = preprocess_light_curve(tic_id, sector, science_product=science_product, sigma=sigma)

    light_curve = lk.LightCurve(time=smoothed_time, flux=smoothed_flux)
    periodogram = light_curve.normalize(unit='ppm').to_periodogram(method='lombscargle', normalization='amplitude')

    return periodogram

In [None]:
def plot_periodogram(tic_id, sector, science_product='QLP', sigma=61, features=None):
    '''
    This function plots a periodogram with (optionally) extracted frequencies.
    Input: tic_id (TESS id of a star), sector (TESS sector), science_product (light curve pipeline), sigma (width of the filter),
        features (dataframe with extracted features).
    Output: none (this function is a process, does not return anything).
    '''
    periodogram = compute_periodogram(tic_id, sector, science_product=science_product, sigma=sigma)

    fig, ax = plt.subplots(figsize=(18, 4))
    periodogram.plot(ax=ax, color='blue', label='Periodogram')

    ax.set_xlabel('Frequency [1/d]', fontsize=12)
    ax.set_ylabel('Power [ppm]', fontsize=12)

    if features is not None and not features.empty:
        p = 0
        g = 0

        filtered_row = features[(features["TIC"] == tic_id) & (features["Sector"] == sector)]
    
        if not filtered_row.empty:
            periods = filtered_row[['PeriodLS', 'PeriodLS2', 'PeriodLS3']].values.flatten().tolist()
        else:
            periods = []

        for period in periods:
            if period > 0:
                frequency = 1 / period
                if period <= 0.3:
                    if p == 0:
                        ax.axvline(x=frequency, linestyle='--', color='green', alpha=0.3, label='Extracted peak ≤ 0.3 [d]')
                    else:
                        ax.axvline(x=frequency, linestyle='--', color='green', alpha=0.3)
                    p += 1
                else:
                    if g == 0:
                        ax.axvline(x=frequency, linestyle='--', color='purple', alpha=0.3, label='Extracted peak > 0.3 [d]')
                    else:
                        ax.axvline(x=frequency, linestyle='--', color='purple', alpha=0.3)
                    g += 1 

    fig.suptitle(f'Periodogram for TIC {tic_id}, Sector {sector}', fontsize=14)

    ax.set_xlim(0, 70)
    ax.legend(fontsize=12, loc='best') 

    plt.tight_layout()
    plt.show()

In [None]:
def plot_lc_and_pd(tic_id, sector, science_product='QLP', sigma=61, zoom=False):
    '''
    This function plots a light curve, a periodogram, and (optionally) a zoomed-in periodogram on low-frequency regime.
    Input: tic_id (TESS id of a star), sector (TESS sector), science_product (light curve pipeline), sigma (width of the filter),
        zoom (True -> plot zoomed-in version, False -> do not plot).
    Output: none (this function is a process, does not return anything).

    Note that this function will only mark the dominant variability instead of 3 extracted peaks from the dataframe.
    '''
    smoothed_time, smoothed_flux = preprocess_light_curve(tic_id, sector, science_product=science_product, sigma=sigma)
    periodogram = compute_periodogram(tic_id, sector, science_product=science_product, sigma=sigma)
    max_frequency = periodogram.frequency[np.argmax(periodogram.power)]

    if zoom:
        fig, axs = plt.subplots(3, 1, figsize=(18, 12))
    else:
        fig, axs = plt.subplots(2, 1, figsize=(18, 8))

    axs[0].scatter(smoothed_time, smoothed_flux, color='black', s=2)
    axs[0].set_xlabel('Time [d]', fontsize=12)
    axs[0].set_ylabel('Detrended Flux', fontsize=12)
    axs[0].set_title(f'Light curve for TIC {tic_id}, Sector {sector}', fontsize=14)

    periodogram.plot(ax=axs[1], color='blue')
    axs[1].set_title(f'Periodogram (full) | f1: {max_frequency:.4f} | p1: {1/max_frequency:.4f}', fontsize=14)
    axs[1].axvline(x=max_frequency.value, linestyle='--', color='orange', alpha=0.5)
    axs[1].set_xlim(0, 70)

    if zoom:
        periodogram.plot(ax=axs[2], color='blue')
        axs[2].set_title('Periodogram (zoomed)', fontsize=14)
        axs[2].axvline(x=max_frequency.value, linestyle='--', color='orange', alpha=0.5)
        axs[2].set_xlim(0, 5)

    plt.tight_layout()
    plt.show()

##### Exploring the dataset

In [None]:
df = pd.read_csv('session6_tda.csv')
df.head()

In [None]:
df.describe()

In [None]:
df.info()

Let's get a random row from the dataframe:

In [None]:
tic, sector, label = get_random_row(df)

In [None]:
tic, sector, label

With that we can now take a look at the light curve of this object:

In [None]:
plot_light_curve(tic, sector)

We can of course also use this function for a non-random object if we specify the TIC and Sector manually:

In [None]:
plot_light_curve(121788685, 40)

The Fourier transformation of our random object looks like this:

In [None]:
plot_periodogram(tic, sector, features=df) # remove features=df if you don't want to see extracted periods

If you want to look at both the light curve and periodogram, you can do it like this:

In [None]:
plot_lc_and_pd(tic, sector)

It is also possible to zoom in on the low frequency regime - this is very useful for some classes but not so much for others:

In [None]:
plot_lc_and_pd(tic, sector, zoom=True)

**To do:** Try running these cells with with different random objects a few times so look at different stars.

**Discuss with your neighbour:** Have you noticed for which classes the zoom-in is more useful?

##### Classification

Let's classify our dataset with a RF classifier:

In [None]:
X = df.drop(columns=['TIC', 'Sector', 'Class'])
y = df['Class']

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [None]:
clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)

In [None]:
y_pred = clf.predict(X_test)

In [None]:
print(classification_report(y_test, y_pred))

In [None]:
cm = confusion_matrix(y_test, y_pred)

In [None]:
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, cmap='Blues', fmt='d', xticklabels=clf.classes_, yticklabels=clf.classes_)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

This might be a bit difficult to interpret since we have different number of objects in our classes. Let's look at the normalized version of the confusion matrix:

In [None]:
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

In [None]:
plt.figure(figsize=(8, 6))
sns.heatmap(cm_normalized, annot=True, cmap='Blues', fmt='.2f', xticklabels=clf.classes_, yticklabels=clf.classes_)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Normalized Confusion Matrix')
# plt.savefig('confusion_matrix.pdf')
plt.show()

**Put in the report:** This confusion matrix.

**Discuss with your neighbour (put in the report):** Why do you think rrlyr_cepheid class is not retrieved well?

**Discuss with your neighbour (put in the report):** Which other two classes, other than with rrlyd_cepheid (!), are most-heavily confused?

In [None]:
importances = clf.feature_importances_
feature_names = X.columns

In [None]:
plt.figure(figsize=(8, 13))
plt.barh(feature_names, importances)
plt.xlabel('Importance')
plt.title('Feature importance')
plt.show()

**Discuss with your neighbour:** What are the most important features? Do you think it makes sense given that it is a dataset of variable stars? The description of features here can help you answer this question: https://feets.readthedocs.io/en/latest/tutorial.html#The-Features

**Discuss with your neighbour:** What do you think about the performance of the classifier as a whole?

**To do for the report:** Plot the distribution of PeriodLS, PeriodLS2, and PeriodLS3 features for all classes (hint: an overlapping histogram with transparent fill works best for this) and discuss. For better readability, you can put different classes on different subplots.

**To do for the report:** Randomly sample 5 objects from each of these two classes, plot their light curves and periodograms and discuss.