# 02. Gradient calculation

Most deep learning models are trained using some variation of gradient descent method. The idea behind this method is as follows: 

1. You define loss function - something you want to minimize, e.g. square difference between predicted and actual output. 
2. Then you calculate partial derivatives of this loss function with respect to its parameters. Vector of partial derivatives is called gradient. 
3. Finally, you update the parameters towards minimizing the loss function, where direction and magnitude of the update are proportional to corresponding gradient value. 

For more information on gradient descent see e.g. [this cheatsheet](https://ml-cheatsheet.readthedocs.io/en/latest/gradient_descent.html).

In Lilith you can calculate gradient of any scalar-valued function. For example, consider such a loss function:

In [5]:
function my_loss(y_pred::Vector, y_true::Vector)
    return sum((y_pred .- y_true) .^ 2)
end

y_pred = rand(10)
y_true = rand(10)

my_loss(y_pred, y_true)


2.2992492304954895

Now we can calculate value of the function and gradients w.r.t. to its parameters with a single call to `grad()`. Note that due to function tracing and gradient compilation the first run make take quite long.

In [6]:
using Lilith

val, g = grad(my_loss, y_pred, y_true)

(2.2992492304954895, GradResult(2))

`val` is the same value of the function call that we've seen earlier, and `g` is a `GradResult` object, containing gradients for each parameter:

In [7]:
print(g[1])    # gradient of my_loss w.r.t. 1st parameter, i.e. y_pred
print(g[2])    # gradient of my_loss w.r.t. 2nd parameter, i.e. y_true

[-0.87848431743822, 1.2165144839180422, -0.5007782854022085, -1.6989762691502182, 0.15510918492280146, 0.2399934876212826, 1.3228758302794104, -1.4021136439424584, 0.08987909227352509, -0.04897310185926251][0.87848431743822, -1.2165144839180422, 0.5007782854022085, 1.6989762691502182, -0.15510918492280146, -0.2399934876212826, -1.3228758302794104, 1.4021136439424584, -0.08987909227352509, 0.04897310185926251]