In [None]:
# default_exp core

# Core library

> Helper functions used throughout the lessons

In [None]:
# export
import pandas as pd
from nbdev.showdoc import *
import os
import gdown
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

## Datasets

In [None]:
# export
def download_dataset(dataset_name: str):
    """Download datasets from Google Drive."""

    name_to_id = {
        "susy.csv.gz": "1rnR1v-BkMOtzV80R7jFyU1cwO3fGYrQs",
        "susy.feather": "1PxCruwO42GV7FKtwZDXah7iGjDib7YPM",
        "susy_train.feather": "1ezeCZycZ3BrEh-qOLiSJF40YowYEbbTH",
        "susy_test.feather": "1UM8sheb4jzQa16haG6HnVbpJCxZwN2yE",
        "susy_sample.feather": "1l4x_uBeup4eciLDK4YjnfY_G8yTpXLkP",
    }

    path = "../data/"
    os.makedirs(path, exist_ok=True)
    gdrive_path = "https://drive.google.com/uc?id="
    if dataset_name in name_to_id:
        if os.path.exists(path + dataset_name):
            print(
                f"Dataset already exists at '{path + dataset_name}' and is not downloaded again."
            )
            return
        try:
            file_url = gdrive_path + name_to_id[dataset_name]
            gdown.download(file_url, path + dataset_name, quiet=True)
        except Exception as e:
            print("Something went wrong during the download! Try again.")
            raise e
        print(f"Download of {dataset_name} dataset complete.")
    else:
        raise KeyError("File not on Google Drive.")

### SUSY

The SUSY dataset from the [UCI Machine Learning repository](http://archive.ics.uci.edu/ml/datasets/SUSY#):

In [None]:
download_dataset("susy.csv.gz")

Download of susy.csv.gz dataset complete.


A compressed version in [feather format](https://blog.rstudio.com/2016/03/29/feather/) is also available for faster loading in-class:

In [None]:
download_dataset("susy.feather")

Download of susy.feather dataset complete.


To get the training (first 4,500,000 rows) and test (last 500,000 rows) sets, run:

In [None]:
download_dataset("susy_train.feather")
download_dataset("susy_test.feather")

Download of susy_train.feather dataset complete.
Download of susy_test.feather dataset complete.


To get a random sample of 100,000 rows from `susy_train`, run:

In [None]:
download_dataset("susy_sample.feather")

Download of susy_sample.feather dataset complete.


## Data wrangling

In [None]:
# export
def display_large(df):
    """Displays up to 1000 columns and rows of pandas.DataFrame or pandas.Series objects."""
    with pd.option_context("display.max_rows", 1000, "display.max_columns", 1000):
        display(df)

In [None]:
# export
def rf_feature_importance(fitted_model, df):
    "Creates a pandas.Dataframe of a Random Forest's feature importance per column."
    return pd.DataFrame(
        {"Column": df.columns, "Importance": fitted_model.feature_importances_}
    ).sort_values("Importance", ascending=False)

## Data visualisation

In [None]:
# export
def plot_feature_importance(feature_importance):
    fig, ax = plt.subplots(figsize=(12, 8))
    return sns.barplot(y="Column", x="Importance", data=feature_importance, color="b")