## Matriks

Matriks merupakan array 2 dimensi yang memiliki baris dan kolom. 
Untuk matriks beranggotakan bilangan riil, kita dapat menuliskannya dengan notasi $\mathbf{A} \in \mathbb{R}^{m \times k}$:

$$
\mathbf{A} = 
\begin{bmatrix}
a_{11} & a_{12} & \cdots & a_{1k} \\
a_{21} & a_{22} & \cdots & a_{2k} \\
\vdots & \vdots & \ddots & \vdots \\
a_{m1} & a_{m2} & \cdots & a_{mk}
\end{bmatrix}
$$

Sebagai contoh, berikut sebuah matriks $\mathbf{A}$ berdimensi 3 x 4:

$$
\mathbf{A} = 
\begin{bmatrix}
0 & 1 & -2.3 & 0.1 \\
1.3 & 4 & -0.1 & 1 \\
4.1 & -1 & 0 & 1.7
\end{bmatrix}
$$

### Pembentukkan matriks

Seperti halnya pada vektor, JAX menyediakan berbagai cara untuk membentuk matriks melalui jax.numpy.

**Hard-coded**

In [1]:
import jax.numpy as jnp
import jax
A = jnp.array([
    [0, 1, -2.3, 0.1], 
    [1.3, 4, -0.1, 1], 
    [4.1, -1, 0, 1.7]
])

m, k = A.shape # get matrix dimension

print(f"Value: {A}")
print(f"Shape: {m, k}")

Value: [[ 0.   1.  -2.3  0.1]
 [ 1.3  4.  -0.1  1. ]
 [ 4.1 -1.   0.   1.7]]
Shape: (3, 4)


**Random**

In [7]:
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)
A = jax.random.normal(subkey, (5, 4))
print(f"A : \n{A}")

A : 
[[-2.4424558  -2.0356805   0.20554423 -0.3535502 ]
 [-0.76197404 -1.1785518  -1.1482196   0.29716578]
 [-1.3105359   2.1302025  -0.18957235  0.96401215]
 [-1.3011001  -0.7486938  -0.3729984   0.4427907 ]
 [-1.1902995  -0.06925564 -0.95605886 -1.9587638 ]]


**Matriks Nol, Identitas, Diagonal**

In [8]:
Z = jnp.zeros((3, 3))
I = jnp.ones((3, 3))
D = jnp.diag(jnp.ones(3))

print(f"Zeros: \n{Z}")
print(f"Ones: \n{I}")
print(f"Diagonal: \n{D}")

Zeros: 
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
Ones: 
[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]
Diagonal: 
[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]


**Submatrix / Slicing**

In [9]:
key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
A = jax.random.randint(subkey, (4, 4), -5, 5)
print(f"A: \n{A}")

B = A[2:4, 2:4] # slicing
print(f"B: \n{B}")

A: 
[[ 0 -3 -1 -3]
 [ 0 -1 -4  3]
 [ 1  0 -5  1]
 [ 3 -4 -2  3]]
B: 
[[-5  1]
 [-2  3]]


**Penggabungan matriks**

Matriks dapat dibentuk dari penggabungan matriks-matriks lain dengan menggunakan `jnp.hstack()` (*horizontal merge*) atau `jnp.vstack()` (*vertical merge*)

In [13]:
key, subkey = jax.random.split(key)
A = jax.random.normal(subkey, (3, 2))

key, subkey = jax.random.split(key)
B = jax.random.normal(subkey, (3, 4))

print(f"A: {A}")
print(f"B: {B}")
print(f"A shape: {A.shape}")
print(f"B shape: {B.shape}")

C = jnp.hstack((A, B)) # horizontal merge
print(f"(hstack) C shape: {C.shape}")

key, subkey = jax.random.split(key)
A = jax.random.normal(subkey, (5, 3))

key, subkey = jax.random.split(key)
B = jax.random.normal(subkey, (2, 3))

print(f"A: {A}")
print(f"B: {B}")
print(f"A shape: {A.shape}")
print(f"B shape: {B.shape}")

C = jnp.vstack((A, B)) # vertical merge
print(f"(vstack) C shape: {C.shape}")

A: [[ 0.8162971  -1.0413978 ]
 [-0.09381925  1.2889506 ]
 [-0.8274624   0.5524278 ]]
B: [[-1.0291611  -0.31853303  2.0335982   0.5695654 ]
 [ 0.27310836 -0.49904412  0.13288276  1.0810927 ]
 [ 1.0500706  -0.7995201   0.56557906  0.9537454 ]]
A shape: (3, 2)
B shape: (3, 4)
(hstack) C shape: (3, 6)
A: [[-0.45876116 -1.820128    0.09428611]
 [-1.1634603   0.5830652   0.8229201 ]
 [-1.2494267   0.4772025   0.38525388]
 [ 0.6251398   0.9846895   0.21361268]
 [ 0.2637497  -0.25655866  0.5197545 ]]
