In [0]:
# https://gist.github.com/sbarratt/37356c46ad1350d4c30aefbd488a4faa
# https://gist.github.com/sbarratt/37356c46ad1350d4c30aefbd488a4faa#gistcomment-2987842

import torch
torch.manual_seed(123)

def compute_jacobian(f, x, output_dims):
    '''
    Normal:
        f: input_dims -> output_dims
    Jacobian mode:
        f: output_dims x input_dims -> output_dims x output_dims
    '''
    repeat_dims = tuple(output_dims) + (1,) * len(x.shape)
    jac_x = x.detach().repeat(*repeat_dims)
    jac_x.requires_grad_()
    # print(jac_x.shape)
    jac_y = f(jac_x)
    
    ml = torch.meshgrid([torch.arange(dim) for dim in output_dims])
    index = [m.flatten() for m in ml]
    gradient = torch.zeros(output_dims + output_dims)
    gradient.__setitem__(tuple(index)*2, 1)
    
    jac_y.backward(gradient)
        
    return jac_x.grad.data

def g(x):
  return x*x @ w

def h(x):
  return 3*x @ w


w = torch.randn(4, 3)
f = lambda x: x @ w
x = torch.randn(2,3,4)

jac_f = compute_jacobian(f, x, [2,3,3])
jac_g = compute_jacobian(g, x, [2,3,3])
jac_h = compute_jacobian(h, x, [2,3,3])

print(jac_f.sum())
print(jac_g.sum())
print(jac_h.sum())


tensor(-19.6881)
tensor(-7.6844)
tensor(-59.0642)


## Minibatch version of original get_jacobian code

In [0]:
# https://gist.github.com/sbarratt/37356c46ad1350d4c30aefbd488a4faa#gistcomment-3060003

import torch

def get_jacobian(net, x, num_outputs, batch_size=None, verbose=0):
    """
    Compute jacobian matrix of network outputs w.r.t input x.
    
    Parameters
    ----------
    net: A pytorch callable (e.g a network instance)

    num_outputs: int
        Number of outputs produced by net (per input instance)
        
    batch_size: int, optional
        If None, then do run in full-back mode. Else run in minibatch mode
        with mini-batches of size `batch_size`
    """

    from sklearn.utils import gen_batches
    import torch

    if batch_size is None:
        batch_size = num_outputs
    num_batches = num_outputs / float(batch_size) + (num_outputs % batch_size != 0)
    x.requires_grad_(False)
    x = x.squeeze(0)
    shape = list(x.shape)
    ones = [1] * len(shape)
    jacs = torch.zeros([num_outputs] + shape)
    
    for b, batch in enumerate(gen_batches(num_outputs, batch_size)):
        this_batch_size = len(jacs[batch])
        x_ = x.repeat(this_batch_size, *ones).requires_grad_(True)
        output = net(x_)
        assert (len(output.shape) == 2 and len(output) == this_batch_size)
        output.backward(torch.eye(num_outputs)[batch, :])
        jacs[batch] = x_.grad
        if verbose and num_batches > 1:
            print("Batch %02i / %02i" % (b + 1, num_batches))
   
    return jacs.data

# Worked example
num_features = 2
num_outputs = 10
x = torch.ones(num_features)
W = torch.randn(num_features, num_outputs)
y = lambda z: z @ W + 2019
batch_size = 3
jacs = get_jacobian(y, x, num_outputs, batch_size, verbose=1)
print("dy/dx:\n%s" % jacs)
print("W.T:\n%s" % W.T)

Batch 01 / 04
Batch 02 / 04
Batch 03 / 04
Batch 04 / 04
dy/dx:
tensor([[ 0.2212,  0.4296],
        [ 0.6006,  1.0388],
        [-1.1592, -0.9876],
        [-0.2304,  0.4367],
        [ 1.3516,  1.6069],
        [-0.4817, -0.0126],
        [-0.7625,  0.0338],
        [ 0.7208, -1.3271],
        [-1.0591,  0.2027],
        [ 0.5967, -0.2911]])
