## PyTorch for Scientists!
by **nikitaved@github**


A rough sketch of things covered:
1. Memory <br>
a. Topology <br>
b. Peformance implications <br>
c. ~Broadcasting~
3. ~Reproducibility related to determinism and nondeterminism~
4. ~Basics of high-level optimizations on Computational Graphs~
5. ~Other things?~

**IMPORTANT NOTE**: to, hopefully, enjoy the content even more, I recommend trying to forget anything you know and just follow the definitions. For example, if you know about contiguous tensors and their stride structure implications, forget about it before these things are introduced in the text.

**Python/PyTorch** versions used

In [1]:
import sys
import torch
torch.manual_seed(13)
print(f"Python version: {'.'.join(map(str, sys.version_info[:2]))}")
print(f"PyTorch version: {torch.__version__}")

Python version: 3.13
PyTorch version: 2.8.0.dev20250521+cu128


## PyTorch in a nutshell (not exhaustive)

1. Tensors

2. Operations over Tensors

3. Computational Graphs and operations with them (this includes forward/backward)

## Tensors

For simplicity, we are talking about the standard and very familiar, so-called, **strided** tensors (as opposed to sparse, nested and other types).

A **strided** tensor is a N-dimensional array with the following basic methods/attributes (the list is not exhaustive)

* Allocated data or memory storage (accessible via `storage()`)

* Underlying local (i.e. device-specific) memory topology meta-data encoded in `shape` and `stride()`

* Other meta-data such as `device`, `dtype` and many more...

## Memory topology of a strided tensor

Memory is linear. **How do we map a N-dimensional index to a specific memory address for individual element accesses** then?

For a tensor `t`, `t.data_ptr()` returns an address in memory at which the data of `t` begins. <br>
When we access an element $\text{t}[i_0, \dots, i_{N - 1}]$, 
we read `t.dtype.itemsize` bytes at address<br> $\big(\text{t.data\_ptr}() + \text{t.dtype.itemsize} \ast \sum_{j=0}^{N-1} i_j \ast \text{t.stride}()[j]\big)$<br>
**NOTE**: the strides are assumed to be non-negative!

Examples:

In [2]:
def memaddr_from_idx(t: torch.Tensor, idx: tuple) -> int:
    assert len(idx) == t.ndim and isinstance(idx, tuple)
    assert all(0 <= i < t.shape[d] for d, i in enumerate(idx))
    mem_offset = 0
    for d, i in enumerate(idx):
        mem_offset += i * t.stride()[d]
    mem_offset *= t.dtype.itemsize
    return t.data_ptr() + mem_offset

def value_from_memaddr(memaddr: int, c_type_name: str):
    import ctypes
    assert hasattr(ctypes, f"c_{c_type_name}")
    value = getattr(ctypes, f"c_{c_type_name}").from_address(memaddr).value
    return value

def idx_to_value(t: torch.Tensor, idx: list, c_type_name: str):
    memaddr = memaddr_from_idx(t, idx)
    value = value_from_memaddr(memaddr, c_type_name)
    return value
    
t = torch.rand(24, dtype=torch.float32).reshape(2, 3, 4)

import itertools
for i, idx in enumerate(itertools.product(*(range(d) for d in t.shape))):  # Generate all possible indices into t
    assert t[*idx].item() == idx_to_value(t, idx, "float")  # NOTE: this is a bitwise equality!
assert i == t.numel() - 1  # Check we looped through all the indices

Let `t` be a Tensor, then

$
\text{memset(t)} = \{ \text{t.data\_ptr}() + k + \text{t.dtype.itemsize} \ast \sum_{j=0}^{N-1} i_j \ast \text{t.stride}()[j] : 0 \le k < \text{t.dtype.itemsize}, 0 \le i_j < \text{t.shape}[j]\},\\
$

is its **memory set** or, simply, the set of addresses (in bytes) spanning the data. Do **not overlook** dependence upon **shape** and **stride**s. It is also helpful to think of **stride**s as a "generating set".

