Copyright (c) 2023 Graphcore Ltd. All rights reserved.

# ReplicaGrouping

In this short complementary notebook we are going to explain the concept of **ReplicaGrouping**. You can read more on the [official user guide](https://docs.graphcore.ai/projects/popxl/en/3.1.0/replication.html?highlight=replica%20grouping#replica-grouping).

Generally speaking, a `ReplicaGrouping` object defines a group of IPUs.
Replica groupings are used in two different contexts:
- **Variables** : the replica group of a variable identifies the IPUs over which the variable is the same. The variable assumes the same value on IPUs that belongs to the same group.
- **Collectives** : in collectives, replica groups are used to define which IPUs need to communicate. Communication happens between IPUs in the same group.

## Definition
A ReplicaGrouping is specified using a `group_size` and a `stride`.
- The **group_size** tells how many IPUs are in the group.
- The **stride** defines which IPUs are in the same group. Specifically, it tells the distance between IPUs belonging to the same group.

To create a replica group, you can use the following syntax:
```python
ir = popxl.Ir(replication=replicas)
rg = ir.replica_grouping(stride=1, group_size=4)
```

Given a certain group_size, 

$$\text{num_groups} = \text{replicas} // \text{group_size}$$

will be created.

Since this must be an integer, not all values of parameters are allowed.

You don't need to specify both the stride and the group_size.
- If you only specify the stride, the group_size is assumed to be `group_size = replicas//stride`.
- If you only specify the group_size, a unitary stride is assumed.
- The **default** group, created with `ir.replica_grouping()` is *all replicas in the same group*: a single group with `group_size=replicas` and `stride=1`.

A ReplicaGrouping defines an **assignment**. When a group is created, each IPU is given a number `0, ..., num_groups`. IPUs with the same number belong to the same group.

Depending on your tastes, it may be more intuitive to think about groups in terms of assignments or in terms of IPU indices. 
In the cells below, we provide a short snippet that leverage both visualisations.

In [1]:
import popxl
import numpy as np
import popxl_addons as addons
from typing import Optional
from popxl import ReplicaGrouping
from popxl import ops

In [2]:
def print_rg(ipu_indices, rg):
    print("#######################################")
    print(rg)
    print("Assignment: ",rg.assignment)
    for i in range(rg.num_groups):
        print(f" Group {i} : ", ipu_indices[np.where(np.asarray(rg.assignment)==i)])

In [3]:
replicas = 8
ir = popxl.Ir(replication=replicas)
ipu_indices = np.arange(0, replicas)

rg = ir.replica_grouping(stride=1, group_size=4)
rg1 = ir.replica_grouping(stride=2, group_size=2)
rg2 = ir.replica_grouping(stride=4)
rg3 = ir.replica_grouping(group_size=2)
rg4 = ir.replica_grouping(stride=2, group_size=4)

print_rg(ipu_indices,rg)
print_rg(ipu_indices,rg1)
print_rg(ipu_indices,rg2)
print_rg(ipu_indices,rg3)
print_rg(ipu_indices,rg4)

#######################################
ReplicaGrouping(num_replicas=8, stride=1, group_size=4, num_groups=2)
Assignment:  [0, 0, 0, 0, 1, 1, 1, 1]
 Group 0 :  [0 1 2 3]
 Group 1 :  [4 5 6 7]
#######################################
ReplicaGrouping(num_replicas=8, stride=2, group_size=2, num_groups=4)
Assignment:  [0, 1, 0, 1, 2, 3, 2, 3]
 Group 0 :  [0 2]
 Group 1 :  [1 3]
 Group 2 :  [4 6]
 Group 3 :  [5 7]
#######################################
ReplicaGrouping(num_replicas=8, stride=4, group_size=2, num_groups=4)
Assignment:  [0, 1, 2, 3, 0, 1, 2, 3]
 Group 0 :  [0 4]
 Group 1 :  [1 5]
 Group 2 :  [2 6]
 Group 3 :  [3 7]
#######################################
ReplicaGrouping(num_replicas=8, stride=1, group_size=2, num_groups=4)
Assignment:  [0, 0, 1, 1, 2, 2, 3, 3]
 Group 0 :  [0 1]
 Group 1 :  [2 3]
 Group 2 :  [4 5]
 Group 3 :  [6 7]
#######################################
ReplicaGrouping(num_replicas=8, stride=2, group_size=4, num_groups=2)
Assignment:  [0, 1, 0, 1, 0, 1, 0, 1]


## Transpose of a ReplicaGrouping

ReplicaGrouping have `.transpose()` method that generates another group.
To understand what it is like, let's represent IPU indices as an array.

<img src="images/ipu_indices.png" alt="IPU indices" style="width:500px;"/>

Defining a stride and a group size amounts to change this flat visualisation to a matrix visualisation where indices are arranged in `num_groups` columns and `group_size` row. The easiest way to build this matrix is looking at the assignment: the assigned ID denotes the column you have to place the IPU in. So if the assignment is `[0, 1, 0, 1, 0, 1, 0, 1]` so place IPUs 0,2,4,6 in column 1, IPUs 1,3,5,7 in column 1. If the assignment is `[0, 1, 2, 3, 0, 1, 2, 3]` you place IPUs 0,4 in column 0, IPUs 1,5 in column 1, IPUs 3,6 in column 2, IPUs 4,7 in column 3.

Another easy way to build it is to place IPU indices in columns of length `group_size` so that between each adjacent indices there is `stride` distance: with a `stride=2, group_size=2`, you build the first column with indices 0,2, the second with indices 1,3 and so on.

If you now take the transpose of this matrix, you get the transpose group.

The image below illustrate this concept for a `ReplicaGrouping(stride=2, group_size=4)`.

<img src="images/transpose.png" alt="The transpose of a ReplicaGrouping" style="width:500px;"/>


In [4]:
replicas = 8
ir = popxl.Ir(replication=replicas)
ipu_indices = np.arange(0, replicas)

rg = ir.replica_grouping(stride=2, group_size=4)
rg_t = rg.transpose()

print_rg(ipu_indices,rg)
print_rg(ipu_indices,rg_t)

#######################################
ReplicaGrouping(num_replicas=8, stride=2, group_size=4, num_groups=2)
Assignment:  [0, 1, 0, 1, 0, 1, 0, 1]
 Group 0 :  [0 2 4 6]
 Group 1 :  [1 3 5 7]
#######################################
ReplicaGrouping(num_replicas=8, stride=1, group_size=2, num_groups=4)
Assignment:  [0, 0, 1, 1, 2, 2, 3, 3]
 Group 0 :  [0 1]
 Group 1 :  [2 3]
 Group 2 :  [4 5]
 Group 3 :  [6 7]


## ReplicaGrouping for variables

As mentioned in the introduction, when you create a variable in `popxl` you can specify its replica_grouping:
```python
var = popxl.variable(
        ...
        replica_grouping=ir.replica_grouping(group_size=2),
    )
```
If no group is specified, the **default** group is used, which means the variable is assumed to be the same on all replicas.

In `popxl-addons`, factories carry the information about the variable replica group. In this way, a variable with the appropriate group will be created on initialisation.
Each individual factory has a `replica_group` property, and `NamedVariableFactories` have an associated `NamedReplicaGrouping` collection accessible via the `factories.replica_groupings` property.

Likewise, all `add_variable_input` functions used to add variables in addons modules admit a replica_grouping parameter.

### Example: simple data parallelism
When implementing data parallelism with replication, all replicas have the same weights. Therefore, the replica grouping of variables is the default group, `ReplicaGrouping(stride=1, group_size=DP)`.


### Example: data parallelism + tensor parallelism
We've not explored tensor parallelism in an tutorial. Hopefully a comprehensive understanding is not necessary for the purpose of this example.

In short, tensor parallelism consists in splitting some weights tensors (that are too big to fit in memory), do computations with the sharded tensors and use collectives to reconstruct the full result after.

For the purpose of this tutorial, just think about a linear layer that performs matrix multiplication.
$$
\mathbf{A} = \begin{pmatrix}
a & b \\
c & d 
\end{pmatrix}, \quad
\vec{x} = \begin{pmatrix}
x_1 & x_2 \\
\end{pmatrix}
 \\
\vec{y} = \vec{x} \cdot \mathbf{A} = \begin{pmatrix}
x_1 & x_2 \\
\end{pmatrix} \begin{pmatrix}
a & b \\
c & d 
\end{pmatrix} =  \begin{pmatrix}
x_1\cdot a + x_2\cdot c & x_1\cdot b + x_2\cdot d \\
\end{pmatrix} = \begin{pmatrix}
y_1 & y_2 \\
\end{pmatrix}
$$
The same result can be obtained by splitting (**sharding**) the matrix $\mathbf{A}$ *column-wise* and then **gathering** the results.
$$
y_1 = 
\begin{pmatrix}
x_1 & x_2 \\
\end{pmatrix} \begin{pmatrix}
a \\
c
\end{pmatrix} = x_1\cdot a + x_2\cdot c
$$
$$
y_2 =
\begin{pmatrix}
x_1 & x_2 \\
\end{pmatrix} \begin{pmatrix}
b \\
d
\end{pmatrix} = x_1\cdot b + x_2\cdot d
$$
$$
\text{gather} \to \begin{pmatrix}
y_1 & y_2 \\
\end{pmatrix} = \begin{pmatrix}
x_1\cdot a + x_2\cdot c & x_1\cdot b + x_2\cdot d \\
\end{pmatrix}
$$
In passing, notice that another possible splitting is a *row-wise* sharding of $\mathbf{A}$ combined with a *row-wise* sharding of the input $\vec{x}$ and an **all reduce** operation at the end.
$$
 \begin{pmatrix}
x_1 \\
\end{pmatrix} \begin{pmatrix}
a & b \\
\end{pmatrix} =  \begin{pmatrix}
x_1 \cdot a & x_1 \cdot b\\
\end{pmatrix}
$$
$$
 \begin{pmatrix}
x_2 \\
\end{pmatrix} \begin{pmatrix}
c & d \\
\end{pmatrix} =  \begin{pmatrix}
x_2 \cdot c & x_2 \cdot d\\
\end{pmatrix}
$$
$$
\text{all reduce} \to \begin{pmatrix}
x_1\cdot a + x_2 \cdot c & x_1\cdot b + x_2\cdot d \\
\end{pmatrix}
$$

We can implement tensor parallelism using replication and letting each replica hold a different shard of the weight matrix. In the above example for column-wise sharding, we need two replicas (`TP` replicas = 2): one replica can have the `(a,c)` slice and the other one the `(b, d)` slice. The model spans across 2 IPUs.

If we combine this with data parallelism over other 4 replicas (`DP` replicas = 4), we replicate this 2-IPUs model 4 times, using 8 IPUs in total.

The ReplicaGrouping for the weight will be `ReplicaGrouping(stride=TP, group_size=DP)=ReplicaGrouping(stride=2, group_size=4)`, with an assignment `[0 1 0 1 0 1 0 1]` . This means that IPUs 0,2,4,6 have the same first shard, and IPUs 1,3,5,7 have the same second shard.

<img src="images/variables_tp_dp.png" alt="Variables grouping for a combination of TP and DP" style="width:300px;"/>

## ReplicaGrouping in collectives

When using collectives the replica group defines the underlying **communication group**: communication happens only between IPUs that are in the same group.

All `popxl` collectives have a `group` parameter that allows to specify the ReplicaGrouping for the communication.
If no group is specified, the default group is used, meaning communication happens between all replicas.

In `popxl-addons` there are extra utilities related to collectives in `rts`, `remote` and collectives custom ops. They all admit specification of one or more replica groups.

The relation between the variable replica group and the communication group

Notice that in several situation the transpose of a replica group is used in collectives. For example, in the above example about Tensor Parallelism + Data Parallelism, the replica group for the gather collective is the transpose group of the variable replica grouping, because we want to gather together *different* shards.

<img src="images/collectives.png" alt="Replica grouping for TP collectives" style="width:500px;"/>

However, this is not a general rule. For example, you will see in the `mnist.ipynb` notebook that replicated tensor sharding collectives use the variable group as communication group.

## Example: Tensor Parallelism + Data Parallelism Linear Layer
We now implement a small example of a Linear Layer executed using tensor parallelism (column-wise) and data parallelism, to illustrate the concepts in practice.
You will see:
- How replica groupings can be used in `addons` modules to create variables
- How to initialise each replica with different data.
- How replica groupings can be used in collectives.
- How the `transpose` group can be used in tensor parallel collectives.

We will use the handy `print_tensor` operation to print tensor values on all replicas.

In [5]:
def constant_weight(tp, in_features, shard_size):
    shards = [np.full((in_features,shard_size), i+1, dtype=np.float32) for i in range(tp)]
    return shards

In [6]:
class Linear(addons.Module):
    def __init__(self, out_features: int, tp: int, dp: int):
        super().__init__()
        self.shard_size = out_features//tp
        self.replica_grouping = popxl.gcg().ir.replica_grouping(stride=tp, group_size=dp)
        self.tp = tp
        
    def build(self, x: popxl.Tensor) -> popxl.Tensor:
        # using iter initialise each replica with different data
        w = self.add_variable_input("weight",
                                    iter(constant_weight(self.tp,x.shape[-1],self.shard_size)),
                                    x.dtype,
                                    replica_grouping=self.replica_grouping)
        print("sharded weight shape: ", w.shape)
        ops.print_tensor(w, title='w_shards', digits=2)
        # sharded computation
        y = x @ w
        print("y shard shape: ", y.shape)
        ops.print_tensor(y, title='y_shards',digits=2)
        # gather shards in the transpose group
        y_full = ops.collectives.replicated_all_gather(y, group=self.replica_grouping.transpose(),axis=-1)
        print("y full shape: ", y_full.shape)

        ops.print_tensor(y_full, title='y_full', digits=2)

        return y_full


In [7]:
# We want to project a vector (2,) to a vector (4,) with a matmul with weight of shape (2,4)
# using tensor parallelism, hence splitting the weight in two shards of shape (2,2).
TP=2
DP=2
out_features = 4
in_features = 2
x_data = np.ones((in_features), np.float32)

In [9]:
# numpy version
shards = constant_weight(TP, in_features, out_features//TP)
full_tensor = np.concatenate(shards, axis=-1)
print("Full tensor shape: ", full_tensor.shape)
print("Full tensor: ")
print(full_tensor)
print("Shard (column wise) shape: ", shards[0].shape)
for s in shards:
    print(s)
print("Matmul full")
print(x_data @ full_tensor)
print("Matmul sharded")
y_shards = []
for i,s in enumerate(shards):
    y_shard = x_data @ s
    print(f"   y_shard_{i}: ", y_shard)
    y_shards.append(y_shard)
print("Gathered: ", np.concatenate(y_shards, axis=-1))

Full tensor shape:  (2, 4)
Full tensor: 
[[1. 1. 2. 2.]
 [1. 1. 2. 2.]]
Shard (column wise) shape:  (2, 2)
[[1. 1.]
 [1. 1.]]
[[2. 2.]
 [2. 2.]]
Matmul full
[2. 2. 4. 4.]
Matmul sharded
   y_shard_0:  [2. 2.]
   y_shard_1:  [4. 4.]
Gathered:  [2. 2. 4. 4.]


In [10]:
ir = popxl.Ir(replication=TP*DP)


with ir.main_graph:
    x = popxl.constant(np.ones((in_features)), popxl.float32)
    facts, linear = Linear(out_features, TP, DP).create_graph(x)
    print("replica groupings in factories: ")
    print(facts.replica_groupings)
    print("assignment for weight: ")
    print(facts.replica_groupings.weight.assignment)
    vars = facts.init()
    y, = linear.bind(vars).call(x)

with popxl.Session(ir,'ipu_hw') as session:
    session.run()

sharded weight shape:  (2, 2)
y shard shape:  (2,)
y full shape:  (4,)
replica groupings in factories: 
NamedReplicaGrouping({
weight: ReplicaGrouping(num_replicas=4, stride=2, group_size=2, num_groups=2),
})
assignment for weight: 
[0, 1, 0, 1]


w_shards_replica_0: [
 [1.0 1.0]
 [1.0 1.0]
]

w_shards_replica_1: [
 [2.0 2.0]
 [2.0 2.0]
]

w_shards_replica_2: [
 [1.0 1.0]
 [1.0 1.0]
]

w_shards_replica_3: [
 [2.0 2.0]
 [2.0 2.0]
]

y_shards_replica_0: [2.0 2.0]

y_shards_replica_1: [4.0 4.0]

y_shards_replica_2: [2.0 2.0]

y_shards_replica_3: [4.0 4.0]

y_full_replica_0: [2.0 2.0 4.0 4.0]

y_full_replica_1: [2.0 2.0 4.0 4.0]

y_full_replica_2: [2.0 2.0 4.0 4.0]

y_full_replica_3: [2.0 2.0 4.0 4.0]

