In [52]:
import torch

# Autograd

In [64]:
a = torch.tensor([2., 3.], requires_grad=True)
b = torch.tensor([6., 4.], requires_grad=True)

In [65]:
a, b

(tensor([2., 3.], requires_grad=True), tensor([6., 4.], requires_grad=True))

In [66]:
#cant call backward on tensor; need size=1
q = a**3 + 5*b**2
# q.sum().backward()
external_gradient = torch.tensor([1., 1.])
q.backward(gradient=external_gradient)

In [67]:
a.grad, b.grad

(tensor([12., 27.]), tensor([60., 40.]))

In [68]:
3*a**2 == a.grad, 10*b == b.grad

(tensor([True, True]), tensor([True, True]))

# Extending PyTorch

In [162]:
class LinearFunction(torch.autograd.Function):
	
	@staticmethod
	def forward(ctx, input, weight, bias):
		ctx.save_for_backward(input, weight)
		return input.mm(weight.t()) + bias

	@staticmethod
	def backward(ctx, grad_output):
		input, weight = ctx.saved_tensors 
		return grad_output.mm(weight), grad_output.t().mm(input), grad_output.sum(0)

In [163]:
linear = LinearFunction.apply

In [164]:
linear(torch.randn(1,2, requires_grad=True), torch.randn(2,2), torch.randn(1,2))

tensor([[ 2.3402, -4.2041]], grad_fn=<LinearFunctionBackward>)