For two tensors `t1` and `t2`, `t2` is a **view** of `t1`, if $\text{memset(t2)} \subseteq \text{memset(t1)}$.

**View** are very important and PyTorch will try to create them whenever possible (unless copy is needed), because:
* no **memory allocation** is needed,
* no **kernel launch** is needed to populate that data with the exception of **in-place** operations.<br>
  **Caveat**: only at the time of tensor creation, not necessarily tensor materialization!

A **view** is trivial, if it *can* be implemented without any **kernel launch**es (**NOTE**: unless copy is a specified semantics!),<br>
i.e. the result is just a munipulation over the meta-data like `data_ptr`, `strides`, `shape` and others which **do not read the underlying data memory**, but pointer arithmetic with `data_ptr` is fine.

**Question**: when we run `t2 = op(t1)` for some PyTorch operation, how certain are we to get a **view** `t2`?

**Answer**:
* when `op` is **in-place**, we are certain that `t2` will be exactly `t1` (albeit not necessarily trivial),
* otherwise when `op` *can* be implemented by finding approprite `data_ptr`, `shape` and `strides` (and maybe even `dtype`) for `t2` such that $\text{memset(t2)} \subseteq \text{memset(t1)}$, i.e. when `op` is a trivial **view**,
* even if `op` *can* be implemented as a trivial **view**, it does not mean it is implemented as such.<br>
  One should **always check documentation** and **create issues if there are inconsistencies**!
  

### View example: accessing a single element

Accessing a single element in a N-dim tensor, i.e. $\bf{\textbf{t}[i_1, \dots, i_{N}]}$, which is a 0-dim tensor, is a trivial **view**. **Why?**

Answer:
* set `data_ptr` to the address of the element index points to,
* set `strides` and `shape` to `()`, i.e. empty.

In [3]:
def mem_offset(idx: tuple, strides: tuple):
    assert isinstance(idx, tuple) and isinstance(strides, tuple)
    assert len(idx) == len(strides)
    mem_offset = 0
    for i, s in zip(idx, strides):
        mem_offset += i * s
    return mem_offset

In [4]:
x = torch.arange(24, device="cuda").reshape(2, 3, 4)

idx = (0, 2, 2)
y = x[*idx]
print(f"{y=}, {y.shape=}. {y.stride()=}")

y=tensor(10, device='cuda:0'), y.shape=torch.Size([]). y.stride()=()


In [5]:
y_data_ptr_estimate = x.data_ptr() + x.dtype.itemsize * mem_offset(idx, x.stride())
print(y_data_ptr_estimate == y.data_ptr())

True


**Critical takeaway**: modifying that single element in-place will alter the initial data!

**Isn't this behavior a bit weird?** Maybe, depends on the underlying semantics and how consistent it is with other indexing patterns. I can only speculate, but it seems like in this case the developers might have found it hard to justify launching a CUDA kernel for copying just a single element. A general rule of thumb: if certain things appear to be weird and inconsistent with, for example, NumPy, it could be a tradeoff decision favouring GPU performance.

### View example: slicing along a dimension
Slicing along a dimension $\bf{d}$, i.e. $\bf{\textbf{res} = \textbf{t}[\dots, \textbf{start}_d:\textbf{end}_d:\textbf{step}_d, \dots]})$, is a trivial **view**. **Why?**

`res.data_ptr` of the result should be set to

* `res.data_ptr = t.data_ptr + t.dtype.itemsize * (start * t.strides[d])`

`res.shape` should be set to

* `res.shape = t.shape[::]; res.shape[d] = len(range(start, end, step))`

`res.strides` should be set to

* `res.strides = t.strides[::]; res.strides[d] = step * t.strides[d]`

