In [2]:
:dep polars = { version = "0.35.4", features = ["describe", "to_dummies", "ndarray", "random"] }
:dep polars-core
:dep candle-core

In [3]:
use polars::prelude::*;
use polars_core::prelude::*;
use candle_core::{Device, Tensor, Result};
use polars::frame::DataFrame;
use std::path::Path;

In [4]:
fn read_data_frame_from_csv(
    csv_file_path: &Path,
) -> DataFrame {
    CsvReader::from_path(csv_file_path)
        .expect("Cannot open file.")
        .has_header(true)
        .finish()
        .unwrap()
}

In [5]:

let iris_file_path: &Path = Path::new("Iris.csv");
let mut iris_df: DataFrame = read_data_frame_from_csv(iris_file_path);

In [6]:
iris_df.describe(None)?

shape: (9, 7)
┌────────────┬───────────┬──────────────┬──────────────┬──────────────┬──────────────┬─────────────┐
│ describe   ┆ Id        ┆ SepalLengthC ┆ SepalWidthCm ┆ PetalLengthC ┆ PetalWidthCm ┆ Species     │
│ ---        ┆ ---       ┆ m            ┆ ---          ┆ m            ┆ ---          ┆ ---         │
│ str        ┆ f64       ┆ ---          ┆ f64          ┆ ---          ┆ f64          ┆ str         │
│            ┆           ┆ f64          ┆              ┆ f64          ┆              ┆             │
╞════════════╪═══════════╪══════════════╪══════════════╪══════════════╪══════════════╪═════════════╡
│ count      ┆ 150.0     ┆ 150.0        ┆ 150.0        ┆ 150.0        ┆ 150.0        ┆ 150         │
│ null_count ┆ 0.0       ┆ 0.0          ┆ 0.0          ┆ 0.0          ┆ 0.0          ┆ 0           │
│ mean       ┆ 75.5      ┆ 5.843333     ┆ 3.054        ┆ 3.758667     ┆ 1.198667     ┆ null        │
│ std        ┆ 43.445368 ┆ 0.828066     ┆ 0.433594     ┆ 1.76442      ┆ 0.763

In [7]:
iris_df.head(Some(5))

shape: (5, 6)
┌─────┬───────────────┬──────────────┬───────────────┬──────────────┬─────────────┐
│ Id  ┆ SepalLengthCm ┆ SepalWidthCm ┆ PetalLengthCm ┆ PetalWidthCm ┆ Species     │
│ --- ┆ ---           ┆ ---          ┆ ---           ┆ ---          ┆ ---         │
│ i64 ┆ f64           ┆ f64          ┆ f64           ┆ f64          ┆ str         │
╞═════╪═══════════════╪══════════════╪═══════════════╪══════════════╪═════════════╡
│ 1   ┆ 5.1           ┆ 3.5          ┆ 1.4           ┆ 0.2          ┆ Iris-setosa │
│ 2   ┆ 4.9           ┆ 3.0          ┆ 1.4           ┆ 0.2          ┆ Iris-setosa │
│ 3   ┆ 4.7           ┆ 3.2          ┆ 1.3           ┆ 0.2          ┆ Iris-setosa │
│ 4   ┆ 4.6           ┆ 3.1          ┆ 1.5           ┆ 0.2          ┆ Iris-setosa │
│ 5   ┆ 5.0           ┆ 3.6          ┆ 1.4           ┆ 0.2          ┆ Iris-setosa │
└─────┴───────────────┴──────────────┴───────────────┴──────────────┴─────────────┘

In [8]:
iris_df
.hstack_mut(
    iris_df["Species"]
    .to_dummies(None, false)?
    .get_columns()
)?
.drop_in_place("Species")?;

In [9]:
iris_df.head(Some(5))

shape: (5, 8)
┌─────┬─────────────┬─────────────┬────────────┬────────────┬────────────┬────────────┬────────────┐
│ Id  ┆ SepalLength ┆ SepalWidthC ┆ PetalLengt ┆ PetalWidth ┆ Species_Ir ┆ Species_Ir ┆ Species_Ir │
│ --- ┆ Cm          ┆ m           ┆ hCm        ┆ Cm         ┆ is-setosa  ┆ is-versico ┆ is-virgini │
│ i64 ┆ ---         ┆ ---         ┆ ---        ┆ ---        ┆ ---        ┆ lor        ┆ ca         │
│     ┆ f64         ┆ f64         ┆ f64        ┆ f64        ┆ i32        ┆ ---        ┆ ---        │
│     ┆             ┆             ┆            ┆            ┆            ┆ i32        ┆ i32        │
╞═════╪═════════════╪═════════════╪════════════╪════════════╪════════════╪════════════╪════════════╡
│ 1   ┆ 5.1         ┆ 3.5         ┆ 1.4        ┆ 0.2        ┆ 1          ┆ 0          ┆ 0          │
│ 2   ┆ 4.9         ┆ 3.0         ┆ 1.4        ┆ 0.2        ┆ 1          ┆ 0          ┆ 0          │
│ 3   ┆ 4.7         ┆ 3.2         ┆ 1.3        ┆ 0.2        ┆ 1          ┆ 0 

