In [None]:
#|hide
#|eval: false
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [None]:
#|hide
#|eval: false
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [None]:
#|default_exp funcs.conjugation

In [None]:
#|export
# Python native modules

# Third party libs
import torch
import numpy as np
# Local modules

# Conjugation
> Notes and functions illistrated by [Shewchuk, 1994](https://www.cs.cmu.edu/~quake-papers/painless-conjugate-gradient.pdf) and will
be referenced through this notebook.

Based on sample problem 4...

In [None]:
A = torch.tensor(
    [[3.,2.],[2.,6.]]
)
b = torch.tensor([[2.],[-8.]])
c = 0

In [None]:
# x minimizes the following function `f`
x = torch.tensor([[2.],[-2.]])

We define a quadratic function whose minimum value output is -10. The challenge is
to pretend we don't know this, and automatically figure out what value of `x` is needed
to figure this out.

In [None]:
def f(x): return (1/2) * x.T @ A @ x - b.T @ x + c

In [None]:
import plotly.express as px
import plotly.io as pio
pio.renderers.default = "plotly_mimetype+notebook_connected"

In [None]:
xx = torch.tensor(np.array([x for x in np.ndindex(20,20)])).float()-10

def plot3d(xx,f):
    return px.scatter_3d(
        x=xx[:,0],
        y=xx[:,1],
        z = [f(x.reshape(-1,1)).numpy()[0][0] for x in xx]
    )
plot3d(xx,f)

In [None]:
def f_prime(x):
    return (1/2) * A.T @ x + (1/2) * A @ x - b

Using `f_prime` above, 

In [None]:
[f_prime(x.reshape(-1,1)) for x in xx][:5]

[tensor([[ -77.],
         [-112.]]),
 tensor([[ -75.],
         [-106.]]),
 tensor([[ -73.],
         [-100.]]),
 tensor([[-71.],
         [-94.]]),
 tensor([[-69.],
         [-88.]])]

In [None]:
xx = torch.tensor(np.array([x for x in np.ndindex(20,20)])).float()-15

def get_magnitude(x):
    return torch.linalg.norm(torch.zeros((1,2))-x).numpy()

def plot3d(xx,f):
    return px.scatter_3d(
        x=xx[:,0],
        y=xx[:,1],
        z = [get_magnitude(f(x.reshape(-1,1))) for x in xx]
    )
plot3d(xx,f_prime)

In [None]:
#|hide
#|eval: false
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev import nbdev_export
    nbdev_export()