In [6]:
x = torch.rand(20, 30, 40).cuda()
slice_test = slice(2, 15, 3)  # start=2, stop=15, step=3
slice_lenght = len(range(2, 15, 3))  # slice object does not have a lenght
for d in range(x.ndim):  # testing our hypothesis for data_ptr, shape and strides for all dims
    idx = [slice(0, x.shape[i]) for i in range(x.ndim)]
    idx[d] = slice_test
    y = x[*idx]
    assert y.data_ptr() == x.data_ptr() + x.dtype.itemsize * (slice_test.start * x.stride()[d])
    assert all(
        s1 == s2 if dim != d else s1 == slice_lenght
        for dim, (s1, s2) in enumerate(zip(y.shape, x.shape))
    )
    assert all(
        s1 == s2 if dim != d else s1 == slice_test.step * x.stride()[d]
        for dim, (s1, s2) in enumerate(zip(y.stride(), x.stride()))
    )

#### Slicing along a dim: a peculiar observation

Take a look at

In [7]:
x = torch.rand(20, 20, 20)  # NOTE: uniform from [0, 1)
slice_dim = slice(0, 5)
y = x[slice_dim, :, :]
x[0, 0, 0] = 10.
print(f"{x[0, 0, 0]=}, {y[0, 0, 0]=}")

x[0, 0, 0]=tensor(10.), y[0, 0, 0]=tensor(10.)


It just agrees with our trivial **view** assumptions tested above.

Now take a look at:

In [8]:
x = torch.rand(20, 20, 20)  # NOTE: uniform from [0, 1)
range_dim = range(0, 5)
slice_dim = slice(0, 5)
y_range = x[range_dim, :, :]
y_slice = x[slice_dim, :, :]
assert torch.all(y_range == y_slice)  # the values are the same for range and slice
x[0, 0, 0] = 10.  # NOTE: this value is definitely out of [0, 1) range
print(f"{x[0, 0, 0]=}, {y_range[0, 0, 0]=}")  # NOTE: y_range is not a view of x. Modifying x did not change y_range!

x[0, 0, 0]=tensor(10.), y_range[0, 0, 0]=tensor(0.3641)


A switch from **slice** to **range** breaks semantic consitency.
At least one of these variants is incorrect.<br>
**Finding** such issues and **reporting** them on GitHub helps the community!
If the issue is there, it never hurts to chime in anyway - that can speed up resolution! 

### View example: permutation of dimensions
Suppose `t` is a N-dimensional tensor and $\pi$ is a permutation of N elements.<br>
Consider a tensor `res` defined as $\bf{(\textbf{res})_{i_1, \dots, i_N} = (\textbf{t})_{i_{\pi(1)}, \dots, i_{\pi(N)}}}$, then `res` is a trivial **view** of `t`. **Why?**

`res.data_ptr` = ?

* = `t.data_ptr`

`res.shape` = ?

* = `[t.shape[perm[i] for i in range(N)]`

`res.strides` = ?

* = `[t.strides[perm[i] for i in range(N)]`

In [9]:
def apply_perm(perm: tuple, vals: tuple) -> tuple:
    assert len(perm) == len(vals)
    perm_range = frozenset(perm)
    assert len(perm_range) == len(perm) and min(perm_range) == 0 and max(perm_range) == len(perm) - 1
    return tuple(vals[perm[i]] for i in range(len(perm)))
                 
x = torch.rand(2, 3, 4)
for i, perm in enumerate(itertools.permutations(range(x.ndim))):
    y = x.permute(perm)
    assert y.data_ptr() == x.data_ptr()
    assert y.shape == apply_perm(perm, x.shape)
    assert y.stride() == apply_perm(perm, x.stride())
assert i == (3 * 2 * 1 - 1)  # check we looped through all possible dim permutations

### View example: diagonal of a (square) matrix
Suppose `t` is a tensor of shape `(n, n)`, then the 1-dimensional `res`, defined as $\bf{(\textbf{res})_i = (\textbf{t})_{ii}}$, is a trivial **view**. **Why?**

`res.data_ptr` = ?

* = `t.data_ptr`

`res.shape` is trivial. What about `res.stride`?

* When moving from $\bf{t_{ii}}$ to $\bf{t_{i+1, i+1}}$, we advance once along the colums with `t.strides[-1]`, and once along the rows with `t.strides[-2]`. That gives us the total offset of `t.stride[-1] + t.stride[-2]`. Since this offset is independent of $\bf{i}$, this must be the stride value we are looking for!