In [22]:
let iris_df = iris_df.sample_frac(&Series::new("frac", &[1.0]), false, true, Some(42))?; // shuffle the data

In [23]:
iris_df

shape: (150, 8)
┌─────┬─────────────┬─────────────┬────────────┬────────────┬────────────┬────────────┬────────────┐
│ Id  ┆ SepalLength ┆ SepalWidthC ┆ PetalLengt ┆ PetalWidth ┆ Species_Ir ┆ Species_Ir ┆ Species_Ir │
│ --- ┆ Cm          ┆ m           ┆ hCm        ┆ Cm         ┆ is-setosa  ┆ is-versico ┆ is-virgini │
│ i64 ┆ ---         ┆ ---         ┆ ---        ┆ ---        ┆ ---        ┆ lor        ┆ ca         │
│     ┆ f64         ┆ f64         ┆ f64        ┆ f64        ┆ i32        ┆ ---        ┆ ---        │
│     ┆             ┆             ┆            ┆            ┆            ┆ i32        ┆ i32        │
╞═════╪═════════════╪═════════════╪════════════╪════════════╪════════════╪════════════╪════════════╡
│ 75  ┆ 6.4         ┆ 2.9         ┆ 4.3        ┆ 1.3        ┆ 0          ┆ 1          ┆ 0          │
│ 10  ┆ 4.9         ┆ 3.1         ┆ 1.5        ┆ 0.1        ┆ 1          ┆ 0          ┆ 0          │
│ 116 ┆ 6.4         ┆ 3.2         ┆ 5.3        ┆ 2.3        ┆ 0          ┆ 

In [48]:
const TRAIN_FRAC: f64 = 0.8;
let n_rows = iris_df.height();
let n_train_examples = (TRAIN_FRAC * n_rows as f64) as usize;


In [49]:
let df_train = iris_df.slice(0, n_train_examples);
let df_test = iris_df.slice(n_train_examples as i64, n_rows);

In [50]:
df_train.tail(Some(9))

shape: (9, 8)
┌─────┬─────────────┬─────────────┬────────────┬────────────┬────────────┬────────────┬────────────┐
│ Id  ┆ SepalLength ┆ SepalWidthC ┆ PetalLengt ┆ PetalWidth ┆ Species_Ir ┆ Species_Ir ┆ Species_Ir │
│ --- ┆ Cm          ┆ m           ┆ hCm        ┆ Cm         ┆ is-setosa  ┆ is-versico ┆ is-virgini │
│ i64 ┆ ---         ┆ ---         ┆ ---        ┆ ---        ┆ ---        ┆ lor        ┆ ca         │
│     ┆ f64         ┆ f64         ┆ f64        ┆ f64        ┆ i32        ┆ ---        ┆ ---        │
│     ┆             ┆             ┆            ┆            ┆            ┆ i32        ┆ i32        │
╞═════╪═════════════╪═════════════╪════════════╪════════════╪════════════╪════════════╪════════════╡
│ 50  ┆ 5.0         ┆ 3.3         ┆ 1.4        ┆ 0.2        ┆ 1          ┆ 0          ┆ 0          │
│ 51  ┆ 7.0         ┆ 3.2         ┆ 4.7        ┆ 1.4        ┆ 0          ┆ 1          ┆ 0          │
│ 42  ┆ 4.5         ┆ 2.3         ┆ 1.3        ┆ 0.3        ┆ 1          ┆ 0 

In [51]:
df_test.head(Some(9))

shape: (9, 8)
┌─────┬─────────────┬─────────────┬────────────┬────────────┬────────────┬────────────┬────────────┐
│ Id  ┆ SepalLength ┆ SepalWidthC ┆ PetalLengt ┆ PetalWidth ┆ Species_Ir ┆ Species_Ir ┆ Species_Ir │
│ --- ┆ Cm          ┆ m           ┆ hCm        ┆ Cm         ┆ is-setosa  ┆ is-versico ┆ is-virgini │
│ i64 ┆ ---         ┆ ---         ┆ ---        ┆ ---        ┆ ---        ┆ lor        ┆ ca         │
│     ┆ f64         ┆ f64         ┆ f64        ┆ f64        ┆ i32        ┆ ---        ┆ ---        │
│     ┆             ┆             ┆            ┆            ┆            ┆ i32        ┆ i32        │
╞═════╪═════════════╪═════════════╪════════════╪════════════╪════════════╪════════════╪════════════╡
│ 27  ┆ 5.0         ┆ 3.4         ┆ 1.6        ┆ 0.4        ┆ 1          ┆ 0          ┆ 0          │
│ 84  ┆ 6.0         ┆ 2.7         ┆ 5.1        ┆ 1.6        ┆ 0          ┆ 1          ┆ 0          │
│ 30  ┆ 4.7         ┆ 3.2         ┆ 1.6        ┆ 0.2        ┆ 1          ┆ 0 

