#### Computational Graphs in PyTorch

In [45]:
# Load the required packages
import torch
from IPython.display import display, Math

In [80]:
# Define the graph a,b,c,d are leaf nodes and e is the root node
# The graph is constructed with every line since the 
# computational graphs are dynamic in PyTorch

a = torch.tensor([2.0],requires_grad=True)
b = torch.tensor([3.0],requires_grad=True)
c = torch.tensor([5.0],requires_grad=True)
d = torch.tensor([10.0],requires_grad=True)
u = a*b
t = torch.log(d)
v = t*c
t.retain_grad() 
e = u+v

In [76]:
print(f'a.is_leaf: {a.is_leaf}')
print(f'a.grad_fn: {a.grad_fn}')
print(f'a.grad: {a.grad}')
print()

print(f'b.is_leaf: {b.is_leaf}')
print(f'b.grad_fn: {b.grad_fn}')
print(f'b.grad: {b.grad}')
print()

print(f'c.is_leaf: {c.is_leaf}')
print(f'c.grad_fn: {c.grad_fn}')
print(f'c.grad: {c.grad}')
print()

print(f'd.is_leaf: {d.is_leaf}')
print(f'd.grad_fn: {d.grad_fn}')
print(f'd.grad: {d.grad}')
print()

print(f'e.is_leaf: {e.is_leaf}')
print(f'e.grad_fn: {e.grad_fn}')
print(f'e.grad: {e.grad}')
print()

print(f'u.is_leaf: {u.is_leaf}')
print(f'u.grad_fn: {u.grad_fn}')
print(f'u.grad: {u.grad}')
print()

print(f'v.is_leaf: {v.is_leaf}')
print(f'v.grad_fn: {v.grad_fn}')
print(f'v.grad: {v.grad}')
print()

print(f't.is_leaf: {t.is_leaf}')
print(f't.grad_fn: {t.grad_fn}')
print(f't.grad: {t.grad}')

a.is_leaf: True
a.grad_fn: None
a.grad: None

b.is_leaf: True
b.grad_fn: None
b.grad: None

c.is_leaf: True
c.grad_fn: None
c.grad: None

d.is_leaf: True
d.grad_fn: None
d.grad: None

e.is_leaf: False
e.grad_fn: <AddBackward0 object at 0x0000020EB13FAFD0>
e.grad: None

u.is_leaf: False
u.grad_fn: <MulBackward0 object at 0x0000020EB13FAFD0>
u.grad: None

v.is_leaf: False
v.grad_fn: <MulBackward0 object at 0x0000020EB13FAFD0>
v.grad: None

t.is_leaf: False
t.grad_fn: <LogBackward object at 0x0000020EB13FAFD0>
t.grad: None


In [81]:
# if you want to call the backward function again you'll have to 
# set retain_graph = True
e.backward(retain_graph=True)

In [85]:
# Since retain_grad() was called for node t, gradients were 
# calculated despite it not being a leaf node.
print(t.grad)
# Since retain_grad() was not called for node u and u is not a leaf node, 
# gradients were not calculated for this node
print(u.grad)

tensor([5.])
None


In [86]:
print(f'a.is_leaf: {a.is_leaf}')
print(f'a.grad_fn: {a.grad_fn}')
print(f'a.grad: {a.grad}')
print()

print(f'b.is_leaf: {b.is_leaf}')
print(f'b.grad_fn: {b.grad_fn}')
print(f'b.grad: {b.grad}')
print()

print(f'c.is_leaf: {c.is_leaf}')
print(f'c.grad_fn: {c.grad_fn}')
print(f'c.grad: {c.grad}')
print()

print(f'd.is_leaf: {d.is_leaf}')
print(f'd.grad_fn: {d.grad_fn}')
print(f'd.grad: {d.grad}')
print()

print(f'e.is_leaf: {e.is_leaf}')
print(f'e.grad_fn: {e.grad_fn}')
print(f'e.grad: {e.grad}')
print()

print(f'u.is_leaf: {u.is_leaf}')
print(f'u.grad_fn: {u.grad_fn}')
print(f'u.grad: {u.grad}')
print()

print(f'v.is_leaf: {v.is_leaf}')
print(f'v.grad_fn: {v.grad_fn}')
print(f'v.grad: {v.grad}')
print()

print(f't.is_leaf: {t.is_leaf}')
print(f't.grad_fn: {t.grad_fn}')
print(f't.grad: {t.grad}')

a.is_leaf: True
a.grad_fn: None
a.grad: tensor([3.])

b.is_leaf: True
b.grad_fn: None
b.grad: tensor([2.])

c.is_leaf: True
c.grad_fn: None
c.grad: tensor([2.3026])

d.is_leaf: True
d.grad_fn: None
d.grad: tensor([0.5000])

e.is_leaf: False
e.grad_fn: <AddBackward0 object at 0x0000020EAF69DCF8>
e.grad: None

u.is_leaf: False
u.grad_fn: <MulBackward0 object at 0x0000020EAF69DCF8>
u.grad: None

v.is_leaf: False
v.grad_fn: <MulBackward0 object at 0x0000020EAF69DCF8>
v.grad: None

t.is_leaf: False
t.grad_fn: <LogBackward object at 0x0000020EAF69DCF8>
t.grad: tensor([5.])


Gradients computed using the PyTorch autograd

In [64]:
display(Math(fr'\frac{{\partial e}}{{\partial a}} = {a.grad.item()}'))
print()
display(Math(fr'\frac{{\partial e}}{{\partial b}} = {b.grad.item()}'))
print()
display(Math(fr'\frac{{\partial e}}{{\partial c}} = {c.grad.item()}'))
print()
display(Math(fr'\frac{{\partial e}}{{\partial d}} = {d.grad.item()}'))

<IPython.core.display.Math object>




<IPython.core.display.Math object>




<IPython.core.display.Math object>




<IPython.core.display.Math object>