In [10]:
x = torch.rand(5, 5, 5)
for perm in itertools.permutations(range(x.ndim)):  # All permutations to test all kinds of interesting stride values
    y = x.permute(perm)[..., 0]  # Permute and then index to get a 2d tensor
    y_diag = y.diagonal(dim1=-2, dim2=-1)
    assert y_diag.ndim == 1
    assert y_diag.data_ptr() == y.data_ptr()
    assert y_diag.shape[0] == y.shape[-1]
    assert y_diag.stride() == (y.stride()[-1] + y.stride()[-2],)

### Contiguous tensors
There is a special class of **strided** tensors in PyTorch, they are *canonical*, in a way. These are so-called **contiguous** tensors.<br>
**NOTE**: most tensor factory methods, like `ones`, `zeros`, `empty`, create **contiguous** tensors by default.

A tensor `t` is **contiguous**, if it is empty (i.e. there are dims of size 0) or if it is 0-dimensional (a scalar) or if it is N-dimensional with N $\ge$ 1 and

$
\text{memset(t)} = \text{t.data\_ptr()} + \{0, 1, 2, \dots, \text{t.dtype.itemsize} * \text{t.numel()} \},
$

so $\text{memset(t)}$ does not have "holes" (and we also call it **contiguous**), and, additionally, `t` has the following stride structure:

$
\begin{cases}
\text{t.strides}[-1] &= 1, \\
\text{t.strides}[-i - 1] &= \text{t.shape[-i]} * \text{t.strides[-i]}, \text{ for } i \in \{1, \dots, \text{t.ndim}\}.
\end{cases}
$

**NOTE**: PyTorch only assumes non-negative strides!

In [11]:
t = torch.rand(2, 3, 4, 5)
print(t.stride())
assert t.is_contiguous()

def check_contiguous(t):
    if t.numel() == 0:  # Emtpy is contiguous
        return True
    if t.ndim == 0:  # Scalar is contiguous
        return True
    isc = (t.stride(-1) == 1)
    for i in range(1, t.ndim):
        isc = (t.stride(-i - 1) == t.shape[-i] * t.stride(-i)) and isc
    return isc

assert check_contiguous(t) == True
assert check_contiguous(t.transpose(-1, -2)) == False
assert check_contiguous(t.transpose(0, -1)) == False

# check empty tensor
t = torch.rand(0, 1, 0)
assert t.is_contiguous() == check_contiguous(t)

# check scalar tensor
t = torch.rand(())
assert t.ndim == 0 and t.is_contiguous() == check_contiguous(t)

(60, 20, 5, 1)


**NOTE**: If `t` is non-empty and not a scalar, so it is of dimension N $\ge$ 1, then `t.strides` is non-increasing as per presented definition.<br>
This implies that if $i$ and $j$ are two N-dimensional indices such that $i < j$ lexicographically, then $\text{t[*i]}.\text{data\_ptr()} < \text{t[*j]}.\text{data\_ptr()}$.

#### An immediately useful property of contiguous tensors
Any **reshape** (or any other shape-based operation that preserves `numel`) of a **contiguous** tensor is **contiguous** and is a **view**.

In [12]:
x = torch.rand(24)
y = x.reshape(12, 2)
z = x.reshape(2, 3, 4)
assert x.is_contiguous() == y.is_contiguous() == z.is_contiguous()
x[0] = 10.
assert 10. == x[0] == y[0, 0] == z[0, 0, 0]

**NOTE**: there are cases when **reshape** returns a view of a **non-contiguous** tensor. I recommend checking **torch.view** in the documentation.

### View Summary
* **view** operations, unless **in-place**, require $\Theta(1)$ memory and take $\Theta(1)$ time.<br>
  **Caveat**: this means tensor creation, not necessarily tensor materialization!
