# Tensor Parallelism
In this session, we will learn about Tensor Parallelism.

# 1. Intra-layer Model Parallelism

Tensor Parallelism is a form of **intra-layer model parallelism**, where the model is split at the *tensor level within a layer*.  
Inter-layer model parallelism is generally intuitive, but intra-layer parallelism can be harder to understand at first.

![Intra-layer Parallelism](../images/intra_layer.png)


## Why Tensor Parallelism Works

Matrix multiplication has a key property:

Matrices can be split, computed independently, and then summed or concatenated without changing the final output.

This makes it possible to parallelize computation within a single layer.


## What is Tensor Parallelism?

Tensor Parallelism takes advantage of this property by splitting tensors across GPUs and computing parts of the operation in parallel.


## Terminology Clarification

The terminology can be confusing:

- **Intra-layer Parallelism**  
  Any type of parallelism that happens *within a layer*.

- **Tensor Parallelism**  
  A specific implementation of intra-layer parallelism using tensor slicing and distributed computation.


### In Short

**Tensor Parallelism = splitting tensors inside a layer to parallelize computation across devices.**


Implementation of **intra-layer model parallelism** and is one of the most important in large-scale model training today.

<img src="../images/megatron_lm.jpeg" width=540>



## Column & Row Parallelism

Below are illustrations of **Column Parallelism** and **Row Parallelism** used in Megatron-LM.

- **Column Parallelism**  
  The weight matrix `A` is split **vertically** into submatrices `(A₁, A₂)`.

- **Row Parallelism**  
  The weight matrix `A` is split **horizontally** into submatrices `(A₁, A₂)`.

![Column & Row Parallelism](../images/intra_layer_2.png)


## Let’s Implement It with Code

Now let’s test the idea in code.

First, we compute the matrix multiplication result of tensor `X` and tensor `A`.


In [None]:
"""
src/non_parallelism.py
"""

import torch

X = torch.tensor(
    [
        [0, 1, 2, 3],
        [4, 5, 6, 7],
    ]
)

A = torch.tensor(
    [
        [10, 14],
        [11, 15],
        [12, 16],
        [13, 17],        
    ]
)

Y = X @ A

print(Y)

tensor([[ 74,  98],
        [258, 346]])


### Column Parallelism

In Column Parallelism, the weight matrix `A` is split **vertically** and the result of each split is **concatenated after computation**.

As shown in the figure:
- Tensor `X` is **replicated** across devices
- Tensor `A` is split vertically into `(A₁, A₂)`
- Each device computes:
  


In [2]:
"""
src/column_parallelism.py
"""

import torch

X = torch.tensor(
    [
        [0, 1, 2, 3],
        [4, 5, 6, 7],
    ]
)

A1 = torch.tensor(
    [
        [10],
        [11],
        [12],
        [13],        
    ]
)

A2 = torch.tensor(
    [
        [14],
        [15],
        [16],
        [17],        
    ]
)

Y1 = X @ A1
Y2 = X @ A2

print(Y1)
print(Y2)

Y = torch.cat([Y1, Y2], dim=1)
print(Y)

tensor([[ 74],
        [258]])
tensor([[ 98],
        [346]])
tensor([[ 74,  98],
        [258, 346]])


We can confirm that the result **before and after parallelization is identical**.

Now let’s move on to **Row Parallelism**.


### Row Parallelism

In Row Parallelism, the weight matrix `A` is split **horizontally**, and the partial results are **summed across devices**.

As shown in the figure:
- Both `X` and `A` are split:


In [3]:
"""
src/row_parallelism.py
"""

import torch

X1 = torch.tensor(
    [
        [0, 1],
        [4, 5],
    ]
)

X2 = torch.tensor(
    [
        [2, 3],
        [6, 7],
    ]
)

A1 = torch.tensor(
    [
        [10, 14],
        [11, 15],      
    ]
)

A2 = torch.tensor(
    [
        [12, 16],
        [13, 17],        
    ]
)

Y1 = X1 @ A1
Y2 = X2 @ A2

print(Y1)
print(Y2)

Y = Y1 + Y2

print(Y)

tensor([[ 11,  15],
        [ 95, 131]])
tensor([[ 63,  83],
        [163, 215]])
tensor([[ 74,  98],
        [258, 346]])
