Skip to content

Commit

Permalink
update to pytorch 1.10.1 (#1518)
Browse files Browse the repository at this point in the history
* update to pytorch 1.10.1

* fix histogram tests and small refactor

* disable augmentation test

* enable more cases in test histogram

* Use of legacy backtick

* enable crop tests again
  • Loading branch information
edgarriba committed Jan 9, 2022
1 parent 053092e commit 562d42f
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests_cpu.yml
Expand Up @@ -17,7 +17,7 @@ jobs:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: [3.6, 3.8, 3.9]
pytorch-version: [1.8.1, 1.9.1]
pytorch-version: [1.8.1, 1.9.1, 1.10.1]
steps:
- uses: actions/checkout@v2
- name: Setup conda dependencies
Expand Down
27 changes: 16 additions & 11 deletions kornia/enhance/histogram.py
Expand Up @@ -214,25 +214,29 @@ def image_histogram2d(

if bandwidth is None:
bandwidth = (max - min) / n_bins

if centers is None:
centers = min + bandwidth * (torch.arange(n_bins, device=image.device, dtype=image.dtype).float() + 0.5)
centers = min + bandwidth * (torch.arange(n_bins, device=image.device, dtype=image.dtype) + 0.5)
centers = centers.reshape(-1, 1, 1, 1, 1)

u = torch.abs(image.unsqueeze(0) - centers) / bandwidth
if kernel == "triangular":
mask = (u <= 1).to(u.dtype)
kernel_values = (1 - u) * mask
elif kernel == "gaussian":

if kernel == "gaussian":
kernel_values = torch.exp(-0.5 * u ** 2)
elif kernel == "uniform":
elif kernel in ("triangular", "uniform", "epanechnikov",):
# compute the mask and cast to floating point
mask = (u <= 1).to(u.dtype)
kernel_values = torch.ones_like(u, dtype=u.dtype, device=u.device) * mask
elif kernel == "epanechnikov":
mask = (u <= 1).to(u.dtype)
kernel_values = (1 - u ** 2) * mask
if kernel == "triangular":
kernel_values = (1. - u) * mask
elif kernel == "uniform":
kernel_values = torch.ones_like(u) * mask
else: # kernel == "epanechnikov"
kernel_values = (1. - u ** 2) * mask
else:
raise ValueError(f"Kernel must be 'triangular', 'gaussian', " f"'uniform' or 'epanechnikov'. Got {kernel}.")

hist = torch.sum(kernel_values, dim=(-2, -1)).permute(1, 2, 0)

if return_pdf:
normalization = torch.sum(hist, dim=-1, keepdim=True) + eps
pdf = hist / normalization
Expand All @@ -248,4 +252,5 @@ def image_histogram2d(
hist = hist.squeeze()
elif image.dim() == 3:
hist = hist.squeeze(0)
return hist, torch.zeros_like(hist, dtype=hist.dtype, device=hist.device)

return hist, torch.zeros_like(hist)
4 changes: 2 additions & 2 deletions setup_dev_env.sh
Expand Up @@ -16,7 +16,7 @@ conda_bin=$conda_bin_dir/conda

# download and install miniconda
# check the operating system: Mac or Linux
platform=`uname`
platform=$(uname)
if [[ "$platform" == "Darwin" ]];
then
download_link=https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh
Expand All @@ -36,7 +36,7 @@ fi
# define a python version to initialise the conda environment.
# by default we assume python 3.7.
python_version=${PYTHON_VERSION:-"3.7"}
pytorch_version=${PYTORCH_VERSION:-"1.9.1"}
pytorch_version=${PYTORCH_VERSION:-"1.10.1"}
pytorch_mode=${PYTORCH_MODE:-""} # use `cpuonly` for CPU or leave it in blank for GPU
cuda_version=${CUDA_VERSION:-"10.2"}

Expand Down
4 changes: 3 additions & 1 deletion test/enhance/test_histogram.py
Expand Up @@ -51,7 +51,9 @@ def test_jit(self, device, dtype, kernel):
op = TestImageHistogram2d.fcn
op_script = torch.jit.script(op)

assert_close(op(*inputs), op_script(*inputs))
out, out_script = op(*inputs), op_script(*inputs)
assert_close(out[0], out_script[0])
assert_close(out[1], out_script[1])

@pytest.mark.parametrize("kernel", ["triangular", "gaussian", "uniform", "epanechnikov"])
@pytest.mark.parametrize("size", [(1, 1), (3, 1, 1), (8, 3, 1, 1)])
Expand Down

0 comments on commit 562d42f

Please sign in to comment.