* one can foresee memory managment techniques with allocating sufficiently large buffers and then creating **view**s from them.
* one has to be careful with modifying tensors **in-place** - it may change the data of all other tensors they share storage with.

## Memory and performance implications

Now that we are a bit familiar with the **stride**d arrays, let's see how memory and stride structure impact performance.

### Dimension and index hierarchy
It is very helpful to understand how PyTorch traverses memory internally.<br>
For example, if `t` is a N-dimensional tensor, how does PyTorch traverse `t` to, for example, compute something like `t.add_(3)`?

Suppose `t` has strides $(s_1, \dots, s_{N})$ and $\pi$ is a dimension-sorting permutation such that

$
\begin{align}
s_{\pi^{-1}(1)} \ge \dots \ge s_{\pi^{-1}(N)}, \text{ with } [s_{\pi^{-1}(i)} = s_{\pi^{-1}(j)}] \implies [\pi^{-1}(i) < \pi^{-1}(j)],
\end{align}
$

which reads as *sort dimensions by stride value in decreasing order, and within groups of equal strides sort by the dimension index in ascending order*.<br>

For two N-dim indices $i=(i_1, \dots, i_N)$ and $j=(j_1, \dots, j_N)$, let $k$ be the largest integer such that $i_{\pi^{-1}(k)} \neq j_{\pi^{-1}(k)}$, then <br>
$
  \begin{equation}
    \begin{cases}
      i < j, & \text{if}\ i_{\pi^{-1}(k)} < j_{\pi^{-1}(k)}, \\
      i > j, & \text{if}\ i_{\pi^{-1}(k)} > j_{\pi^{-1}(k)}, \\
      i = j, & \text{if no such}\ k \ \text{exists}.
    \end{cases}
  \end{equation}
$

This index total order defines how PyTorch traverses memory.
**NOTE**: if `t` is **contiguous**, then $\pi = \text{id}$.

**In simple words**, PyTorch traverses dimensions from smallest to largest stride. This order enforces a very important property of<br>
**Memory Locality**: $i < j \implies \text{t[*i]}.\text{data\_ptr()} \le \text{t[*j]}.\text{data\_ptr()}$ (**NOTE**: $\le$, it becomes $<$ for **contiguous** tensors, $\le$ is relevant in the context of *broadcasting*).<br>

### Why does PyTorch do that?

**Memory Locality**, i.e. $i < j \implies \text{t[*i]}.\text{data\_ptr()} \le \text{t[*j]}.\text{data\_ptr()}$),<br> is **incredibly** important, because

