{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Debug backward passes through KeOps model\n", "\n", "Setup:\n", "- pytorch 1.2\n", "- pykeops 1.2\n", "- gpytorch master" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import gpytorch\n", "\n", "from gpytorch.kernels import ScaleKernel\n", "from gpytorch.kernels.keops import MaternKernel as KeOpsMaternKernel\n", "from gpytorch.kernels.keops import RBFKernel as KeOpsRBFKernel\n", "from gpytorch.means import ConstantMean\n", "from gpytorch.models import ExactGP\n", "from gpytorch.likelihoods import GaussianLikelihood\n", "from gpytorch.distributions import MultivariateNormal\n", "from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood\n", "\n", "\n", "class KeOpsGP(ExactGP):\n", " def __init__(self, train_X, train_Y):\n", " super().__init__(train_X, train_Y, GaussianLikelihood())\n", " self.mean_module = ConstantMean()\n", " self.covar_module = ScaleKernel(\n", " KeOpsRBFKernel(nu=2.5),\n", " )\n", "\n", " def forward(self, x):\n", " mean_x = self.mean_module(x)\n", " covar_x = self.covar_module(x)\n", " return MultivariateNormal(mean_x, covar_x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": false }, "outputs": [], "source": [ "n_train = 100 # any number here will work\n", "n_test = 81 # works for 80, breaks for 81+\n", "\n", "train_X = torch.rand(n_train, 2)\n", "train_Y = 1 + torch.rand(n_train)\n", "test_X = torch.rand(n_test, 2)\n", "\n", "model = KeOpsGP(train_X, train_Y)\n", "model.eval()\n", "\n", "with gpytorch.settings.max_cholesky_size(1):\n", " out = model(test_X)\n", " out.mean.sum().backward()" ] } ], "metadata": { "anaconda-cloud": [], "anp_cloned_from": { "notebook_id": "773930246357510", "revision_id": "360687611544715" }, "anp_metadata": { "fetch_marker": "a0cfed57-08e7-4365-8c5c-6fcd34a70f05", "path": "notebooks/malaria_AL_SysML_v3_rerun_lesspoints.ipynb" }, "bento_stylesheets": { "bento/extensions/flow/main.css": true, "bento/extensions/kernel_selector/main.css": true, "bento/extensions/kernel_ui/main.css": true, "bento/extensions/new_kernel/main.css": true, "bento/extensions/system_usage/main.css": true, "bento/extensions/theme/main.css": true }, "disseminate_notebook_info": { "backup_notebook_id": "502117130366686" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 2 }