Skip to content

Commit

Permalink
classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
franzbischoff committed Jun 3, 2023
1 parent 1304029 commit c13a091
Show file tree
Hide file tree
Showing 10 changed files with 444 additions and 160 deletions.
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ export(mpx_stream_start)
export(muinvn)
export(normalize)
export(num_shapelets_par)
export(redundance_par)
export(redundancy_par)
export(regime_landmark_par)
export(regime_threshold_par)
export(scrimp)
Expand Down
6 changes: 3 additions & 3 deletions R/contrast_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ register_contrast_model <- function() {
parsnip::set_model_arg(
model = "contrast_model",
eng = "contrast_profile",
parsnip = "redundance",
original = "redundance",
func = list(fun = "redundance_par"),
parsnip = "redundancy",
original = "redundancy",
func = list(fun = "redundancy_par"),
has_submodel = FALSE
)
}
18 changes: 9 additions & 9 deletions R/contrast_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
contrast_model <- function(
mode = "classification",
num_shapelets = NULL,
redundance = NULL,
redundancy = NULL,
engine = "contrast_profile") {
# Check for correct mode
if (mode != "classification") {
Expand All @@ -20,7 +20,7 @@ contrast_model <- function(
# Capture the arguments in quosures
args <- list(
num_shapelets = rlang::enquo(num_shapelets),
redundance = rlang::enquo(redundance)
redundancy = rlang::enquo(redundancy)
)

# Save some empty slots for future parts of the specification
Expand Down Expand Up @@ -52,7 +52,7 @@ print.contrast_model <- function(x, ...) { # nolint
update.contrast_model <- function(object,
parameters = NULL,
num_shapelets = NULL,
redundance = NULL,
redundancy = NULL,
fresh = FALSE, ...) {
# nolint
# cli::cli_inform(c("*" = "update.contrast_model"))
Expand All @@ -63,7 +63,7 @@ update.contrast_model <- function(object,

args <- list(
num_shapelets = rlang::enquo(num_shapelets),
redundance = rlang::enquo(redundance)
redundancy = rlang::enquo(redundancy)
)

# function currently not exported by parsnip
Expand Down Expand Up @@ -195,7 +195,7 @@ check_args.contrast_model <- function(object) {
.check_contrast_profile_fit <- function(x) {
# cli::cli_inform(c("*" = ".check_contrast_profile_fit"))
num_shapelets <- rlang::eval_tidy(x$args$num_shapelets)
redundance <- rlang::eval_tidy(x$args$redundance)
redundancy <- rlang::eval_tidy(x$args$redundancy)

if (length(num_shapelets) != 1L) {
rlang::abort(c(
Expand All @@ -206,12 +206,12 @@ check_args.contrast_model <- function(object) {
))
}

if (length(redundance) != 1L) {
if (length(redundancy) != 1L) {
rlang::abort(c(
"For the contrast_profile engine, `redundance` must be a single number (or a value of `tune()`).",
glue::glue("There are {length(redundance)} values for `redundance`."),
"For the contrast_profile engine, `redundancy` must be a single number (or a value of `tune()`).",
glue::glue("There are {length(redundancy)} values for `redundancy`."),
"To try multiple values for total regularization, use the tune package.",
"To predict multiple redundance, use `multi_predict()`"
"To predict multiple redundancy, use `multi_predict()`"
))
}
}
Expand Down
88 changes: 88 additions & 0 deletions R/contrast_train.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,94 @@ contrast_train_model <- function(truth, ts, ..., window_size, regime_threshold,
return(trained)
}

#' @export

contrast_evaluate_all_platos <- function(true_data, false_data, contrast_profiles, quantiles = c(0.1, 1 / 3), segment_size = 2800) {
checkmate::qassert(true_data, "N+")
checkmate::qassert(false_data, "N+")
checkmate::qassert(contrast_profiles, "L+")

w_sizes <- names(contrast_profiles) # retrieve the window sizes that are stored as list labels
c_sizes <- as.numeric(w_sizes) # then convert them to numeric

segs <- list()
cont <- matrix(Inf, ncol = length(w_sizes), nrow = ncol(contrast_profiles[[1]]$cps))
thlds <- matrix(Inf, ncol = length(w_sizes), nrow = ncol(contrast_profiles[[1]]$cps))

for (i in seq_along(c_sizes)) {
# get all the distance profiles of all top-k platos for the true and false data
tp <- topk_distance_profiles(true_data, contrast_profiles, c_sizes[i])
fp <- topk_distance_profiles(false_data, contrast_profiles, c_sizes[i])

# computes the index of each segment
segments <- unique(c(seq(0, nrow(tp), by = segment_size), nrow(tp)))

# number of top-k platos is equal to the ncol of the distance profiles
knn <- ncol(tp)

# this matrix is equivalent to "k" vectors of TRUE/FALSE values for each segment
seg <- matrix(0, ncol = length(segments) - 1, nrow = knn)

for (k in seq_len(knn)) {
# get the distance profiles
tpk <- tp[, k]
fpk <- fp[, k]

# get the minimum value on the false data to use as a threshold
fpk_min <- min(fpk, na.rm = TRUE)
thlds[k, i] <- fpk_min
max_min <- NULL

# iterate over the segments
for (j in seq_along(segments)) {
if (j < length(segments)) {
# get just that segment distance profile
sm <- tpk[seq(segments[j] + 1, segments[j + 1])]
# keep only the ones that are smaller than the threshold
sm <- sm[sm < fpk_min]

# if there are any value left, we can classify the segment
if (length(sm) > 0) {
# get the 10% percentile of the values
# this gives us a hint of how well the shapelet is doing by segment
max_min <- c(max_min, quantile(sm, quantiles[1], na.rm = TRUE))
seg[k, j] <- 1 # store as TRUE, since we have values below the threshold
}
}
}

# if the plato could classify any segment...
if (!is.null(max_min)) {
# computes the overall contrast value using the quantile 1/3 of the max_min values, normalized
# by the sqrt(2*w) (same factor used on the contrast profile), so we can compare different window sizes
cont[k, i] <- (fpk_min - quantile(max_min, quantiles[2], na.rm = TRUE)) / sqrt(2 * c_sizes[i])
} else {
# if the plato could not classify any segment, we set the contrast to 0
cont[k, i] <- 0
}
}
segs[[w_sizes[i]]] <- seg
}

# here we compute the total number of segments that each plato could classify
total_counts <- as.matrix(purrr::map_dfr(segs, function(x) apply(x, 1, sum)))
colnames(cont) <- w_sizes # set the column names on the overall contrast matrix
colnames(thlds) <- w_sizes # set the column names on the overall contrast matrix

# NOTE: cont and segs seems to complement each other
# cont is similar to plot_topk_contrasts (contrasts values by k vs shapelet size)
# but it is more specific, since it takes into account the segments and not just the
# best contrast.

list(
contrast = cont, # cont == overall contrast of each plato (~specificity)
coverage = segs, # segs == coverage of each plato (~sensitivity)
thresholds = thlds, # thlds == threshold of each plato
cov_counts = total_counts, # sum of segs == 1. Best is sum == num_segments
num_segments = (length(segments) - 1)
)
}

#' @export
contrast_train_regimes <- function(ecg_data, window_size, mp_threshold, time_constraint,
ez = 0.5, history = 5000L, sample_freq = 250L, batch = 100L) {
Expand Down
4 changes: 2 additions & 2 deletions R/contrast_tune.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ num_shapelets_par <- function(range = c(1L, 10L), trans = NULL) {
}

#' @export
redundance_par <- function(range = c(0L, 10L), trans = NULL) {
redundancy_par <- function(range = c(0L, 10L), trans = NULL) {
dials::new_quant_param(
type = "integer",
range = range,
inclusive = c(TRUE, TRUE),
trans = trans,
default = 0L,
values = seq.int(range[1], range[2]),
label = c(redundance = "Redundance"),
label = c(redundancy = "redundancy"),
finalize = NULL
)
}
Expand Down
42 changes: 22 additions & 20 deletions _contrast_profile/meta/meta

Large diffs are not rendered by default.

107 changes: 107 additions & 0 deletions flowcharts/contrast_profile.mmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
---
title: Contrast Profile Classifier
---

%%{
init: {
"securityLevel": "loose",
"theme": "dark",
"fontFamily": "Fira Code Medium, Trebuchet MS, Verdana, Arial, Sans-Serif",
"flowchart": {
"diagramPadding": 10
}
}
}%%

classDiagram
%% <|-- Inheritance (is-a relationship)
%% ..> Dependency (needs but not part of)
%% ..|> Realization (interface implementation)
%% *-- Composition (both live and die together)
%% o-- Aggregation (lifecycle is independent)
%% --> Association (generic relationship, that may use cardinality)
%% -- Link (Solid) (Association without arrows)
%% .. Link (Dashed) (Association without arrows and not navigable)

Data "many * classes" <.. "1 * window_sizes" Contrast
Contrast "1" <.. "1" Shapelet
Shapelet "1" <.. "1" ShapeletMeta
Data <.. ShapeletMeta
%% Shapelet "1..window_sizes" *-- "1" PanContrast_TopK
%% Contrast "1..window_sizes" *-- "1" PanContrast_TopK

%% class PanContrast_TopK {
%% Contrast contrasts
%% Shapelet shapelets
%% }


class Data {
List~Factor~ classes
List~Numeric~ ts
List~int~ ids
}


%% class is the positive class
class Contrast {
List~int~ window_sizes*
Factor class*
List~Numeric~ contrast_profiles
}

class Shapelet {
List~int~ window_sizes*
Factor class*
List~Numeric~ platos
List~int~ platos_indices
List~Numeric~ platos_twin
List~int~ platos_twin_indices
List~float~ plato_nary_contrasts
}

%% all Lists have dim m,n where m == num_of_shapelets(k) and n == length(window_sizes)
%% except coverages that has dim(m) where m == length(window_sizes) and each element has
%% dim(m,n) where m == num_of_shapelets(k) and n == num_segments
%% TODO: this need to be reshaped
%% TODO: num_segments reflect the number of positive samples
class ShapeletMeta {
List~int~ window_sizes*
Factor class*
List~Numeric~ thresholds
List~Numeric~ overall_contrasts
List~bool~ coverages
List~int~ coverages_counts
int num_segments
}

ShapeletMeta *-- Fitted
Fitted *-- Model
Terms *-- Model
ShapeletMeta <-- Terms : optimizes

class Fitted {
Factor class*
ShapeletMeta best_shapelets
List~Numeric~ platos
List~Numeric~ thresholds
}

class Terms {
float contrast_total
float contrast_median
float contrast_mean
fload contrast_std
fload cov_con_ratio_mean
float k_mean
float cov_mean
fload coverage
fload cov_percent
int redundancy
int num_shapelets
}

class Model {
Fitted fitted_values
Terms terms
}
2 changes: 1 addition & 1 deletion scripts/_classifier.R
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ list(
# contrast_model(
# # coverage_quantiles = tune::tune(), # score_by_segment_window
# num_shapelets = tune::tune(), # find_solutions
# redundance = tune::tune() # find_solutions
# redundancy = tune::tune() # find_solutions
# ) |>
# parsnip::set_engine("contrast_profile") |>
# parsnip::set_mode("classification")
Expand Down
Loading

0 comments on commit c13a091

Please sign in to comment.