* CPU - accessing addresses close to each other reduces chances of cache misses.
* GPU - increased chance of DRAM bursts when threads in a warp initiate coalesced memory access (contiguous memory chunks).<br>
  Hitting the same global memory address over and over again is fine because all global memory accesses are cached.<br>
  I recommend reading and revisiting [CUDA C Best Practices Guide](https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/).<br>
  Memory access part is [this](https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/#coalesced-access-to-global-memory) one.    **Important**: pay attention to memory alignment!

### Performance implications: Unary Element-wise Operations
**Semantics**: $\bf{\textbf{res}_{i_1, \dots, i_N} = \textbf{op}(\text{t}_{i_1, \dots, i_N})}$.

**Question**: taking into account the semantics of unary operations and **Memory Locality**, which memory layout is likely (i.e. it is sufficient, but not necessary) to be the best for performance?

**Answer**: **contiguous** $\text{memset(t)}$. This is guaranteed when, for example, the sorted array of `t.strides()` in descending order is that of a **contiguous** tensor.<br>
**NOTE**: it could be helpful to see **contiguous** tensors as the tensors with $\text{memset(t)}$ of the minimal *diameter* with the bijective relationship between the indices $i_1, \dots, i_N$ and the addesses of the elements $\text{t}[i_1, \dots, i_N]$ with additional structure of the indices being *lexicographically* sorted. From that is also easy to derive the stride structure which we used above in the definition:

$
\begin{cases}
\text{t.strides}[-i] &= 1, \\
\text{t.strides}[-i - 1] &= \text{t.shape[-i]} * \text{t.strides[-i]}, \text{ for } i \in \{1, \dots, \text{t.ndim}\}.
\end{cases}
$

In [13]:
t_contig = torch.rand(512, 512, 512)
t_ncontig = t_contig.permute(2, 1, 0)
assert t_ncontig.is_contiguous() == False

%timeit t_contig.add(3)
%timeit t_ncontig.add(3)

t_ncontig = torch.rand(512, 512, 512, 5)[..., 0]
assert t_ncontig.shape == (512, 512, 512) and t_ncontig.stride()[-1] == 5
%timeit t_ncontig.add(3)

del t_contig, t_ncontig

49.1 ms ± 54.3 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
50.4 ms ± 703 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
77.4 ms ± 1.87 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [14]:
def warmup(f, n_runs: int = 5):
    """
    Do warmup runs of `f` assuming it does compute on the GPU.
    Run it before profiling just to make sure that the CUDA context
    is properly initialized.
    """
    assert n_runs >= 0
    for i in range(n_runs):
        f()
    torch.cuda.synchronize()

t_contig = torch.rand(512, 512, 512, device="cuda")
t_ncontig = t_contig.permute(2, 1, 0)
assert t_ncontig.is_contiguous() == False

warmup(lambda: t_contig.add(3))
%timeit t_contig.add(3); torch.cuda.synchronize()
del t_contig

warmup(lambda: t_ncontig.add(3))
%timeit t_ncontig.add(3); torch.cuda.synchronize()

t_ncontig = torch.rand(512, 512, 512, 5).cuda()[..., 0]
assert t_ncontig.shape == (512, 512, 512) and t_ncontig.stride()[-1] == 5
warmup(lambda: t_ncontig.add(3))
%timeit t_ncontig.add(3); torch.cuda.synchronize()

# An interesting example where having "holes" in memset does not reduce performance by much!
# This implies that contiguous memset is not necessary for performance!
# Why is that? The answer is in https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/#coalesced-access-to-global-memory.
t_ncontig = torch.rand(512, 512, 32, 512).cuda()[..., 0, :]
assert t_ncontig.shape == (512, 512, 512) and t_ncontig.stride()[-2] == 32 * 512 and not t_ncontig.is_contiguous()
warmup(lambda: t_ncontig.add(3))
%timeit t_ncontig.add(3); torch.cuda.synchronize()

del t_ncontig
torch.cuda.empty_cache()

1.5 ms ± 161 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
1.5 ms ± 120 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
4.38 ms ± 1.63 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.53 ms ± 279 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### Performance implications: Reductions
**Semantics**: $\bf{\textbf{res}_{i_1, \dots, i_{d-1}, i_{d+1}, \dots i_N} = \textbf{op}_d(\text{t}_{i_1, \dots, i_{d-1}, 0, i_{d+1}, \dots i_N},\textbf{t}_{i_1, \dots, i_{d-1}, 1, i_{d+1}, \dots i_N}, \dots, \textbf{t}_{i_1, \dots, i_{d-1}, \textbf{t.shape[d]-1}, i_{d+1}, \dots i_N})}$

Examples of such operations: `max/min`, `sort`, `mean`, `var` and many others! They always have a `dim` argument.

**Question**: taking into account the semantics of unary operations and **Memory Locality**, which memory layout is the best for performance?

**Answer**: `t.stride[d]` should be `1`.

In [15]:
x = torch.rand(1024, 1024)
%timeit x.max(-1)
%timeit x.max(-2)

x = x.cuda()

warmup(lambda: x.max(-1))
%timeit x.max(-1); torch.cuda.synchronize()

warmup(lambda: x.max(-2))
%timeit x.max(-2); torch.cuda.synchronize()

48.3 μs ± 1.75 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
318 μs ± 874 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
18 μs ± 30.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
26.4 μs ± 15.1 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


### Performance implications: Binary Element-wise Operations
**Semantics**: $\bf{\textbf{res}_{i_1, \dots, i_N} = \textbf{op}(\textbf{x}_{i_1, \dots, i_N}, \textbf{y}_{i_1, \dots, i_N})}$.

**Question**: taking into account the semantics of unary operations and **Memory Locality**, which memory layout is likely (i.e. it is sufficient, but not necessary) to be the best for performance?

**Answer**: same as for unary operations, but, additionally, it is best for `x` and `y` to have the same strides (**Why?**)!

In [16]:
x = torch.rand(512, 512, 512)
y = torch.rand(512, 512, 512)

%timeit x * y
%timeit x * y.transpose(0, 2)
%timeit x.transpose(0, 2) * y
%timeit x.transpose(0, 2) * y.transpose(0, 2)

55 ms ± 115 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
205 ms ± 97.7 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
206 ms ± 127 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
55.4 ms ± 504 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [17]:
x = x.cuda()
y = y.cuda()

f = lambda: x * y
warmup(f)
%timeit f(); torch.cuda.synchronize()

f = lambda: x * y.transpose(0, 2)
warmup(f)
%timeit f(); torch.cuda.synchronize()

f = lambda: x.transpose(0, 2) * y
warmup(f)
%timeit f(); torch.cuda.synchronize()

f = lambda: x.transpose(0, 2) * y.transpose(0, 2)
warmup(f)
%timeit f(); torch.cuda.synchronize()

del x
torch.cuda.empty_cache()

2.24 ms ± 443 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
19.7 ms ± 2.11 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
19.7 ms ± 3.24 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.25 ms ± 399 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)


