-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
cbf7376
commit 5a31c33
Showing
6 changed files
with
727 additions
and
31 deletions.
There are no files selected for viewing
240 changes: 240 additions & 0 deletions
240
examples/.ipynb_checkpoints/multi_task_test_notebook-checkpoint.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,240 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ff0ce27e", | ||
"metadata": {}, | ||
"source": [ | ||
"# FVGP MULTI Task Notebook\n", | ||
"In this notebook we will go through many features of FVGP. We will be primarily concerned with regression over a single dimension output and multiple tasks." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "346f9628", | ||
"metadata": {}, | ||
"source": [ | ||
"## This first cell has nothing to do with gpCAM, it's just a function to plot some results later" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f17458ce", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import plotly.graph_objects as go\n", | ||
"import numpy as np\n", | ||
"def plot(x,y,z,data = None):\n", | ||
" fig = go.Figure()\n", | ||
" fig.add_trace(go.Surface(x = x, y = y,z=z))\n", | ||
" if data is not None: \n", | ||
" fig.add_trace(go.Scatter3d(x=data[:,0], y=data[:,1], z=data[:,2],\n", | ||
" mode='markers'))\n", | ||
"\n", | ||
" fig.update_layout(title='Posterior Mean', autosize=True,\n", | ||
" width=800, height=800,\n", | ||
" margin=dict(l=65, r=50, b=65, t=90))\n", | ||
"\n", | ||
"\n", | ||
" fig.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "82296522", | ||
"metadata": {}, | ||
"source": [ | ||
"## Import fvgp and relevant libraries" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "3cc21dc8", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import fvgp\n", | ||
"from fvgp import gp, fvgp\n", | ||
"import numpy as np\n", | ||
"import matplotlib.pyplot as plt" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "87ae086d", | ||
"metadata": {}, | ||
"source": [ | ||
"## Defining some input data and testing points" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "31d4da07", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def test_data_function(x):\n", | ||
" data_1 = 100*np.sin(x)+np.cos(x)\n", | ||
" data_2 = 5*np.ones(x.shape)\n", | ||
" data_3 = 1*np.cos(x/10 + 5)\n", | ||
" data_4 = 5*np.sin(x/200)\n", | ||
" data_5 = 10*np.cos(x)\n", | ||
" \n", | ||
" return np.column_stack((data_1, data_2, data_3, data_4, data_5))\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "d3b9ede4", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"x_input = np.linspace(-2*np.pi, 10*np.pi,100)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "4dda18dd", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"y_output = test_data_function(x_input)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "e7ff3403", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"x_input_test = np.linspace(3*np.pi, 4*np.pi, 100)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "2b0cc16d", | ||
"metadata": {}, | ||
"source": [ | ||
"## Setting up the fvgp multi task object\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "a2874db0", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"obj = fvgp.fvGP(input_space_dim = 1, \n", | ||
" output_space_dim = 1, \n", | ||
" output_number = 5, \n", | ||
" points = x_input.reshape(-1,1),\n", | ||
" values = y_output,\n", | ||
" init_hyperparameters = np.array([10,10,10]))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ca3686f5", | ||
"metadata": {}, | ||
"source": [ | ||
"## Training our gaussian process regression on given data\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "75dcee32", | ||
"metadata": { | ||
"scrolled": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"hyper_param_bounds = np.array([[0.0001, 1000],[ 0.0001, 1000],[ 0.0001, 1000]])\n", | ||
"obj.train(hyper_param_bounds)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "3b143d85", | ||
"metadata": {}, | ||
"source": [ | ||
"## Looking at the posterior mean at the test points (remember that we did not define a particularly good kernel)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "1acce22c", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"task_idx = 1" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "6a2d2781", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"x_linspace = np.linspace(3*np.pi, 4*np.pi,100)\n", | ||
"y_linspace = np.linspace(0,4,100)\n", | ||
"x_grid, y_grid = np.meshgrid(x_linspace, y_linspace)\n", | ||
"post_mean= obj.posterior_mean(np.column_stack((x_input_test, task_idx*np.ones(x_input_test.shape))))\n", | ||
"posterior_mean_data = obj.posterior_mean(np.column_stack((x_grid.flatten(), y_grid.flatten())))\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "9e7f6b42", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"fig = plt.figure(figsize = (10,10))\n", | ||
"ax = fig.add_subplot(projection='3d')\n", | ||
"\n", | ||
"\n", | ||
"ax.scatter(post_mean['x'][:,0], np.ones(100)*0, y_output[:,0])\n", | ||
"ax.scatter(post_mean['x'][:,0], np.ones(100)*1,y_output[:,1])\n", | ||
"ax.scatter(post_mean['x'][:,0], np.ones(100)*2,y_output[:,2])\n", | ||
"ax.scatter(post_mean['x'][:,0], np.ones(100)*3,y_output[:,3])\n", | ||
"ax.scatter(post_mean['x'][:,0], np.ones(100)*4,y_output[:,4])\n", | ||
"ax.plot_surface(x_grid, y_grid, posterior_mean_data['f(x)'].reshape(100,100), rstride=1, cstride=1,cmap='inferno')\n", | ||
"# ax.scatter(x_grid.flatten(), y_grid.flatten(), posterior_mean_data['f(x)'])\n", | ||
"\n", | ||
"# pyplot.pcolormesh(x_grid, y_grid, posterior_mean_data['f(x)'].reshape(100,100))\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "fvgp_env", | ||
"language": "python", | ||
"name": "fvgp_env" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.10" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.