Skip to content

Commit

Permalink
Loads the CSV just once
Browse files Browse the repository at this point in the history
  • Loading branch information
cuducos committed May 9, 2024
1 parent 0fbb5d4 commit c6fdada
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 35 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ crate-type = ["cdylib"]
anyhow = "1.0.83"
chrono = "0.4.38"
csv = "1.3.0"
lazy_static = "1.4.0"
pyo3 = { version = "0.18", features = ["abi3-py311"] }
rand = "0.8.5"
rayon = "1.10.0"
Expand Down
11 changes: 7 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use pyo3::{
exceptions::PyValueError, pyfunction, pymodule, types::PyModule, wrap_pyfunction, PyResult,
Python,
};
use rand::Rng;
use walkdir::WalkDir;

mod correlation;
Expand Down Expand Up @@ -40,13 +41,15 @@ fn recommendations_for(name: String) -> PyResult<Vec<(whisky::PyWhisky, whisky::
}

#[pyfunction]
fn random_whisky() -> PyResult<whisky::PyWhisky> {
whisky::random_whisky().map_err(|e| PyValueError::new_err(e.to_string()))
fn all_whiskies() -> PyResult<Vec<whisky::PyWhisky>> {
Ok(whisky::WHISKIES.iter().map(|w| w.py()).collect())
}

#[pyfunction]
fn all_whiskies() -> PyResult<Vec<whisky::PyWhisky>> {
whisky::all_whiskies().map_err(|e| PyValueError::new_err(e.to_string()))
fn random_whisky() -> PyResult<whisky::PyWhisky> {
let mut rng = rand::thread_rng();
let idx = rng.gen_range(0..whisky::WHISKIES.len());
Ok(whisky::WHISKIES[idx].py())
}

#[pymodule]
Expand Down
61 changes: 30 additions & 31 deletions src/whisky.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,27 @@ use std::cmp::Ordering;

use anyhow::{anyhow, Result};
use csv::{ReaderBuilder, StringRecord};
use rand::seq::SliceRandom;
use rand::thread_rng;
use rayon::iter::{ParallelIterator, IntoParallelRefIterator};
use lazy_static::lazy_static;
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use regex::Regex;

use crate::correlation::Correlation;

const DATA: &str = include_str!("data/whisky.csv");

fn load_whiskies() -> Result<Vec<Whisky>> {
let mut reader = ReaderBuilder::new()
.has_headers(true)
.from_reader(DATA.as_bytes());
let mut whiskies: Vec<Whisky> = vec![];
for row in reader.records() {
let record = Whisky::from_csv_row(&row?)?;
whiskies.push(record);
}
Ok(whiskies)
lazy_static! {
pub static ref WHISKIES: Vec<Whisky> = {
let mut reader = ReaderBuilder::new()
.has_headers(true)
.from_reader(DATA.as_bytes());

let mut whiskies: Vec<Whisky> = vec![];
for row in reader.records() {
let record = Whisky::from_csv_row(&row.unwrap()).unwrap();
whiskies.push(record);
}
whiskies
};
}

fn distillery(row: &StringRecord) -> Result<String> {
Expand Down Expand Up @@ -100,24 +102,10 @@ impl Whisky {
}
}

pub fn random_whisky() -> Result<PyWhisky> {
let whiskies = load_whiskies()?;
let mut rng = thread_rng();
whiskies
.choose(&mut rng)
.map(|w| w.py())
.ok_or(anyhow!("No whiskies found"))
}

pub fn all_whiskies() -> Result<Vec<PyWhisky>> {
let whiskies = load_whiskies()?;
Ok(whiskies.iter().map(|w| w.py()).collect())
}

pub fn recommendations_for(name: String) -> Result<Vec<(PyWhisky, PyWhisky, f64)>> {
let mut whisky: Option<Whisky> = None;
let mut others: Vec<Whisky> = vec![];
for w in load_whiskies()? {
let mut whisky: Option<&Whisky> = None;
let mut others: Vec<&Whisky> = vec![];
for w in WHISKIES.iter() {
if w.distillery == name || w.slug() == name {
whisky = Some(w);
} else {
Expand All @@ -127,7 +115,7 @@ pub fn recommendations_for(name: String) -> Result<Vec<(PyWhisky, PyWhisky, f64)
let reference = whisky.ok_or(anyhow!("Whisky {} not found", name))?;
let mut correlations: Vec<Correlation> = others
.par_iter()
.map(|w| Correlation::new(&reference, w))
.map(|w| Correlation::new(reference, w))
.collect();

correlations.sort_by(|a, b| {
Expand All @@ -148,3 +136,14 @@ pub fn recommendations_for(name: String) -> Result<Vec<(PyWhisky, PyWhisky, f64)

Ok(best)
}

#[cfg(test)]
mod tests {

use super::*;

#[test]
fn test_all_whiskies_are_loaded() {
assert_eq!(WHISKIES.len(), 86);
}
}

0 comments on commit c6fdada

Please sign in to comment.