In [60]:
df_train.head(Some(9))

shape: (9, 8)
┌─────┬─────────────┬─────────────┬────────────┬────────────┬────────────┬────────────┬────────────┐
│ Id  ┆ SepalLength ┆ SepalWidthC ┆ PetalLengt ┆ PetalWidth ┆ Species_Ir ┆ Species_Ir ┆ Species_Ir │
│ --- ┆ Cm          ┆ m           ┆ hCm        ┆ Cm         ┆ is-setosa  ┆ is-versico ┆ is-virgini │
│ i64 ┆ ---         ┆ ---         ┆ ---        ┆ ---        ┆ ---        ┆ lor        ┆ ca         │
│     ┆ f64         ┆ f64         ┆ f64        ┆ f64        ┆ i32        ┆ ---        ┆ ---        │
│     ┆             ┆             ┆            ┆            ┆            ┆ i32        ┆ i32        │
╞═════╪═════════════╪═════════════╪════════════╪════════════╪════════════╪════════════╪════════════╡
│ 75  ┆ 6.4         ┆ 2.9         ┆ 4.3        ┆ 1.3        ┆ 0          ┆ 1          ┆ 0          │
│ 10  ┆ 4.9         ┆ 3.1         ┆ 1.5        ┆ 0.1        ┆ 1          ┆ 0          ┆ 0          │
│ 116 ┆ 6.4         ┆ 3.2         ┆ 5.3        ┆ 2.3        ┆ 0          ┆ 0 

In [63]:
let X_train = dataframe_to_tensor(&df_train.select(["SepalLengthCm", "SepalWidthCm", "PetalLengthCm", "PetalWidthCm"])?)?;
let y_train = dataframe_to_tensor(&df_train.select(["Species_Iris-setosa", "Species_Iris-versicolor", "Species_Iris-virginica"])?)?;
let X_test = dataframe_to_tensor(&df_test.select(["SepalLengthCm", "SepalWidthCm", "PetalLengthCm", "PetalWidthCm"])?)?;
let y_test = dataframe_to_tensor(&df_test.select(["Species_Iris-setosa", "Species_Iris-versicolor", "Species_Iris-virginica"])?)?;

In [65]:
(X_train.shape(), y_train.shape(), X_test.shape(), y_test.shape())

([120, 4], [120, 3], [30, 4], [30, 3])

In [None]:
struct DataSet {
    pub X_train: Tensor,
    pub y_test: Tensor,
    pub X_test: Tensor,
    pub y_train: Tensor,
}

impl DataSet {
    pub fn new(X_train: Tensor, y_train: Tensor, X_test: Tensor, y_test: Tensor) -> Self {
        Self {
            X_train,
            y_train,
            X_test,
            y_test,
        }
    }

    pub fn from_df(df: &DataFrame, train_frac: f64, input_columns: &[&str], predict_columns: &[&str], seed: Option<u64>) -> Result<Self> {
        let shuffled_df = df.sample_frac(&Series::new("frac", &[1.0]), false, true, Some(42))?;  // shuffle the dataframe
        
        let n_rows = df.height();
        let n_train_examples = (train_frac * n_rows as f64) as usize;

        let df_train = shuffled_df.slice(0, n_train_examples);
        let df_test = shuffled_df.slice(n_train_examples as i64, n_rows);

        let X_train = dataframe_to_tensor(&df_train.select(input_columns)?)?;
        let y_train = dataframe_to_tensor(&df_train.select(predict_columns)?)?;
        let X_test = dataframe_to_tensor(&df_test.select(input_columns)?)?;
        let y_test = dataframe_to_tensor(&df_test.select(predict_columns)?)?;

        Ok(Self::new(X_train, y_train, X_train, y_train))
    }
}

In [None]:
X.head(Some(5))

In [52]:
fn dataframe_to_tensor(df: &DataFrame) -> Result<Tensor> {
    let n_rows = df.height();
    let n_cols = df.width();

    // Collect DataFrame values into a Vec of Vecs
    let mut values = Vec::with_capacity(n_cols);
    for col in df.iter() {
        let col_vec: Vec<f64> = col
            .iter()
            .map(|val| val.extract::<f64>().unwrap())
            .collect();
        values.push(col_vec);
    };

    // Create Tensor from flattened Vec of Vecs
    Ok(Tensor::from_vec(
        values.into_iter().flatten().collect(),
        (n_cols, n_rows),
        &Device::Cpu,
    )
    .expect("error")
    .t()
    .expect("error"))
}

In [None]:
let X_tensor = dataframe_to_tensor(&X)?;
let y_tensor = dataframe_to_tensor(&y)?;

In [None]:
(X.shape(), y.shape())

In [None]:
struct Dataset {
    X_train: Tensor,
    y_train: Tensor,
    X_test: Tensor,
    y_test: Tensor,
}

In [None]:
X