# Tutorial 2: Source Classification
https://astronomers.skatelescope.org/ska-science-data-challenge-1/

The last tutorial gave us two data frames, one used to train the machine learning model and the other to test the trained model. This notebook will go through the following:

    - Data exploration
    - Data pre_processing
    - Training
    - Testing

For this tutorial, we have three kinds of astronomical sources, namely star-forming galaxies (`SFGs`), steep-spectrum (`SS`) AGN, and flat-spectrum (`FS`) AGN. The primary purpose is to use features from Pydsf to classify astronomical sources.

---

First, let us recall the data from the previous notebook

In [None]:
%store -r sources_training
%store -r sources_full

#### Examining data

First let's take a look at the raw data

In [None]:
sources_training[1400]

We may verify the shape of the data frame:

In [None]:
print(sources_training[1400].shape)

The data frame above, in its current state, is not suited to be used for ML for the following reasons: 
   * Some columns need to be excluded, like Source_ID,...
   * Perform cross matching aginst the truth catalog and get the matched output (meaning get all actual sources where Pydsf correctly identified sources.
   * Include the matched output from Pydsf (ground truth) to perform supervised learning.
   
   In conclusion, we will need to perform some pre-processing on the data frame provided by Pydsf

---

#### Data pre-processing

Importing some libraries:

In [None]:
from source.utils.bdsf_utils import  load_truth_df, cat_df_from_srl_df 
from source.path import train_truth_path, full_truth_path, write_df_to_disk, submission_df_path
from source.utils.columns import SRL_CAT_COLS, SRL_COLS_TO_DROP, SRL_NUM_COLS
from source.utils.classification import SKLClassification

In [None]:
model_pre = SKLClassification()
train_truth_cat_df = load_truth_df(train_truth_path(1400), skiprows=18) # the tuth cataluge for the training data

"""
Args:
    srl_df (:obj:`pandas.DataFrame`): Source list.
    truth_cat_df (:obj:`pandas.DataFrame`): Truth catalogue.
    regressand_col: (`str`): Regressand column name (output).
    freq: (`int`): Frequency band (MHz).
Returns:
    srl_df (`str`): Crossmatched source list DataFrame used for training.
"""
train_df = model_pre.pre_process(sources_training[1400], train_truth_cat_df, regressand_col="class_t", freq=1400)

---

#### Examining the traning data

In [None]:
train_df

Now let us check the frequencies of the astronomical sources.

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

ax = sns.countplot(x="class_t",data=train_df)
for p in ax.patches:
    ax.annotate('{:.1f}'.format(p.get_height()), (p.get_x()+0.25, p.get_height()+0.01))

plt.show()

Where:
   - Class 1 ->   SS-AGN
   - Class 2 ->   FS-AGN
   - Class 3 ->   SFG

I have been asking myself why there is a majority of SFGs in the simulated data. If you do not know, ask the nearest postdoc you see.

**Exercise 1:** Clean and visualise the other 2 frequencies 
<br>

---

### Training

First, let us separate the data frame into input and output

In [None]:
from sklearn.utils import shuffle
 
train_df = shuffle(train_df) # we are shuffling the data.
train_x = train_df.drop(['class_t'], axis = 1) # input data
train_y = train_df['class_t'] # output data

In [None]:
print("Input dim "+ str(train_x.shape)) # input
print("output dim: "+ str(train_y.shape)) # output

In this notebook, we will use Random Forest to classify the 3 classes above

In [None]:
from sklearn.ensemble import RandomForestClassifier # a ML model
from sklearn.metrics import classification_report # this package is used to assess the accuracy of the ML models


In [None]:
forest = RandomForestClassifier(random_state=0)
forest.fit(train_x, train_y) # train the ML model on the training data

We can try and assess the model aginst the training data

In [None]:
y_pred = forest.predict(train_x)

In [None]:
print(classification_report (train_y, y_pred))

The score is 100% accurate, which is expected since the trained model has seen the data before. Now let us test it against the whole catalog (test set) without the training data.

**Exercise 2:** train the other 2 frequencies on their training data 
<br>

---

### Testing

In [None]:
sources_full[1400]["class"] = model_pre.test(forest, sources_full[1400]) # add the predicted output to the source df

In [None]:
sources_full[1400]

In [None]:
from source.path import full_truth_path, write_df_to_disk, submission_df_path, score_report_path, image_path
from ska_sdc import Sdc1Scorer
import os
from pathlib import Path


In [None]:
# 6) Create final catalogues and calculate scores
print("\nStep 6: Final score")
for freq, source_df in sources_full.items():
    # Assemble submission and truth catalogues for scoring
    sub_cat_df = cat_df_from_srl_df(source_df, guess_class=False)
    truth_cat_df = load_truth_df(full_truth_path(freq), skiprows=0)
    
    # (Optional) Write final submission catalogue to disk
    write_df_to_disk(sub_cat_df, submission_df_path(freq))

    # Calculate score
    scorer = Sdc1Scorer(sub_cat_df, truth_cat_df, freq)
    score = scorer.run(mode=0, train=False, detail=True) # train=False -> means that we are removing the training data from the evaluation, so there is no data leakage

    # Write short score report:
    score_path = score_report_path(freq)
    score_dir = os.path.dirname(score_path)
    Path(score_dir).mkdir(parents=True, exist_ok=True)

    with open(score_path, "w+") as report:
        report.write(
            "Image: {}, frequency: {} MHz\n".format(image_path(freq), freq)
        )
        report.write("Score was {}\n".format(score.value))
        report.write("Number of detections {}\n".format(score.n_det))
        report.write("Number of matches {}\n".format(score.n_match))
        report.write(
            "Number of matches too far from truth {}\n".format(score.n_bad)
        )
        report.write("Number of false detections {}\n".format(score.n_false))
        report.write("Score for all matches {}\n".format(score.score_det))
        report.write("Accuracy percentage {}\n".format(score.acc_pc))
        report.write("Classification report: \n")
        report.write(
            classification_report(
                score.match_df["class_t"],
                score.match_df["class"],
                labels=[1, 2, 3],
                target_names=["1 (SS-AGN)", "2 (FS-AGN)", "3 (SFG)"],
                digits=4,
            )
        )

print("\nComplete")

Your results can be found in:
`data/score/1400mhz_score.txt`

---

## Now it is your turn
 - on the other two frequencies, show us the following:
     - The images before and after clipping
     - Before and after PB correction
     - training image
- on all three training data frames ( from the three frequencies):
     - Try to improve/add to the pipeline we introduced on the training data. Maybe you can try [feature reduction](https://scikit-learn.org/stable/modules/feature_selection.html).
     - Try more [ML models](https://scikit-learn.org/stable/supervised_learning.html)
- Make a comparison between all three frequencies 
- Present the results