**NOTE**: not having the same stride structure (i.e. `x.strides() != y.strides()`) does not imply that performance is going to be significantly worse. PyTorch performs all kinds of non-trivial optimizations to yield best performance. Most of them will try to improve **Memory Locality** in the access patterns. For example, try replacing `transpose(0, 2)` with `transpose(0, 1)` or `transpose(1, 2)` and see what happens!

#### Strides of the output
We have seen that the stride of the inputs has performance implications. Now, suppose `res = binary_elemwise_op(x, y)`, then what is `res.stride()`?

In [18]:
t1 = torch.rand(4, 4, 4)
t2 = torch.rand(4, 4, 4)

all_dim_perms = list(itertools.permutations(range(t1.ndim)))
for device_type in ("cpu", "cuda"):
    t1 = t1.to(device=device_type)
    t2 = t2.to(device=device_type)
    assert t1.device.type == t2.device.type == device_type
    
    # Let us exhaustively loop through all permutations and see what happens with the strides!
    for perm_idx1, perm_idx2 in itertools.product(range(len(all_dim_perms)), repeat=2):
        perm1 = all_dim_perms[perm_idx1]
        perm2 = all_dim_perms[perm_idx2]
    
        x = t1.permute(perm1)
        y = t2.permute(perm2)
    
        if perm1 != perm2:
            assert x.stride() != y.stride()
        else:
            # because x and y are then both contiguous
            assert x.stride() == y.stride()
    
        xy = x * y
        yx = y * x
    
        assert torch.all(xy == yx).item()
        assert xy.stride() == x.stride()
        assert yx.stride() == y.stride()
        
assert perm_idx1 == perm_idx2 == len(all_dim_perms) - 1  # Sanity check for the last element in the Cartesian product

So, at least for `mul`, `res.stride() == <first argument>.stride()`. One can also see that `x * y` is not equivalent to `y * x` in terms of the strides of the output.<br>
**Is this properly documented?** Again, one should always consult with the documentation. If it is not there, create a pull request/issue and help the community!

### Performance implications: Matrix Multiplication
**Semantics**: $\bf{\textbf{res}_{ij} = \sum_k \textbf{a}_{ik} \ast \textbf{b}_{kj}}$. For simplicity, we assume that the arguments are 2-dimensional only!


Suppose `a` is a matrix of shape `(m, n)`. Then we say it is stored in the
* **Row-Major** (or C-contiguous) format, if $a$ is **contiguous**.
* **Column-Major** (or F-contiguous) format, if $a^T$ is **contiguous**.

