## Shortlisting algorithms

### Introduction
You cannot know which algorithm will work best on your dataset beforehand. You must use trial and error to discover a shortlist of algorithms that do well on your problem that you can then double down on and tune further. I call this process spot-checking.

The question is not:

  >*What algorithm should I use on my dataset?*

Instead it is:

  >*What algorithms should I spot-check on my dataset?*

You can guess at what algorithms might do well on your dataset, and this can be a good starting point. I recommend trying a mixture of algorithms and see what is good at picking out the structure in your data. Below are some suggestions when spot-checking algorithms on your dataset:

- Try a mixture of algorithm representations (e.g. instances and trees).
- Try a mixture of learning algorithms (e.g. different algorithms for learning the same type of representation).
- Try a mixture of modelling types (e.g. linear and nonlinear functions or parametric and non-parametric).

We are going to take a look at four classification algorithms that you can spot-check on your dataset.

- Logistic Regression
- k-Nearest Neighbors.
- Classification and Regression Trees (CART).
- Support Vector Machines.

We will then look at how you might go about tuning a model to find the best parameters. Finally, we will demonstrate how to save and load the best model for future use in production.

### Install Python libraries

In [None]:
!pip install pandas matplotlib seaborn scikit-learn

### Stellar Classification Dataset (SDSS17)
Astronomers use different colours of light to measure how bright stars and galaxies are, break their light apart to find out what they’re made of, and calculate how far and how fast they’re moving away. Altogether, this gives us a vast amount of information about the universe.

The dataset we will use was derived from the Sloan Digital Sky Survey (SDSS), a long-term astronomical survey that maps the sky using a dedicated 2.5-meter wide-angle optical telescope located at Apache Point Observatory in New Mexico, USA. The SDSS17 dataset comes from the 17th data release of this project.

Astronomers observe space using powerful telescopes and collect two main types of data:

*1.) Photometric observations (brightness in different colours)*
Imagine taking a photograph of the night sky, but instead of using just one colour, you take five different images using distinct colour filters. These filters are labelled:

- `u` (ultraviolet) – captures the shortest wavelengths (invisible to our eyes)
- `g` (green)
- `r` (red)
- `i` (infrared)
- `z` (even deeper infrared)

Each filter records how bright a star or galaxy appears in that portion of the light spectrum. This helps scientists determine:

- The object’s temperature
- Its composition (what it is made of)
- How old it might be

Think of it like looking at a fire through coloured glasses—you can learn a lot about the flame based on which glasses make it look brightest.

*2.) Spectroscopic Observations (splitting light like a rainbow)*

The second type of data is spectroscopic observations. This involves taking the light from a star or galaxy and passing it through a prism to separate it into its component colours. This “rainbow” of light is known as a spectrum. Why is this useful?

Different elements (such as hydrogen, helium, etc.) leave unique fingerprints in the spectrum. These fingerprints reveal which elements are present, how hot the object is, and whether it is moving.

The universe is expanding, and galaxies are moving away from us. When they do, their light becomes stretched—similar to how the pitch of a siren drops as an ambulance drives away.

This stretching is called redshift, as the light shifts towards the red end of the spectrum. A higher redshift means the object is moving away faster, and often indicates it is further away. This helps astronomers understand the structure and history of the universe—how it has evolved over time.


### Load the data


In [None]:
import pandas as pd

# Load the dataset
df = pd.read_csv("https://raw.githubusercontent.com/martyn-harris-bbk/AppliedMachineLearning/refs/heads/main/data/star_classification.csv")

# Sample rows per class from the 'class' column (some models like SVM, take a long time to train)
# You can use the full dataset by commenting out the line below if you have the hardware and time:
df = df.groupby('class', group_keys=False).sample(n=1000, random_state=42)

# Preview the balanced sample
print(df['class'].value_counts())
df.head()


In [None]:
df.describe()

Let's look at the variables in the dataset that provide useful information about each object:

| Column Name     | Description                                                                 |
|------------------|-----------------------------------------------------------------------------|
| obj_ID           | Unique ID for each object                                                   |
| alpha            | Right ascension (RA) – celestial longitude in degrees                       |
| delta            | Declination – celestial latitude in degrees                                 |
| u                | Ultraviolet filter magnitude                                                |
| g                | Green filter magnitude                                                      |
| r                | Red filter magnitude                                                        |
| i                | Near-infrared filter magnitude                                              |
| z                | Infrared filter magnitude                                                   |
| run_ID           | Observation run ID                                                          |
| rereun_ID        | Data processing ID                                                          |
| cam_col          | Camera column where the object was observed                                 |
| field_ID         | ID of the field (region) of sky observed                                    |
| spec_obj_ID      | Spectroscopic object ID                                                     |
| class            | **Target column**: Object type – `GALAXY`, `STAR`, or `QSO`                |
| redshift         | Redshift value – used to estimate distance and motion of the object         |
| plate            | Spectroscopic plate ID                                                      |
| MJD              | Modified Julian Date of the observation                                     |
| fiber_ID         | Fiber ID used in spectroscopic observation                                  |
| spec_class       | Spectral class (`GALAXY`, `QSO`, `STAR`, etc.)                              |

### The target variable
The column `class` is the label we want to predict. It tells us the type of astronomical object we're looking at. The values in this column include:

- `GALAXY` – A massive collection of stars, gas, and dust bound together by gravity.
- `QSO` (Quasi-Stellar Object or quasar) – Extremely bright and distant objects powered by black holes at the centre of galaxies.
- `STAR` – A luminous sphere of plasma, like our Sun.

In a classification task, the goal is to use the other data (features) to correctly assign new objects to one of these three categories.

This means it's important to understand what you're trying to predict (the target) and which pieces of data might help you make that prediction.

We can then select these good features, and preprocess them if necessary, to make them suitable for our machine learning models.

#### Photometric magnitudes

The columns `u`, `g`, `r`, `i`, `z` represent how bright an object appears in different parts of the light spectrum, from ultraviolet to infrared.

Since galaxies, stars, and quasars have different light profiles, these magnitudes can be powerful indicators for classification. For example:

- *Galaxies* have more balanced light across *all* bands.
- *Quasars* tend to be brighter in `u` (ultraviolet).
- *Stars* might appear brighter in the `g `and `r` bands.

#### Redshift

The column `redshift` measures how much the light from an object has been stretched due to the expansion of the universe. Higher redshift = further away and often implies the object is moving faster.

Typically, we tend to observe that:

- *Galaxies* have moderate redshifts.
- *Quasars* tend to have very high redshifts (they're far away).
- *Stars* have very low redshift values (they're within our own galaxy).

So redshift is often a strong hint about the object’s identity (our target label).

### Preprocessing

Now we understand which features we will use, let's preprocess the data to prepare it for feature selection.  The following columns don’t help the model learn the true patterns in the data (i.e., the science), and some could cause the model to latch onto artefacts or noise, resulting in poor generalisation. This because most of them refer to unique identifiers, so we will remove them:

The goal is to train a model that learns the true physical characteristics of stars, galaxies, and quasars—not how the data was indexed or how the telescope happened to record them:

| Column Name                | Reason for Removal                                                                                                                                              |
|---------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `obj_ID`                  | Just a unique identifier (like a serial number); it has no relationship to the object's physical properties. Keeping it could confuse the model or introduce overfitting. |
| `spec_obj_ID`             | Another unique ID, this time for the spectroscopic observation. Again, not useful for learning patterns.                                                        |
| `run_ID`, `rerun_ID`      | These are technical details about the observation process—not about the object itself.                                                                           |
| `field_ID`, `fiber_ID`, `plate` | These refer to the telescope hardware or sky region where the object was captured. While they might hold some observational bias, they don’t provide physical characteristics of the object. In general, we avoid training models on metadata unless there’s a strong justification. |
| `MJD` |  Modified Julian Date represents the date of observation. It might correlate with some observational quirks, but it’s not related to whether an object is a star, galaxy, or quasar. |

Let's start by removing them:

In [None]:
df = df.drop(columns=[
    'obj_ID', 'spec_obj_ID', 'run_ID', 'rerun_ID',
    'field_ID', 'fiber_ID', 'plate', 'MJD'
])


Let's also drop missing values:

In [None]:
df = df.dropna()


### Feature selection

We now prepare the dataset for a classification task by selecting relevant features and encoding the target labels. We first define the input features (`u`, `g`, `r`, `i`, `z`, `redshift`) and the target column (`class`), then extract these columns from the DataFrame into `X` (features) and `Y` (labels).

Since the target values are categorical (`'GALAXY'`, `'QSO'`, `'STAR'`), the code uses `LabelEncoder` from scikit-learn to convert these string labels into numeric values (e.g., 0, 1, 2), which are required for most machine learning models.

Additionally, we create a dictionary (`label_mapping`) that shows how the original class names were mapped to numbers. This is helpful for interpreting model predictions later, as it allows you to translate encoded labels back to their original form:

In [None]:
from sklearn.preprocessing import LabelEncoder
import pandas as pd

# Our features and target
features = ['u', 'g', 'r', 'i', 'z', 'redshift']
target = 'class'

# Use the full dataset
X = df[features]
Y = df[target]

# Encode the target labels (Galaxy, QSO, Star -> 0, 1, 2)
label_encoder = LabelEncoder()
Y= label_encoder.fit_transform(Y)

# Optional: View the mapping of class names to encoded values
label_mapping = dict(zip(label_encoder.classes_, label_encoder.transform(label_encoder.classes_)))

print("Label Mapping:", label_mapping)

As we might have mentioned, many machine learning models — such as *K-Nearest Neighbours (KNN)*, *Support Vector Machines (SVM)*, and *Logistic Regression* — can be affected by the scale or range of the numbers in your dataset. For example, if one feature (like income) ranges from 0 to 100,000 and another (like age) ranges from 0 to 100, the model might give more importance to the larger numbers, even if they are not more important.

To fix this, we use a *StandardScaler*, which makes all the features follow the same scale by removing the average (*mean*) from each feature and then scaling them so that they have a *standard deviation* of 1. In short, after this process, all features have a *mean of 0* and *equal spread*, so the model treats them fairly.

Another helpful technique, especially when we're splitting data into *training* and *test* sets, is to use the *stratify* option — for example: `stratify=y_encoded`. This makes sure that the *class balance* (such as how many examples of each category you have, like "cat" and "dog" or "spam" and "not spam") is kept the same in both the training and test sets. This is important because if one set has mostly one class, the model may become biased or inaccurate.

These steps help ensure that your model trains in a fair and balanced way, giving it a better chance of performing well on new data.

In [None]:
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

# Scale the feature values
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)  # Fit to data, then transform it

# Keep the process reproducible
seed = 7

# Split the dataset into training (80%) and test sets (20%)
X_train, X_test, Y_train, Y_test = train_test_split(
    X_scaled, Y, test_size=0.2, random_state=seed, stratify=Y)

print(X_train.shape)
print(X_test.shape)

### Logistic Regression classification
Imagine you're trying to figure out what type of space object something is—maybe a star, a galaxy, or a quasar—based on things like how bright it is and how far away it seems. You have a big list of known objects with this kind of information. Now, you want a method that looks at that data and learns patterns to help you make predictions for new objects.

Logistic regression is a simple and popular method used to predict categories, especially when there are just two options—like whether an email is spam or not spam, or if a star is a galaxy or not. The name comes from its mathematical roots. It starts like a linear regression (which predicts numbers), but then it uses a special function called the logistic function to turn those numbers into probabilities between 0 and 1.

It works by looking at numerical input data (like brightness or redshift) and tries to find patterns that separate one category from the other. It assumes the numbers follow a typical bell-curve shape (called a Gaussian distribution) but still works pretty well even if that's not perfectly true.

While logistic regression is originally designed for binary classification (two classes), scikit-learn’s `LogisticRegression` automatically extends it to handle multiple classes. You can read the <a href="https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html">API Documentation</a> for more information:

In [None]:
from sklearn.model_selection import KFold
from sklearn.model_selection import cross_val_score
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report

num_folds = 10

# Set up K-Fold cross-validation:
# shuffle=True ensures the data is shuffled before splitting (important for randomness)
kfold = KFold(n_splits=10, random_state=seed, shuffle=True)

# Define the logistic regression model
# solver='lbfgs' is a good choice for multinomial problems
# max_iter=2000 allows more iterations in case the model takes longer to converge

# If you get a warning, your model still produce an accuracy—but it might not
# be as accurate or stable as it could be if it had fully converged.
model = LogisticRegression(solver='saga', max_iter=2000)

# Evaluate the model using cross-validation
results = cross_val_score(model, X, Y, cv=kfold)

# Print the average accuracy across all 10 folds
print(f'k-Fold Accuracy: {results.mean()*100:.3f}% ({results.std()*100:.3f}%)')

# Train the model on the full training data
model.fit(X_train, Y_train)

# Make predictions on the test data
y_pred = model.predict(X_test)

# Generate and print the classification report
report = classification_report(Y_test, y_pred)
print(report)

### k-Nearest Neighbors (KNN)

KNN uses a distance metric to find the k most similar instances in
the training data for a new instance and takes the mean outcome of the neighbors as the prediction.

In other words, imagine you see a new star in the sky, and you want to know what type it is—a galaxy, a quasar, or just a star. You don’t have a formula to figure it out, but you do have a big notebook of space objects you’ve seen before, each labelled with what it is.

What would you do? You might say "Let me find the objects in my notebook that look most similar to this new one—and go with whatever type they are.” That’s exactly what K-Nearest Neighbours (KNN) does.

When given a new object (like a star with certain brightness and redshift), KNN measures how similar it is to everything else in the training data. It looks for the K most similar examples (its "nearest neighbours"). Then it checks which class (e.g., STAR, GALAXY, QSO) is the most common among those neighbours. That’s the prediction!  If you're predicting numbers (like temperature), it might take the average instead.  

You can construct a KNN model using the ```KNeighborsClassifier``` class. See the <a href="https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html">API documentation</a> for more details.

In [None]:
from sklearn.model_selection import KFold
from sklearn.model_selection import cross_val_score
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report

num_folds = 10

kfold = KFold(n_splits=10, shuffle=True, random_state=seed)

model = KNeighborsClassifier()

results = cross_val_score(model, X, Y, cv=kfold)

print(f'k-Fold Accuracy: {results.mean()*100:.3f}% ({results.std()*100:.3f}%)')

# Train the model on the full training data
model.fit(X_train, Y_train)

# Make predictions on the test data
y_pred = model.predict(X_test)

# Generate and print the classification report
report = classification_report(Y_test, y_pred)
print(report)

### Support Vector Machines (SVM)
Imagine you're looking at a scatterplot of stars, galaxies, and quasars, based on their brightness and redshift. Your goal is to draw a line (or curve) that clearly separates these objects into distinct groups. This is exactly what Support Vector Machines (SVMs) aim to do - they try to find the best possible boundary that separates one type of object from another.

More formally, SVMs are designed to separate two classes by finding a line (in 2D) or a hyperplane (in higher dimensions) that leaves the widest possible gap between the classes. The most important data points in this process are the ones closest to the boundary—these are called support vectors. They directly influence where the dividing line is drawn.  SVMs have also been extended to handle multiple classes, like those in the SDSS dataset.

A key strength of SVMs is their use of kernel functions, which allow them to draw curved boundaries instead of just straight lines. By default, SVM uses a kernel called the *Radial Basis Function (RBF)*, which works well for many real-world problems. *RBF* measures the similarity between points. It gives a high value when two points are close together and a low value when they are far apart. The idea is that nearby data points should have more influence on each other than distant ones.

You can build an SVM model in scikit-learn using the SVC class:

In [None]:
from sklearn.model_selection import KFold
from sklearn.model_selection import cross_val_score
from sklearn.svm import SVC
from sklearn.metrics import classification_report

kfold = KFold(n_splits=10, random_state=seed, shuffle=True)

model = SVC()

results = cross_val_score(model, X, Y, cv=kfold)

print(f'k-Fold Accuracy: {results.mean()*100:.3f}% ({results.std()*100:.3f}%)')

# Train the model on the full training data
model.fit(X_train, Y_train)

# Make predictions on the test data
y_pred = model.predict(X_test)

# Generate and print the classification report
report = classification_report(Y_test, y_pred)
print(report)

### Classification and Regression Trees (CART or just decision trees)

Decision trees work a bit like playing 20 questions—they try to ask the best possible question at each step to help sort data into the right categories. At every point (called a *node*), the algorithm chooses a feature (like brightness or redshift) and decides how to split the data based on that feature.

Choosing the right feature to split on is *crucial*. The goal is to pick the one that does the best job of dividing the data into clean groups. This is achieved by the algorithm trying all features, whereupon it chooses the ones that give the most *information gain* (for classification) or reduces uncertainty or "impurity" the most.

The metric used depends on the type of decision tree:
- *Gini impurity* (used in CART)
- *Entropy/information gain* (used in ID3 and C4.5)
- *Variance reduction* (for regression trees)

For categorical features, the tree splits the data based on different class labels. For numerical features, it looks for the best cutoff value (like "is redshift > 0.5?") to divide the data.

This process continues until the tree reaches a stopping point, such as:
- All remaining data belongs to one class (a "pure" leaf),
- The tree has reached a *maximum depth* (to avoid getting too complicated),
- Or when there are not enough samples left to make another meaningful split.

After the tree is built, we can make it simpler and more reliable by using *pruning*. This means trimming off parts of the tree that do not help much with predictions. Techniques like cost-complexity pruning or *reduced-error* pruning help prevent the model from overfitting the training data—making it better at predicting new, unseen examples.

You can create a decision tree model using scikit-learn’s `DecisionTreeClassifier` class.  You can find more details in the [API documentation](https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html). Here’s an example of how we build one on our data:

In [None]:
from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.metrics import classification_report

# Fit a model
model = DecisionTreeClassifier(criterion = "gini", random_state=seed)

results = cross_val_score(model, X, Y, cv=kfold)

print(f'k-Fold Accuracy: {results.mean()*100:.3f}% ({results.std()*100:.3f}%)')

# Train the model on the full training data
model.fit(X_train, Y_train)

# Make predictions on the test data
y_pred = model.predict(X_test)

# Generate and print the classification report
report = classification_report(Y_test, y_pred)
print(report)

### Decision trees and overfitting
Decision trees can sometimes become too complex, trying to fit every little detail in the training data—even the noise or outliers.

When this happens, the tree performs well on training data, but struggles to make accurate predictions on new, unseen data. This problem is known as *overfitting*.

To prevent overfitting, it's important to limit the complexity of the tree. We will focus on pruning. Pruning a decision tree is a simple and effective way to reduce overfitting, and we can simply *prune* the tree by limiting its size or shape using a few key parameters. Some of the most commonly used parameters for pruning include:

- `max_leaf_nodes`: Limits the number of leaf nodes in the tree. Fewer leaf nodes generally mean a simpler tree.
- `min_samples_leaf`: Sets the minimum number of samples required to be at a leaf node. Helps avoid tiny splits that are not meaningful.
- `max_depth`: Limits how deep the tree can go, controlling how many decisions the model can make in sequence.

In the example below, we apply these pruning parameters to a `DecisionTreeClassifier`, and evaluate the pruned model using cross-validation and our trusty classification report:


In [None]:
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import cross_val_score
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
import numpy as np

# Set random seed for reproducibility
seed = 7

# Define a pruned decision tree model
model = DecisionTreeClassifier(
    criterion="gini",           # Use Gini impurity to measure splits
    splitter="random",          # Randomly choose features for splitting (adds randomness)
    max_leaf_nodes=10,          # Limit the number of leaf nodes (controls tree size)
    min_samples_leaf=5,         # Require at least 5 samples in each leaf node
    max_depth=5,                # Limit tree depth to 5 levels
    random_state=seed           # Set seed for reproducibility
)

# Evaluate the model using k-fold cross-validation
results = cross_val_score(model, X_train, Y_train, cv=kfold)

# Print mean accuracy and standard deviation from cross-validation
print(f'k-Fold Accuracy: {results.mean() * 100:.3f}% ({results.std() * 100:.3f}%)')

# Train the model on the full training set
model.fit(X_train, Y_train)

# Predict the labels for the test set
y_pred = model.predict(X_test)

# Generate a classification report showing precision, recall, F1-score, and support
report = classification_report(Y_test, y_pred)
print(report)


This approach demonstrates how adjusting just a few parameters can result in a simpler, more generalisable decision tree that performs well on unseen data. You can experiment with different values to see how they affect accuracy and overfitting.

### Comparing Machine Learning algorithms
When working on a machine learning project, it’s common to build and test several different models. Each algorithm has its own strengths and weaknesses, and their performance can vary depending on the dataset. To choose the best one, we need a fair and consistent way to compare them.

One effective approach is to use *cross-validation*, which gives you an estimate of how well each model is likely to perform on new, unseen data. These estimates help you decide which models are worth keeping and fine-tuning.

Just like it’s helpful to visualise your dataset from different angles, it’s also important to visualise model performance from multiple perspectives. You can use graphs and plots to compare things like:
- Average accuracy
- Variability (spread or consistency)
- Distribution of scores across different folds

This gives you a clearer picture of which models are both accurate and reliable.

To ensure a fair comparison, every model should be tested in exactly the same way—using the same data splits, same evaluation method, and same random seed. This ensures that differences in performance come from the models themselves, not from how they were tested.

In the example below, six popular classification algorithms are compared using *10-fold cross-validation*, with a consistent test setup (we have explored some of these before):

- Logistic Regression  
- Linear Discriminant Analysis  
- k-Nearest Neighbours (KNN)  
- Decision Trees (CART)  
- Naive Bayes  
- Support Vector Machines (SVM)

Each algorithm is assigned a short label to make it easier to summarise and visualise results later on. Creating models and storing them in a dictionary or list can really help test a range of models in one go:

In [None]:
from sklearn.model_selection import KFold
from sklearn.model_selection import cross_val_score
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import SVC

# prepare models
models = []
models.append(('LR', LogisticRegression(solver='liblinear')))
models.append(('LDA', LinearDiscriminantAnalysis()))
models.append(('KNN', KNeighborsClassifier()))
models.append(('CART', DecisionTreeClassifier()))
models.append(('NB', GaussianNB()))
models.append(('SVM', SVC()))

# evaluate each model in turn
results = []
names = []

scoring = 'accuracy'

for name, model in models:
   kfold = KFold(n_splits=10, shuffle=True, random_state=seed)

   cv_results = cross_val_score(model, X, Y, cv=kfold, scoring=scoring)

   results.append(cv_results)

   names.append(name)

   msg = "%s: %f (%f)" % (name, cv_results.mean(), cv_results.std())

   print(msg)

Looking at just the average accuracy, allows us to pick a smaller selection of models to work with and improve if they already show good performance from this initial investigation.

### Tuning models

Tuning a model means making small adjustments to improve how well it performs. It’s one of the final steps in building a machine learning system, just before deciding that the model is ready to use.

This process is often called *hyperparameter tuning* (or *hyperparameter optimisation*). The word *hyperparameter* refers to the settings or choices you make before the model starts learning — for example, how many neighbours to use in KNN, or how deep a decision tree can go. These are different from *parameters*, which are the values the model learns by itself from the data (like the weights in linear regression).

When we say *optimisation*, we use it to refer to the fact that because tuning is like a search — you're trying to find the best combination of settings for your model, like searching for the right ingredients in a recipe to get the best result.

#### Grid search
The first approach we can try is *Grid search*, which is a methodical way to test different combinations of settings (hyperparameters). You first define a list of possible values for each setting, and the algorithm will try out *every* possible combination of those values to see which performs best.

You can do this in code using the `GridSearchCV` class from scikit-learn. This class runs all the combinations, checks the performance, and shows you the best setup. See the [API Documentation](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html) to learn how to use it in more detail.

Here we will use it to help prune a decision tree, by picking the best value for `max_depth`:

In [None]:
import numpy
from sklearn.linear_model import Ridge
from sklearn.model_selection import GridSearchCV

# The varying max_depth parameters to test and choose the best from
vals = numpy.array([12, 10, 8, 6, 4])

param_grid = dict(max_depth=vals) # The key is the parameter name

model = DecisionTreeClassifier(
    criterion="gini",           # Use Gini impurity to measure splits
    splitter="random",          # Randomly choose features for splitting (adds randomness)
    random_state=seed           # Set seed for reproducibility
)

grid = GridSearchCV(estimator=model, param_grid=param_grid)
grid.fit(X_train, Y_train)

print("Best score", grid.best_score_)
print("Recommended max_depth value", grid.best_estimator_.max_depth)

### Random search

Random search is another way to tune a model by trying out different settings (called *hyperparameters*) — but instead of trying *every* possible combination like grid search, it picks them *randomly*. Think of it like reaching into a bag of possibilities and pulling out a few combinations to test, rather than checking them all. This can be much faster, especially when there are a lot of combinations, and it often still finds a very good result.

In random search, values for each hyperparameter are picked randomly from a range you specify (usually using a *uniform distribution*, which means all values in the range have an equal chance of being chosen). The model is built and tested for each combination, just like in grid search.

To do this in Python, you can use the `RandomizedSearchCV` class from scikit-learn. It handles the random selection, training, and evaluation for you. See the [API Documentation](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html) for details on how to use it.

In the example below, we explore a wide variety of settings without needing to try every single possible one. Again, we will tune our decision tree:

In [None]:
from scipy.stats import randint
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import RandomizedSearchCV

# Suppress convergence warnings - not all parameters tested will be sensible
import warnings
from sklearn.exceptions import ConvergenceWarning
warnings.filterwarnings("ignore", category=ConvergenceWarning)

# Define integer-based parameter grid
param_grid = {
    'C': randint(1, 100),             # Inverse regularisation strength (converted to float in model)
    'max_iter': randint(100, 2000),   # Number of iterations
}
model = LogisticRegression(solver='saga', random_state=seed)

rsearch = RandomizedSearchCV(estimator=model, param_distributions=param_grid, n_iter=100, random_state=seed)
rsearch.fit(X_train, Y_train)

print("Best score:", rsearch.best_score_)
print("Recommended C:", rsearch.best_estimator_.C)
print("Recommended max_iter:", rsearch.best_estimator_.max_iter)


Both tuning approaches narrow down the search for suitable parameters leaving you to experiment and judge from the measures of accuracy you get from the models.

### Saving and loading machine learning models

Once you've trained a machine learning model and found one that performs well, it’s a good idea to save it so you can use it later — either in another script, in production, or simply to avoid retraining every time. This process is useful for sharing models or deploying them in real-world applications.

In Python, a common way to save a model is by using a built-in tool called *pickle*. Pickle is a standard Python module that allows you to *serialise* objects. *Serialising* means converting the model into a format that can be stored in a file. Later, you can *deserialise* the file, which means loading the model back into memory exactly as it was.

This is especially useful in machine learning, where training a model can take time. Saving the model once lets you skip that step when you want to make new predictions later. You can find more details in the [Pickle API documentation](https://docs.python.org/2/library/pickle.html), but the basic process is very straightforward. The example below shows how to:
- Train a *logistic regression* model on a dataset.
- Save the trained model to a file using *pickle*.
- Load the model back from the file.
- Use the loaded model to make predictions on a separate *test set*.

This workflow helps you build once, reuse many times — an essential step when moving from experimentation to production or deployment:

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

from pickle import dump
from pickle import load

model = LogisticRegression(max_iter=1000)
model.fit(X_train, Y_train)

# save the model to disk
filename = './finalised_model.sav'

print("Saving", filename)
dump(model, open(filename, 'wb'))


# some time later...


# load the model from disk
print("Loading", filename)
loaded_model = load(open(filename, 'rb'))

result = loaded_model.score(X_test, Y_test)

print(result)

### What have we learnt?

We explored the process of building, evaluating, and comparing machine learning models using the SDSS17 astronomical dataset. Along the way, we broke down key concepts into clear, beginner-friendly explanations to help build a strong foundation in both data science and astronomy.

We started by understanding the structure of the dataset, including features like photometric magnitudes and redshift, and why it's important to remove irrelevant columns (such as object IDs and technical metadata) during preprocessing. We then saw how to prepare the data using techniques like *label encoding*, *feature scaling*, and *train-test splitting*.

From there, we explored several popular machine learning algorithms:

- *Logistic Regression*: A model that estimates probabilities and is well-suited for classification, including multi-class problems when configured properly.
- *K-Nearest Neighbours (KNN)*: A simple, intuitive model that predicts labels based on the most similar examples in the training data.
- *Support Vector Machines (SVM)*: A powerful classifier that finds the optimal boundary between classes, with the ability to handle complex patterns using kernels.
- *Decision Trees*: Models that split the data based on the most informative features, but can overfit if not properly controlled with techniques like max depth and pruning.

We also discussed how to avoid overfitting and why it's important to evaluate models fairly and consistently, using tools like *cross-validation*. Finally, we looked at how to compare multiple models side by side and inspect their accuracy and variability to select the best-performing ones for the task.

We demonstrated had to tune your chosen model to find the best parameters, and how to then save and load the model for future use. You should hopefully now feel better equipped to build robust machine learning solutions—even for complex datasets like those in astronomy.