Conversation
lib/scholar/impute/knn_imputer.ex
Outdated
| if opts[:missing_values] != :nan and | ||
| Nx.any(Nx.is_nan(x)) == Nx.tensor(1, type: :u8) do | ||
| raise ArgumentError, | ||
| ":missing_values other than :nan possible only if there is no Nx.Constant.nan() in the array" | ||
| end | ||
|
|
There was a problem hiding this comment.
This check does not really work in Nx. If you call fit inside Nx.Defn.jit, then x is an expression, and we can't read its values to find out if there is a nan or not. The best we can do is to remove this check and document it.
There was a problem hiding this comment.
I found this check in simple imputer
https://github.com/elixir-nx/scholar/blob/main/lib/scholar/impute/simple_imputer.ex
Are you sure it won't work?
There was a problem hiding this comment.
It is also broken there. :)
lib/scholar/impute/knn_imputer.ex
Outdated
|
|
||
| all_nan_rows_count = Nx.sum(all_nan_rows) | ||
|
|
||
| if num_neighbors > rows - 1 - Nx.to_number(all_nan_rows_count) do |
There was a problem hiding this comment.
Same here, this code won't work because, when you have an expression, you can't get a number from it. Can we remove this check? What happens if we don't check for this condition?
There was a problem hiding this comment.
You can test this by calling fit after jitting it with Nx.Defn.fit.
lib/scholar/impute/knn_imputer.ex
Outdated
|
|
||
| # if potential neighbor has nan in nan_col, we don't want to calculate distance and the case if potential_neighbour is the row to impute | ||
| {potential_neighbor} = | ||
| if potential_neighbor[nan_col] == Nx.Constants.nan() do |
There was a problem hiding this comment.
I am not sure if this check is guaranteed to work, given two NaNs are not guaranteed to be equal. Using Nx.is_nan would be more appropriate.
lib/scholar/impute/knn_imputer.ex
Outdated
|
|
||
| x = | ||
| if opts[:missing_values] != :nan, | ||
| do: Nx.select(Nx.equal(x, opts[:missing_values]), Nx.Constants.nan(), x), |
There was a problem hiding this comment.
Use Nx.is_nan here NaN is not equal to itself
| coordinates = coordinates - 1 | ||
|
|
||
| # inputes zeros in nan_col to calculate distance with squared_euclidean | ||
| new_row = Nx.indexed_put(row, Nx.new_axis(nan_col, 0), Nx.tensor(0)) |
There was a problem hiding this comment.
Generally, when you write in defn, you don't need to wrap this zero in Nx.tensor. I prefer to explicitly use Nx.<type> or Nx.tensor(x, type: type) to indicate the type of the tensor. Now, there are some cases where imputter has fixed type like :f32. I think that this might cause undesired upcasts when e.g. I have tensor of type :bf16. So I suggest to check if there are any unwanted casts / upcast.
There was a problem hiding this comment.
I changed it but I don't know how to change this line
row_distances = Nx.iota({rows}, type: {:f, 32})
because i don't know what the type calculated distance will be at this point
|
|
||
| # if row has all nans we skip it | ||
| {weight, potential_neighbor} = | ||
| if present_coordinates == 0 do |
There was a problem hiding this comment.
As mentioned in comment up, try to replace "bare" numbers with typed tensors
lib/scholar/impute/knn_imputer.ex
Outdated
| @@ -0,0 +1,256 @@ | |||
| defmodule Scholar.Impute.KNNImputer do | |||
There was a problem hiding this comment.
I think it should be written with double t KNNImputter like formatter etc.
msluszniak
left a comment
There was a problem hiding this comment.
Thanks for the PR, I dropped some comments :))
|
Hi @srzeszut and thanks for the pull request. I’m traveling now and don’t have my laptop with me. Will be back this Sunday, so I will have a look probably next week. |
|
Thanks for the review, I apply suggested changes and left some comments. |
lib/scholar/impute/knn_imputter.ex
Outdated
|
|
||
| if num_neighbors > rows - 1 - Nx.to_number(all_nan_rows_count) do | ||
| raise ArgumentError, | ||
| "Number of neighbors rows must be less than number valid of rows - 1 (valid row is row with more than 1 non nan value)" |
There was a problem hiding this comment.
error messages start in lowercase. :)
| "Number of neighbors rows must be less than number valid of rows - 1 (valid row is row with more than 1 non nan value)" | |
| "number of neighbors rows must be less than number valid of rows - 1 (valid row is row with more than 1 non nan value)" |
lib/scholar/impute/knn_imputter.ex
Outdated
|
|
||
| all_nan_rows_count = Nx.sum(all_nan_rows) | ||
|
|
||
| if num_neighbors > rows - 1 - Nx.to_number(all_nan_rows_count) do |
There was a problem hiding this comment.
Can you please add some tests? In particular, please add a test where you call jit this function and then you call it: Nx.Defn.jit(...).(arg1, arg2). It should reveal some errors around here. :)
There was a problem hiding this comment.
I added tests and checked it. I removed those checks and added them in the description
lib/scholar/impute/knn_imputter.ex
Outdated
| `n_neighbors` nearest neighbors found in the training set. Two samples are | ||
| close if the features that neither is missing are close. |
There was a problem hiding this comment.
| `n_neighbors` nearest neighbors found in the training set. Two samples are | |
| close if the features that neither is missing are close. | |
| `n_neighbors` nearest neighbors found in the training set. Two samples are | |
| close if the features that neither is missing are close. |
lib/scholar/impute/knn_imputter.ex
Outdated
|
|
||
| Preconditions: | ||
| * `number_of_neighbors` is a positive integer. | ||
| * number of neighbors must be less than number valid of rows - 1 (valid row is row with more than 1 non nan value) otherwise it is better to use simple imputter |
There was a problem hiding this comment.
Please try to break this long line :)
| test "Wrong impute rank" do | ||
| x = Nx.tensor([1, 2, 2, 3]) | ||
|
|
||
| assert_raise ArgumentError, | ||
| "Wrong input rank. Expected: 2, got: 1", | ||
| fn -> | ||
| KNNImputter.fit(x, missing_values: 1, number_of_neighbors: 2) | ||
| end | ||
| end | ||
|
|
||
| test "Invalid n_neighbors value" do |
There was a problem hiding this comment.
Test names start in lowercase :)
| test "Wrong impute rank" do | |
| x = Nx.tensor([1, 2, 2, 3]) | |
| assert_raise ArgumentError, | |
| "Wrong input rank. Expected: 2, got: 1", | |
| fn -> | |
| KNNImputter.fit(x, missing_values: 1, number_of_neighbors: 2) | |
| end | |
| end | |
| test "Invalid n_neighbors value" do | |
| test "invalid impute rank" do | |
| x = Nx.tensor([1, 2, 2, 3]) | |
| assert_raise ArgumentError, | |
| "Wrong input rank. Expected: 2, got: 1", | |
| fn -> | |
| KNNImputter.fit(x, missing_values: 1, number_of_neighbors: 2) | |
| end | |
| end | |
| test "invalid n_neighbors value" do |
josevalim
left a comment
There was a problem hiding this comment.
I dropped the last round of nitpicks and we are good to go!
There was a problem hiding this comment.
First review. Some features we might wanna have:
- Make k-NN algorithm configurable.
- Make the metric configurable.
You can leave these for another pull request. Have a look at e.g. KNNClassifier how it is done over there.
I should have another look tonight.
lib/scholar/impute/knn_imputter.ex
Outdated
| The default value expects there are no NaNs in the input tensor. | ||
| """ | ||
| ], | ||
| number_of_neighbors: [ |
There was a problem hiding this comment.
I would suggest changing this to num_neighbors to be consistent with the rest of Scholar.
krstopro
left a comment
There was a problem hiding this comment.
Several minor comments for now. I have to go through the code at least once more as I don't exactly understand the logic here.
lib/scholar/impute/knn_imputter.ex
Outdated
|
|
||
| x = | ||
| if opts[:missing_values] != :nan, | ||
| do: Nx.select(Nx.equal(x, opts[:missing_values]), Nx.Constants.nan(), x), |
There was a problem hiding this comment.
You should be able to use == instead of Nx.equal/2.
There was a problem hiding this comment.
This is a deftransform, so Nx.equal is the proper function. == will be Elixir.Kernel.==
lib/scholar/impute/knn_imputter.ex
Outdated
| placeholder_value = Nx.Constants.nan() |> Nx.tensor() | ||
|
|
||
| statistics = knn_impute(x, placeholder_value, num_neighbors: num_neighbors) | ||
| missing_values = opts[:missing_values] |
There was a problem hiding this comment.
I would move this line above so that you don't access opts[:missing_values] multiple times.
lib/scholar/impute/knn_imputter.ex
Outdated
|
|
||
| {_, values_to_impute} = | ||
| while {{row = 0, mask, num_neighbors, num_rows, x}, values_to_impute}, | ||
| Nx.less(row, num_rows) do |
There was a problem hiding this comment.
You can use < instead of Nx.less/2 over here.
lib/scholar/impute/knn_imputter.ex
Outdated
| Nx.less(row, num_rows) do | ||
| {_, values_to_impute} = | ||
| while {{col = 0, mask, num_neighbors, num_cols, row, x}, values_to_impute}, | ||
| Nx.less(col, num_cols) do |
lib/scholar/impute/knn_imputter.ex
Outdated
| {_, values_to_impute} = | ||
| while {{col = 0, mask, num_neighbors, num_cols, row, x}, values_to_impute}, | ||
| Nx.less(col, num_cols) do | ||
| if mask[row][col] > 0 do |
There was a problem hiding this comment.
I think if mask[row][col] do should work here.
lib/scholar/impute/knn_imputter.ex
Outdated
| * `number_of_neighbors` is a positive integer. | ||
| * number of neighbors must be less than number valid of rows - 1 (valid row is row with more than 1 non nan value) otherwise it is better to use simple imputter |
There was a problem hiding this comment.
| * `number_of_neighbors` is a positive integer. | |
| * number of neighbors must be less than number valid of rows - 1 (valid row is row with more than 1 non nan value) otherwise it is better to use simple imputter | |
| * The number of neighbors must be less than the number of valid rows - 1. | |
| A valid row is a row with more than 1 non-NaN values. Otherwise it is better to use a simpler imputer. |
lib/scholar/impute/knn_imputter.ex
Outdated
| Preconditions: | ||
| * `number_of_neighbors` is a positive integer. | ||
| * number of neighbors must be less than number valid of rows - 1 (valid row is row with more than 1 non nan value) otherwise it is better to use simple imputter | ||
| * when you set a value different than :nan in `missing_values` there should be no NaNs in the input tensor |
There was a problem hiding this comment.
| * when you set a value different than :nan in `missing_values` there should be no NaNs in the input tensor | |
| * When you set a value different than `:nan` in `missing_values` there should be no NaNs in the input tensor |
lib/scholar/impute/knn_imputter.ex
Outdated
| * `:missing_values` - the same value as in `:missing_values` | ||
|
|
||
| * `:statistics` - The imputation fill value for each feature. Computing statistics can result in | ||
| [`Nx.Constant.nan/0`](https://hexdocs.pm/nx/Nx.Constants.html#nan/0) values. |
There was a problem hiding this comment.
| [`Nx.Constant.nan/0`](https://hexdocs.pm/nx/Nx.Constants.html#nan/0) values. | |
| [`Nx.Constants.nan/0`](https://hexdocs.pm/nx/Nx.Constants.html#nan/0) values. |
There was a problem hiding this comment.
Do you need the explicit linking in hexdoc?
lib/scholar/impute/knn_imputter.ex
Outdated
|
|
||
| The function returns a struct with the following parameters: | ||
|
|
||
| * `:missing_values` - the same value as in `:missing_values` |
There was a problem hiding this comment.
| * `:missing_values` - the same value as in `:missing_values` | |
| * `:missing_values` - the same value as in the `:missing_values` option |
lib/scholar/impute/knn_imputter.ex
Outdated
|
|
||
| num_neighbors = opts[:number_of_neighbors] | ||
|
|
||
| placeholder_value = Nx.Constants.nan() |> Nx.tensor() |
There was a problem hiding this comment.
| placeholder_value = Nx.Constants.nan() |> Nx.tensor() | |
| placeholder_value = Nx.Constants.nan() |
There was a problem hiding this comment.
you probably want to pass the input type here to avoid upcasts
lib/scholar/impute/knn_imputter.ex
Outdated
|
|
||
| opts_schema = [ | ||
| missing_values: [ | ||
| type: {:or, [:float, :integer, {:in, [:nan]}]}, |
There was a problem hiding this comment.
| type: {:or, [:float, :integer, {:in, [:nan]}]}, | |
| type: {:or, [:float, :integer, {:in, [:nan]}]}, |
I believe this should allow :infinity and :neg_infinity too for completeness
lib/scholar/impute/knn_imputter.ex
Outdated
| indices = | ||
| [Nx.stack(row), Nx.stack(col)] | ||
| |> Nx.concatenate() | ||
| |> Nx.stack() |
There was a problem hiding this comment.
| indices = | |
| [Nx.stack(row), Nx.stack(col)] | |
| |> Nx.concatenate() | |
| |> Nx.stack() | |
| indices = Nx.stack([row, col]) |> Nx.reshape({1, 2}) |
If I read the code correctly, row and col are scalars and this should yield the same result
lib/scholar/impute/knn_imputter.ex
Outdated
| |> Nx.concatenate() | ||
| |> Nx.stack() | ||
|
|
||
| values_to_impute = Nx.indexed_put(values_to_impute, indices, Nx.stack(neighbor_avg)) |
There was a problem hiding this comment.
| values_to_impute = Nx.indexed_put(values_to_impute, indices, Nx.stack(neighbor_avg)) | |
| values_to_impute = Nx.put_slice(values_to_impute, [row, col], Nx.reshape(neighbor_avg, {1, 1})) |
I think this is even simpler
| {_, row_distances} = | ||
| while {{i = 0, x, row_with_value_to_fill, rows, nan_row, nan_col}, row_distances}, | ||
| Nx.less(i, rows) do | ||
| potential_donor = x[i] | ||
|
|
||
| distance = | ||
| if i == nan_row do | ||
| Nx.Constants.infinity(Nx.type(row_with_value_to_fill)) | ||
| else | ||
| nan_euclidian(row_with_value_to_fill, nan_col, potential_donor) | ||
| end | ||
|
|
||
| row_distances = Nx.indexed_put(row_distances, Nx.new_axis(i, 0), distance) | ||
| {{i + 1, x, row_with_value_to_fill, rows, nan_row, nan_col}, row_distances} | ||
| end |
There was a problem hiding this comment.
try this:
potential_donors = Nx.vectorize(x, :rows)
distances = nan_euclidean(row_with_value_to_fill, nan_col, potential_donors) |> Nx.devectorize()
row_distances = Nx.indexed_put(distances, [i], Nx.Constants.infinity())|
Thanks for all the comments, I applied your suggested changes to the code. |
Knn imputer
mix format
|
💚 💙 💜 💛 ❤️ |
I have added the KNNImputer and I am currently implementing tests to ensure that it behaves as expected across various scenarios, including edge cases.