<img src="colrow_major.png" width="500" height="300">

**Question**: what are the strides of a row/column-major matrix?

**Answer**: `(n, 1)` for a row-major, and `(1, m)` for a column-major. You will need these when writing matrix operations in CUDA!

**Question**: taking into account the semantics of matrix multiplications and **Memory Locality**, what are the *best* memory layouts for inputs `a` and `b` to compute their matrix product `a @ b`, if implemented *naively*?

**Answer**: `a` is row-major, `b` is column-major.

PyTorch, however, dispatches to quite fast routines that, ideally, should perform almost the same for all combination of row/column-major inputs.

In [19]:
a = torch.rand(2048, 2048)
b = torch.rand(2048, 2048)

%timeit a @ b
%timeit a @ b.mT
%timeit a.mT @ b
%timeit a.mT @ b.mT

14.5 ms ± 21.4 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
14.7 ms ± 64.2 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
16.5 ms ± 29 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
16.6 ms ± 31 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [20]:
a = a.cuda()
b = b.cuda()

f = lambda: a @ b
warmup(f)
%timeit f(); torch.cuda.synchronize()

f = lambda: a @ b.mT
warmup(f)
%timeit f(); torch.cuda.synchronize()

f = lambda: a.mT @ b
warmup(f)
%timeit f(); torch.cuda.synchronize()

f = lambda: a.mT @ b.mT
warmup(f)
%timeit f(); torch.cuda.synchronize()

del a, b
torch.cuda.empty_cache()

833 μs ± 1.26 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
981 μs ± 1.47 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
797 μs ± 2.15 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
825 μs ± 1.55 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


#### Low-Rank matrix product on the GPU
Suppose `a` is of shape `(m, k)` and `b` is of shape `(k, n)`, and `k` is relatively small.<br>
How does `a @ b` perform with different values of `k`?

In [21]:
def get_matrices(k):
    x = torch.rand(2048, k, device="cuda")
    y = torch.rand(k, 2048, device="cuda")
    return x, y

x, y = get_matrices(2)
f = lambda: x @ y

warmup(f)
%timeit f(); torch.cuda.synchronize()

x, y = get_matrices(4)
warmup(f)
%timeit f(); torch.cuda.synchronize()

x, y = get_matrices(8)
warmup(f)
%timeit f(); torch.cuda.synchronize()

x, y = get_matrices(16)
warmup(f)
%timeit f(); torch.cuda.synchronize()

x, y = get_matrices(32)
warmup(f)
%timeit f(); torch.cuda.synchronize()

del x, y
torch.cuda.empty_cache()

38.6 μs ± 36.1 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
38.6 μs ± 31.7 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
37.8 μs ± 50.5 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
40.8 μs ± 33.3 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
40.8 μs ± 58.7 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


**Takeaway**: PyTorch is a general-purpose library which is not optimizes for all types of shapes (especially for tall and wide matrices)! Use reasonable memory layout for matrix multiplications and use small rank operations with caution!

### Memory and Performence implications Summary
* Stride structure of the inputs might have a huge impact on performance of operations.
* If that is the case, sometimes making a copy with a suitable stride structure could be of benefit if you can tolerate penatly of memory allocation and copy.
* **Contiguous** memory format, or row/column-major for matrices, is probably your best friend most of the time. If one does not deviate much from them, it is pretty hard to improve your program with direct methods (unless there is some room in the backward pass with custom `nn.Module`s) and one need to resort to some high-level advanced techniques operating on the Computational Graph itself, i.e. the so-called Deep Learning Compilers like `torch.compile` and others.

### We have just scratched the surface...

PyTorch is immense, and there is always something more to tell about it.<br>

If you liked what you saw here, the best thing you can do is to share this notebook with your colleagues and friends!

Also, feel free to
* create an issue on GitHub if there is an issue,
* create an issue on GitHub if there is a topic you would like to know more about,
* star the repository.

Thank you and have fun learning High Performance Computing with PyTorch!<br>
nikitaved@github