W.T:
tensor([[ 0.2212,  0.4296],
        [ 0.6006,  1.0388],
        [-1.1592, -0.9876],
        [-0.2304,  0.4367],
        [ 1.3516,  1.6069],
        [-0.4817, -0.0126],
        [-0.7625,  0.0338],
        [ 0.7208, -1.3271],
        [-1.0591,  0.2027],
        [ 0.5967, -0.2911]])


## The n-th derivative of a function

In [0]:
# https://stackoverflow.com/a/50375367/5270873

import torch
from torch.autograd import grad

def nth_derivative(f, wrt, n):

    for i in range(n):

        grads = grad(f, wrt, create_graph=True)[0]
        f = grads.sum()

    return grads

x = torch.arange(4.0, requires_grad=True).reshape(2, 2)
print(x)
loss = (x ** 3).sum() # a scalar
print(loss)

print(nth_derivative(f=loss, wrt=x, n=1))
print(nth_derivative(f=loss, wrt=x, n=2))
print(nth_derivative(f=loss, wrt=x, n=3))
print(nth_derivative(f=loss, wrt=x, n=4))
# error if: print(nth_derivative(f=loss, wrt=x, n=5))

tensor([[0., 1.],
        [2., 3.]], grad_fn=<AsStridedBackward>)
tensor(36., grad_fn=<SumBackward0>)
tensor([[ 0.,  3.],
        [12., 27.]], grad_fn=<MulBackward0>)
tensor([[ 0.,  6.],
        [12., 18.]], grad_fn=<MulBackward0>)
tensor([[6., 6.],
        [6., 6.]], grad_fn=<MulBackward0>)
tensor([[0., 0.],
        [0., 0.]])


## computes the Jacobian of any tensor w.r.t. any dimensional inputs

In [0]:
# https://stackoverflow.com/q/50322833/5270873

import torch
import torch.autograd as ag

def nd_range(stop, dims = None):
    if dims == None:
        dims = len(stop)
    if not dims:
        yield ()
        return
    for outer in nd_range(stop, dims - 1):
        for inner in range(stop[dims - 1]):
            yield outer + (inner,)


def full_jacobian(f, wrt):    
    f_shape = list(f.size())
    wrt_shape = list(wrt.size())
    fs = []


    f_range = nd_range(f_shape)
    wrt_range = nd_range(wrt_shape)

    for f_ind in f_range:
        grad = ag.grad(f[tuple(f_ind)], wrt, retain_graph=True, create_graph=True)[0]
        for i in range(len(f_shape)):
            grad = grad.unsqueeze(0)
        fs.append(grad)

    fj = torch.cat(fs, dim=0)
    fj = fj.view(f_shape + wrt_shape)
    return fj

In [0]:
x = torch.arange(4.0, requires_grad=True).reshape(2, 2)

loss = (x ** 3).sum() # a scalar  
print(x)
print(loss)

full_jacobian(loss, x)

tensor([[0., 1.],
        [2., 3.]], grad_fn=<AsStridedBackward>)
tensor(36., grad_fn=<SumBackward0>)


tensor([[ 0.,  3.],
        [12., 27.]], grad_fn=<ViewBackward>)

In [0]:
import torch

from torch.autograd import grad

def jacobian(inputs, outputs):
    return torch.stack([grad([outputs[:, i].sum()], [inputs], retain_graph=True, create_graph=True)[0]
                        for i in range(outputs.size(1))], dim=-1)
    

In [19]:
w = torch.randn(4,3)
f = lambda x: x @ w
x = torch.randn(2,3,4)

jacobian(f, x)

RuntimeError: ignored

In [20]:
x = torch.arange(4.0, requires_grad=True).reshape(2, 2) 
loss = (x ** 3) # a scalar  
print(x)
print(loss) 
jacobian(loss, torch.tensor([1.0]))

tensor([[0., 1.],
        [2., 3.]], grad_fn=<AsStridedBackward>)
tensor([[ 0.,  1.],
        [ 8., 27.]], grad_fn=<PowBackward0>)


IndexError: ignored