B: [[ 0.6236211  -0.5397633   0.13570921]
 [-1.5921305  -0.21137282  0.30612153]]
A shape: (5, 3)
B shape: (2, 3)
(vstack) C shape: (7, 3)


**Mengubah vektor dimensi $d$ menjadi matriks**

Pada JAX, fungsi `reshape()` dapat mengubah bentuk atau dimensi dari vektor atau matriks tanpa mengubah data.

In [15]:
key, subkey = jax.random.split(key)
v = jax.random.normal(subkey, (4,))
print(f"v: {v}")
print(f"v shape: {v.shape}")

M = jnp.reshape(v, (1, 4))
print(f"M: {M}")
print(f"M shape: {M.shape}")

w = M.flatten() # vectorization: go back to original vector
print(f"w: {w}")
print(f"w shape: {w.shape}")

print(jnp.allclose(v, w))

v: [-3.0202537  -0.09792765  1.2203779  -0.09892032]
v shape: (4,)
M: [[-3.0202537  -0.09792765  1.2203779  -0.09892032]]
M shape: (1, 4)
w: [-3.0202537  -0.09792765  1.2203779  -0.09892032]
w shape: (4,)
True


### Operasi Dasar Matriks

Diketahui matriks $\mathbf{A}, \mathbf{B} \in \mathbb{R}^{m \times k}$, dan skalar $c \in \mathbb{R}$.

- **Penjumlahan**: $\mathbf{C} = \mathbf{A} + \mathbf{B}$
- **Pengurangan**: $\mathbf{C} = \mathbf{A} - \mathbf{B}$
- **Perkalian dengan skalar**: $\mathbf{C} = c \mathbf{A}$
- **Perkalian antar elemen matriks**: $\mathbf{C} = \mathbf{A} * \mathbf{B}$

In [16]:
A = jnp.array([[0, 4], [7,0], [3,1]])
B = jnp.array([[1, 2], [2,3], [0,4]])
C = A + B
print(f"Addition: \n{C}")

C = 2.5 * A
print(f"Scalar mult: \n{C}")

C = A * B
print(f"Element-wise mult: \n{C}")

Addition: 
[[1 6]
 [9 3]
 [3 5]]
Scalar mult: 
[[ 0.  10. ]
 [17.5  0. ]
 [ 7.5  2.5]]
Element-wise mult: 
[[ 0  8]
 [14  0]
 [ 0  4]]


### Inner Product

#### Perkalian matriks-vektor

Perkalian matriks-vektor merupakan generalisasi inner product dari 2 vektor. Misal terdapat matriks $\mathbf{A} \in \mathbb{R}^{m \times k}$ dan vektor $\mathbf{v} \in \mathbb{R}^k$, perkalian matriks-vektor menghasilkan sebuah vektor baru $\mathbf{y} \in \mathbb{R}^m$:

$$
\mathbf{y} = \mathbf{A} \mathbf{v}
$$

In [None]:
key, subkey = jax.random.split(key)
A = jax.random.normal(subkey, (5, 3))

key, subkey = jax.random.split(key)
v = jax.random.normal(subkey, (3,))

print(f"A: \n{A}")
print(f"v: \n{v}")

y = jnp.dot(A, v)
print(f"Matrix-vector mult: \n{y}")

A: 
[[ 0.93723524 -0.7031259   0.14395906]
 [ 0.584948    1.0222079   1.2163118 ]
 [-0.18599552 -1.7047697   1.1065272 ]
 [-0.10622101 -1.1216718   1.9150256 ]
 [-1.0020611   0.02767283 -0.25238422]]
v: 
[-0.35832682  1.8695863  -0.6105054 ]
Matrix-vector mult: 
[-1.7382789   0.95893836 -3.796108   -3.228134    0.56488407]


#### Perkalian matriks-matriks
Diketahui matriks $\mathbf{A} \in \mathbb{R}^{m \times d}$ dan $\mathbf{B} \in \mathbb{R}^{d \times k}$, perkalian antar kedua matriks tsb menghasilkan matriks baru $\mathbf{C} \in \mathbb{R}^{m} \times k$:

$$
\mathbf{C} = \mathbf{A} \mathbf{B}
$$

In [22]:
key, subkey = jax.random.split(key)
A = jax.random.randint(subkey, (4, 2), -10, 10)

key, subkey = jax.random.split(key)
B = jax.random.randint(subkey, (2, 3), -5, 5)

print(f"A: \n{A}")
print(f"B: \n{B}")

C = jnp.dot(A, B)
print(f"C: \n{C}")

A: 
[[-8  5]
 [-3  0]
 [ 1  6]
 [-7  7]]
B: 
[[ 4  1 -3]
 [ 0 -4  2]]
C: 
[[-32 -28  34]
 [-12  -3   9]
 [  4 -23   9]
 [-28 -35  35]]


