# MindSpore Basics - Computation Graph

Make Your First GAN With MindSpore, 2022

In [1]:
import mindspore
import mindspore.nn as nn

## Simple Computation Graph

```
  (x) --> (y) --> (z)
```

> y = x^2
>
> z = 2y + 3

In [2]:
# set up simple graph relating x, y and z

class Func1(nn.Cell):
    def construct(self, x):
        y = x * x
        z = 2 * y + 3
        return z

func1 = Func1()

In [3]:
# work out gradients

# https://mindspore.cn/docs/api/zh-CN/r1.6/api_python/ops/mindspore.ops.GradOperation.html?highlight=gradoperation#mindspore.ops.GradOperation
class GradNet(nn.Cell):
    def __init__(self, func):
        super().__init__()
        self.func = func
        self.grad_func = mindspore.ops.GradOperation()(self.func)
    
    def construct(self, x):
        return self.grad_func(x)

grad_net = GradNet(func1)

In [6]:
# what is gradient at x = 3.5

x = mindspore.Tensor(3.5)

grad_net(x)

Tensor(shape=[], dtype=Float32, value= 14)

## Computation Graph With Multiple Links To A Node

```

  (a) --> (x)
       \ /     \
       .       (z)
      / \     /
  (b) --> (y)

 
  x = 2a + 3b
 
  y = 5a^2 + 3b^3
 
  z = 2x + 3y

```

In [11]:
# set up simple graph relating x, y and z

class Func2(nn.Cell):
    def construct(self, a, b):
        x = 2 * a + 3 * b
        y = 5 * a * a + 3 * b * b * b
        z = 2 * x + 3 * y
        return z

func2 = Func2()

In [13]:
# work out gradients

class GradNet2(nn.Cell):
    def __init__(self, func):
        super().__init__()
        self.func = func
        self.grad_func = mindspore.ops.GradOperation(get_all=True)(self.func)
    
    def construct(self, a, b):
        return self.grad_func(a, b)

grad_net2 = GradNet2(func2)

In [14]:
# what is gradient at a = 2.0

a = mindspore.Tensor(2.0)
b = mindspore.Tensor(1.0)

grad_a, grad_b = grad_net2(a, b)
print(f'Grad of a: {grad_a}. Grad of b: {grad_b}')

Grad of a: 64.0. Grad of b: 33.0


## Manually check MindSpore Result


```

dz/da = dz/dx * dx/da + dz/dy * dy/da

      = 2 * 2 + 3 * 10a

      = 4  + 30a

When a = 3.5, dz/da = 64  ... correct!

```