#### Transpos
Transpos merupakan operator untuk menukar posisi baris dan kolom matriks. Tranpos dari matriks $\mathbf{A} \in \mathbb{R}^{m \times k}$ ditulis dengan $\mathbf{A}^\top \in \mathbb{R}^{k \times m} $.

Sebagai contoh, diketahui matriks $\mathbf{A} \in \mathbb{R}^{2 \times 3}$
$$
\mathbf{A} = 
\begin{bmatrix}
1 & 2 & 3 \\
4 & 5 & 6
\end{bmatrix}
$$

Transpos dari matriks tsb adalah

$$
\mathbf{A}^\top = 
\begin{bmatrix}
1 & 4 \\
2 & 5 \\
3 & 6
\end{bmatrix}
$$

In [23]:
key, subkey = jax.random.split(key)
A = jax.random.normal(subkey, (4, 2))

key, subkey = jax.random.split(key)
B = jax.random.normal(subkey, (3, 2))

# C = jnp.dot(A, B) # can't be computed! matrix B needs to be trasponsed

C = jnp.dot(A, B.T) # B is transposed

print(f"B shape: {B.shape}")
print(f"B.T shape: {B.T.shape}")

B shape: (3, 2)
B.T shape: (2, 3)


### Inverse

**Left-inverse**: Matriks $\mathbf{X} \in \mathbb{R}^{k \times m}$ merupakan *left-inverse* dari matriks $\mathbf{A} \in \mathbb{R}^{m \times k}$ jika memenuhi:

$$
\mathbf{X} \mathbf{A} = \mathbf{I}
$$

**Right-inverse**: Matriks $\mathbf{X}$ merupakan *right-inverse* dari matriks $\mathbf{A}$ jika memenuhi:

$$
\mathbf{A} \mathbf{X} = \mathbf{I}
$$

Jika $\mathbf{X}$ memenuhi baik *left-inverse* maupun *right-inverse* di atas, maka $\mathbf{X}$ disebut sebagai matriks inverse dari $\mathbf{A}$ atau ditulis dengan $\mathbf{A}^{-1}$.

Syarat awal agar $\mathbf{A}$ memiliki inverse adalah harus berbentuk matriks segiempat, i.e., $\mathbb{R}^{m \times m}$.

In [25]:
key, subkey = jax.random.split(key)
X = jax.random.randint(subkey, (4, 4), -2, 10).astype(jnp.float32)

print(f"X: \n {X}")

Xinv = jnp.linalg.inv(X)
print(f"Inverse of X: \n {Xinv}")

print(f"{jnp.dot(X, Xinv)}")
print(f"{jnp.dot(Xinv, X)}")

X: 
 [[ 6.  7. -2.  1.]
 [ 0.  7.  3.  7.]
 [-2. -2.  9.  0.]
 [ 7.  7. -1.  5.]]
Inverse of X: 
 [[-4.7674436e-02 -1.3604653e-01  5.6976750e-02  2.0000002e-01]
 [ 2.3081398e-01  1.0988375e-01 -7.5581460e-03 -2.0000002e-01]
 [ 4.0697671e-02 -5.8139516e-03  1.2209302e-01  4.8856280e-10]
 [-2.4825582e-01  3.5465099e-02 -4.4767436e-02  2.0000002e-01]]
[[ 1.0000000e+00  6.7055225e-08  3.7252903e-09 -5.9604645e-08]
 [ 8.9406967e-08  1.0000001e+00 -3.7252903e-09  2.9802322e-08]
 [-5.2154064e-08 -2.7939677e-09  1.0000000e+00  4.3970654e-09]
 [ 2.9802322e-08 -2.2351742e-08  4.8428774e-08  1.0000000e+00]]
[[ 1.0000000e+00 -8.9406967e-08  1.4901161e-08 -2.9802322e-08]
 [ 8.9406967e-08  1.0000002e+00 -1.4901161e-08  1.4901161e-07]
 [-1.1481221e-08  3.4199397e-09  1.0000000e+00  1.2687362e-08]
 [ 1.4901161e-07  2.9802322e-08  0.0000000e+00  9.9999994e-01]]


### Norm matriks

Konsep norm juga dapat diaplikasikan pada matriks yang merepresentasikan besaran skalar (*magnitude*) dari suatu matriks. 
Sebagai contoh, Euclidean norm dari matriks $\mathbf{A} \in \mathbb{R}^{m \times k}$ adalah:

$$
\| \mathbf{A} \| = \sqrt{\left( \sum_{i=1}^{m} \sum_{j=1}^{k} a^2_{ij} \right)}
$$

In [26]:
key, subkey = jax.random.split(key)
A = jax.random.uniform(subkey, (4, 3))
print(f"A: \n{A}")

print(f"Norm(A): {jnp.linalg.norm(A)}")

A: 
[[0.5555557  0.25089288 0.6705233 ]
 [0.93133795 0.7595793  0.8994366 ]
 [0.12667322 0.4431491  0.6564672 ]
 [0.94205177 0.5898868  0.02611053]]
Norm(A): 2.2257614135742188
