In [1]:
import torch
import torch.nn as nn
import torch.distributed as dist
import os

#### Tensor Parallelism

In [2]:
class ColumnParallelism(nn.Module):
    def __init__(self, in_features, out_features, rank, world_size):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        self.world_size = world_size

        self.out_features_per_rank = out_features // world_size

        self.weight = nn.Parameter(
            torch.randn(in_features, self.out_features_per_rank)
        )

        self.bias = nn.Parameter(
            torch.zeros(self.out_features_per_rank)
        )

    def forward(self, x):
        out = torch.matmul(x, self.weight) + self.bias

        return out

In [3]:
class RowParallelism(nn.Module):
    def __init__(self, in_features, out_features, rank, world_size):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        self.world_size = world_size

        self.in_features_per_rank = in_features // world_size

        self.weight = nn.Parameter(
            torch.randn(self.in_features_per_rank, self.out_features)
        )

        self.bias = nn.Parameter(
            torch.zeros(self.out_features)
        )

    def forward(self, x):
        out = torch.matmul(x, self.weight)
        dist.all_reduce(out, op=dist.ReduceOp.SUM)

        if self.rank == 0:
            out += self.bias

        return out

In [4]:
class TensorParallelismAttention(nn.Module):
    def __init__(self, n_heads, d_model, rank, world_size):
        super().__init__()
        
        self.n_heads = n_heads
        self.d_model = d_model
        self.rank = rank
        self.world_size = world_size

        self.head_dim = d_model // world_size
        self.n_heads_per_rank = n_heads // world_size

        self.qkv_proj = ColumnParallelism(
            in_features=d_model,
            out_features=d_model * 3,
            rank=rank,
            world_size=world_size
        )

        self.out_proj = RowParallelism(
            in_features=d_model,
            out_features=d_model,
            rank=rank,
            world_size=world_size
        )

    def forward(self, x):
        batch_size, seq_len, _ = x.shape

        qkv = self.qkv_proj(x).reshape(batch_size, seq_len, 3, self.n_heads_per_rank, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        scores = q @ k.transpose(-2, -1) / (self.head_dim ** 0.5)
        attn = nn.functional.softmax(scores, dim=-1)
        attn_output = attn @ v

        attn_output = attn_output.transpose(1, 2).contiguous().reshapse(batch_size, seq_len, self.n_heads_per_rank * self.head_dim)
        output = self.out_proj(attn_output)

        return output

In [5]:
class TensorParallelismMLP(nn.Module):
    def __init__(self, d_model, d_ff, rank, world_size):
        super().__init__()

        self.d_model = d_model
        self.d_ff = d_ff
        self.rank = rank
        self.world_size = world_size

        self.fc1 = ColumnParallelism(
            in_features=d_model,
            out_features=d_ff,
            rank=rank,
            world_size=world_size
        )

        self.fc2 = RowParallelism(
            in_features=d_ff,
            out_features=d_model,
            rank=rank,
            world_size=world_size
        )

        self.activation = nn.GELU()

    def forward(self, x):
        out = x

        out = self.fc1(out)
        out = self.activation(out)
        out = self.fc2(out)

        return out

In [6]:
class TensorParallelismBlock(nn.Module):
    def __init__(self, n_heads, d_model, d_ff, rank, world_size):
        super().__init__()

        self.n_heads = n_heads
        self.d_model = d_model
        self.d_ff = d_ff
        self.rank = rank
        self.world_size = world_size

        self.attention = TensorParallelismAttention(
            n_heads=n_heads,
            d_model=d_model,
            rank=rank,
            world_size=world_size
        )

        self.mlp = TensorParallelismMLP(
            d_model=d_model,
            d_ff=d_ff,
            rank=rank,
            world_size=world_size
        )

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.dropout1 = nn.Dropout(0.1)
        self.dropout2 = nn.Dropout(0.1)

    def forward(self, x):
        out = x

        out = out + self.dropout1(self.attention(self.norm1(out)))
        out = out + self.dropout2(self.mlp(self.norm2(out)))

        return out

In [7]:
from torch.autograd import Function

In [8]:
class ColumnParallelismFunction(Function):
    @staticmethod
    def forward(ctx, input, weight, bias):
        ctx.saved_for_backward = input, weight

        out = torch.matmul(input, weight)
        if bias is not None:
            out += bias

        return out
    
    @staticmethod
    def backward(ctx, grad_input):
        input, weight = ctx.saved_for_backward

        grad_weight = torch.matmul(
            input.reshape(-1, input.shape[-1]).t(),
            grad_input.reshape(-1, grad_input.shape[-1])
        )

        grad_bias = grad_input.sum(dim=list(range(len(input.shape) - 1)))
        grad_output = torch.matmul(grad_input, weight.t())

        dist.all_reduce(grad_output, op=dist.ReduceOp.SUM)

        return grad_output, grad_weight, grad_bias, None, None

In [9]:
class ColumnParallelism(nn.Module):
    def __init__(self, in_features, out_features, rank, world_size):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        self.world_size = world_size

        self.out_features_per_rank = out_features // world_size

        self.weight = nn.Parameter(
            torch.randn(in_features, self.out_features_per_rank)
        )

        self.bias = nn.Parameter(
            torch.zeros(self.out_features_per_rank)
        )

    def forward(self, x):
        return ColumnParallelismFunction.apply(x, self.weight, self.bias)

In [10]:
class RowParallelismFunction(nn.Module):
    @staticmethod
    def forward(ctx, input, weight, bias):
        ctx.saved_for_backward = input, weight
        
        out = torch.matmul(input, weight)
        if bias is not None:
            out += bias

        dist.all_reduce(out, op=dist.ReduceOp.SUM)
        return out
    
    @staticmethod
    def backward(ctx, grad_input):
        input, weight = ctx.saved_for_backward

        grad_weight = torch.matmul(
            input.reshape(-1, input.shape[-1]).t(),
            weight.reshape(-1, weight.shape[-1])
        )

        grad_bias = grad_input.sum(dim=list(range(len(grad_input.shape) - 1)))
        grad_output = torch.matmul(grad_input, weight.t())

        return grad_output, grad_weight, grad_bias

In [11]:
class RowParallelism(nn.Module):
    def __init__(self, in_features, out_features, rank, world_size):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        self.world_size = world_size

        self.in_features_per_rank = in_features // world_size

        self.weight = nn.Parameter(
            torch.randn(self.in_features_per_rank, out_features)
        )

        self.bias = nn.Parameter(
            torch.zeros(self.out_features)
        )

    def forward(self, x):
        return RowParallelismFunction(x, self.weight, self.bias, self.rank, self.world_size)

#### Mixed Precision Training

In [12]:
import torch
import torch.nn
import os

In [13]:
class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()

        self.in_features = in_features
        self.hidden_features = hidden_features
        self.out_features = out_features

        self.layers = nn.Sequential(
            nn.Linear(in_features, hidden_features),
            nn.GELU(),
            nn.Linear(hidden_features, hidden_features),
            nn.GELU(),
            nn.Linear(hidden_features, out_features)
        )

    def forward(self, x):
        return self.layers

In [14]:
def train_without_amp():
    in_features = os.environ['in_features']
    out_features = os.environ['out_features']
    hidden_features = os.environ['hidden_features']

    X = torch.randn((1000, in_features))
    y = torch.randint(0, out_features, (1000,))
    dataloader = torch.utils.data.DataLoader(list(zip(X, y)), batch_size=32, shuffle=True)

    model = Model(in_features, hidden_features, out_features)
    fp32_master_weights = [p.to(torch.float32).detach() for p in model.parameters()]
    model = model.half()

    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(fp32_master_weights, lr=1e-3)

    n_epochs = os.environ['n_epochs']

    model.train()

    for epoch in range(n_epochs):
        for _, (batch, label) in enumerate(dataloader):
            optimizer.zero_grad()
            model.zero_grad()

            batch = batch.half()
            label = label.half()

            output = model(batch)
            loss = criterion(output)
            scaled_loss = loss * 8192. 
            scaled_loss.backward()

            for fp16_params, fp32_params in zip(model.parameters, fp32_master_weights):
                if fp32_params.grad is None:
                    fp32_params.grad = nn.Parameter(torch.empty_like(fp32_params))
                fp32_params.grad.data.copy_(fp16_params.grad.data)

            for fp32_params in fp32_master_weights:
                fp32_params.grad.data = fp32_params.grad.data / 8192. 

            optimizer.step()

            for fp16_params, fp32_params in zip(model.parameters, fp32_master_weights):
                fp16_params.data.copy_(fp32_params.data.half())

In [15]:
from torch.amp import GradScaler

In [16]:
def train_with_amp():
    in_features = os.environ['in_features']
    out_features = os.environ['out_features']
    hidden_features = os.environ['hidden_features']

    X = torch.randn((1000, in_features))
    y = torch.randint(0, out_features, (1000,))
    dataloader = torch.utils.data.DataLoader(list(zip(X, y)), batch_size=32, shuffle=True)

    model = Model(in_features, hidden_features, out_features)
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    scaler = GradScaler()

    n_epochs = os.environ['n_epochs']

    model.train()

    for epoch in range(n_epochs):
        for _, (batch, label) in enumerate(dataloader):
            optimizer.zero_grad()

            with torch.amp.autocast(device_type=torch.float16):
                output = model(batch)
                loss = criterion(output, label)

            scaler.scale(loss)
            scaler.step(optimizer)
            scaler.update()

#### NCCL

In [17]:
import os
import torch
import torch.distributed as dist

In [20]:
class NCCL:
    def __init__(self, rank, world_size):
        self.rank = rank
        self.world_size = world_size

        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'

        dist.init_process_group(
            backend='nccl',
            rank=rank,
            world_size=world_size
        )

        torch.cuda.set_device()

    def all_reduce_ring(self, tensor: torch.tensor):
        if self.world_size == 1:
            return tensor
        
        send_rank = (self.rank + 1) % self.world_size
        recv_rank = (self.rank - 1 + self.world_size) % self.world_size

        chunks = torch.chunk(tensor, chunks=self.world_size)
        recv_buffer = torch.zeros_like(chunks[0], dtype=torch.float32)

        for step in range(self.world_size - 1):
            send_chunk_idx = (self.rank + step) % self.world_size
            recv_chunk_idx = (self.rank - step - 1 + self.world_size) % self.world_size

            send_chunk = chunks[send_chunk_idx]
            recv_chunk = chunks[recv_chunk_idx]

            if recv_chunk.numel() != recv_buffer.numel():
                recv_buffer = torch.zeros_like(recv_chunk, dtype=torch.float32)

            send_handle = dist.isend(send_chunk, dst=send_rank)
            recv_handle = dist.irecv(recv_buffer, src=recv_rank)

            send_handle.wait()
            recv_handle.wait()

            recv_chunk.add_(recv_buffer)

        for step in range(self.world_size - 1):
            send_chunk_idx = (self.rank - step - 1 + self.world_size) % self.world_size
            recv_chunk_idx = (self.rank + step) % self.world_size

            send_chunk = chunks[send_chunk_idx]
            recv_chunk = chunks[recv_chunk_idx]

            if recv_chunk.numel() != recv_buffer.numel():
                recv_buffer = torch.zeros_like(recv_chunk, dtype=torch.float32)

            send_handle = dist.isend(send_chunk, dst=send_rank)
            recv_handle = dist.irecv(recv_buffer, src=recv_rank)

            send_handle.wait()
            recv_handle.wait()

            recv_chunk.copy_(recv_buffer)

        return tensor
    
    def broadcast_mst(self, tensor: torch.tensor, root: int):
        if self.world_size == 1:
            return tensor
        
        left_child = 2 * self.rank + 1
        right_child = 2 * self.rank + 2

        if self.rank != root:
            parent = (self.rank - 1) // 2
            dist.recv(tensor, src=parent)

        if left_child < self.world_size:
            dist.send(tensor, dst=left_child)
        if right_child < self.world_size:
            dist.send(tensor, dst=right_child)

        return tensor
    
    def scatter_mst(self, tensor: torch.tensor, root: int):
        if self.world_size == 1:
            return tensor
        
        def get_parent(rank):
            return (rank - 1) // 2
        
        def get_children(rank):
            left_child = 2 * rank + 1
            right_child = 2 * rank + 2

            children = []
            if left_child < self.world_size:
                children.append(left_child)
            if right_child < self.world_size:
                children.append(right_child)

            return children
        
        def get_subtree(rank):
            ranks = [rank]
            children = get_children(rank)
            for child in children:
                ranks.extend(get_subtree(child))

            return ranks
        
        chunks = torch.chunk(tensor, chunks=self.world_size)

        if self.rank != root:
            parent = get_parent(self.rank)
            subtree_ranks = get_subtree(self.rank)
            subtree_size = len(subtree_ranks)

            chunks = torch.zeros(chunks[0].numel() * subtree_size, dtype=torch.float32)
            dist.recv(chunks, src=parent)
            chunks = torch.chunk(chunks, chunks=subtree_size)

        my_chunk = chunks[self.rank]
        children = get_children(self.rank)

        for child in children:
            subtree_ranks = get_subtree(child)
            send_chunks = torch.cat([chunks[r] for r in subtree_ranks], dim=-1)
            dist.send(send_chunks, dst=child)

        return my_chunk

    def reduce_tree(self, tensor: torch.tensor, root: int, op = dist.ReduceOp.SUM):
        if self.world_size == 1:
            return tensor
        
        def get_parent(rank):
            return (rank - 1) // 2
        
        def get_children(rank):
            left_child = 2 * rank + 1
            right_child = 2 * rank + 2

            children = []
            if left_child < self.world_size:
                children.append(left_child)
            if right_child < self.world_size:
                children.append(right_child)

            return children
        
        parent = get_parent(self.rank)
        children = get_children(self.rank)
        recv_buffer = torch.zeros_like(tensor, dtype=torch.float32)

        for child in children:
            dist.recv(recv_buffer, src=child)
            tensor.add_(recv_buffer)

        if self.rank != root:
            dist.send(tensor, dst=parent)
            return None
        else:
            return tensor
        
    def gather_mst(self, tensor: torch.tensor, root: int):
        if self.world_size == 1:
            return tensor
        
        def get_parent(rank):
            return (rank - 1) // 2
        
        def get_children(rank):
            left_child = 2 * rank + 1
            right_child = 2 * rank + 2

            children = []
            if left_child < self.world_size:
                children.append(left_child)
            if right_child < self.world_size:
                children.append(right_child)

            return children
        
        def get_subtree_size(rank):
            size = 1
            for child in get_children(rank):
                size += get_subtree_size(child)

            return size
        
        children = get_children(self.rank)
        child_data = []
        
        for child in children:
            subtree_size = get_subtree_size(child)
            recv_buffer = torch.zeros_like(tensor.numel() * subtree_size, dtype=torch.float32)
            dist.recv(recv_buffer, src=child)
            child_data.append(recv_buffer)
        
        data = None
        if len(child_data) == 0:
            data = tensor
        elif len(child_data) == 1:
            data = torch.concat([child_data[0], tensor], dim=-1)
        else:
            data = torch.concat([child_data[0], tensor, child_data[1]], dim=-1)

        if self.rank != root:
            parent = get_parent(self.rank)
            dist.send(data, dst=parent) 
            return None
        else:
            return data

#### Data Parallelism

In [22]:
import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Adam

In [23]:
class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()

        self.in_features = in_features
        self.hidden_features = hidden_features
        self.out_features = out_features

        self.layers = nn.Sequential(
            nn.Linear(in_features, hidden_features),
            nn.GELU(),
            nn.Linear(hidden_features, hidden_features),
            nn.GELU(),
            nn.Linear(hidden_features, out_features)
        )

    def forward(self, x):
        return self.layers(x)

In [24]:
class DataParallelism:
    def __init__(self):
        self.rank = os.environ['RANK']
        self.world_size = os.environ['WORLD_SIZE']

        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'

        self.setup()

        self.in_features = os.environ['IN_FEATURES']
        self.hidden_features = os.environ['HIDDEN_FEATURES']
        self.out_features = os.environ['OUT_FEATURES']

        self.model = Model(
            in_features=self.in_features,
            hidden_features=self.hidden_features,
            out_features=self.out_features
        )

        self.ddp_model = DistributedDataParallel(
            module=self.model,
            device_ids=[self.rank],
            bucket_cap_mb=25,
            gradient_as_bucket_view=True
        )



    def setup(self):
        dist.init_process_group(
            backend='nccl',
            rank=self.rank,
            world_size=self.world_size
        )

        torch.cuda.set_device()

    def cleanup(self):
        dist.destroy_process_group()

    def train(self, n_samples):
        dataset = torch.randn(n_samples, self.in_features)
        labels = torch.randint(0, self.out_features, (n_samples,))

        sampler = DistributedSampler(
            dataset=dataset,
            num_replicas=self.world_size,
            rank=self.rank,
            shuffle=True
        )

        dataloader = DataLoader(
            list(zip(dataset, labels)),
            batch_size=32,
            sampler=sampler
        )

        n_epochs = os.environ['N_EPOCHS']
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.ddp_model.parameters(), lr=1e-3)

        for epoch in range(n_epochs):
            sampler.set_epoch(epoch)

            for _, (X, y) in enumerate(dataloader):
                X = X.to(self.rank)
                y = y.to(self.rank)

                optimizer.zero_grad()

                output = self.ddp_model(X)
                loss = criterion(output, y)
                loss.backward()

                optimizer.step()

#### Sequence Parallelism

In [25]:
import torch
import torch.nn as nn
import torch.distributed as dist

In [26]:
class SequenceParallelismLayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-15):
        super().__init__()

        self.d_model = d_model
        self.eps = eps

        self.alpha = nn.Parameter(torch.randn(d_model))
        self.beta = nn.Parameter(torch.randn(d_model))

    def forward(self, x: torch.tensor):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True)

        normalised = (x - mean) / torch.sqrt(var + self.eps)
        output = self.alpha * normalised + self.beta

        return output

In [27]:
class SequenceParallelismDropout(nn.Module):
    def __init__(self, p, rank):
        super().__init__()

        self.p = p
        self.rank = rank

        self.register_buffer('rng_seed', torch.zeros(1, dtype=torch.float32))

    def forward(self, x):
        if not self.training:
            return x
        
        if self.rank == 0:
            self.rng_seed.random_()

        dist.broadcast(self.rng_seed, src=0)
        seed_val = int(self.rng_seed.item())

        with torch.random.fork_rng(devices=[x.device]):
            torch.random.manual_seed(seed_val)

            out = nn.functional.dropout(x, p=self.p, training=True)
            return out

In [28]:
class ColumnParallelism(nn.Module):
    def __init__(self, in_features, out_features, rank, world_size):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        self.world_size = world_size

        self.out_features_per_rank = out_features // world_size

        self.weight = nn.Parameter(
            torch.randn(in_features, self.out_features_per_rank)
        )

        self.bias = nn.Parameter(
            torch.zeros(self.out_features_per_rank)
        )

    def forward(self, x):
        out = torch.matmul(x, self.weight) + self.bias

        return out

In [29]:
class RowParallelism(nn.Module):
    def __init__(self, in_features, out_features, rank, world_size):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        self.world_size = world_size

        self.in_features_per_rank = in_features // world_size

        self.weight = nn.Parameter(
            torch.randn(self.in_features_per_rank, out_features)
        )

        self.bias = nn.Parameter(
            torch.zeros(self.out_features)
        )

    def forward(self, x):
        return RowParallelismFunction(x, self.weight, self.bias, self.rank, self.world_size)

In [30]:
class TensorParallelismAttention(nn.Module):
    def __init__(self, n_heads, d_model, rank, world_size):
        super().__init__()
        
        self.n_heads = n_heads
        self.d_model = d_model
        self.rank = rank
        self.world_size = world_size

        self.head_dim = d_model // world_size
        self.n_heads_per_rank = n_heads // world_size

        self.qkv_proj = ColumnParallelism(
            in_features=d_model,
            out_features=d_model * 3,
            rank=rank,
            world_size=world_size
        )

        self.out_proj = RowParallelism(
            in_features=d_model,
            out_features=d_model,
            rank=rank,
            world_size=world_size
        )

    def forward(self, x):
        batch_size, seq_len, _ = x.shape

        qkv = self.qkv_proj(x).reshape(batch_size, seq_len, 3, self.n_heads_per_rank, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        scores = q @ k.transpose(-2, -1) / (self.head_dim ** 0.5)
        attn = nn.functional.softmax(scores, dim=-1)
        attn_output = attn @ v

        attn_output = attn_output.transpose(1, 2).contiguous().reshapse(batch_size, seq_len, self.n_heads_per_rank * self.head_dim)
        output = self.out_proj(attn_output)

        return output

In [31]:
class TensorParallelismMLP(nn.Module):
    def __init__(self, d_model, d_ff, rank, world_size):
        super().__init__()

        self.d_model = d_model
        self.d_ff = d_ff
        self.rank = rank
        self.world_size = world_size

        self.fc1 = ColumnParallelism(
            in_features=d_model,
            out_features=d_ff,
            rank=rank,
            world_size=world_size
        )

        self.fc2 = RowParallelism(
            in_features=d_ff,
            out_features=d_model,
            rank=rank,
            world_size=world_size
        )

        self.activation = nn.GELU()

    def forward(self, x):
        out = x

        out = self.fc1(out)
        out = self.activation(out)
        out = self.fc2(out)

        return out

In [33]:
class TPSPTransformerBlock(nn.Module):
    def __init__(self, n_heads, d_model, d_ff, rank, world_size, p):
        super().__init__()

        self.n_heads = n_heads
        self.d_model = d_model
        self.d_ff = d_ff
        self.rank = rank
        self.world_size = world_size
        self.p = p

        self.attention = TensorParallelismAttention(
            n_heads=n_heads,
            d_model=d_model,
            rank=rank,
            world_size=world_size
        )

        self.mlp = TensorParallelismMLP(
            d_model=d_model,
            d_ff=d_ff,
            rank=rank,
            world_size=world_size
        )

        self.norm1 = SequenceParallelismLayerNorm(d_model)
        self.norm2 = SequenceParallelismLayerNorm(d_model)

        self.dropout1 = SequenceParallelismDropout(p)
        self.dropout2 = SequenceParallelismDropout(p)

    def sp_to_tp_transition(self, tensor: torch.tensor):
        tensor_list = [torch.zeros_like(tensor) for _ in range(self.world_size)]
        dist.all_gather(tensor_list, tensor=tensor)

        tensor = torch.cat(tensor_list, dim=1)
        return tensor
    
    def tp_to_sp_transition(self, tensor: torch.tensor):
        chunks = torch.chunk(tensor, chunks=self.world_size, dim=1)
        chunk = chunks[self.rank].contiguous()

        return chunk
    
    def forward(self, x):
        output = x

        residual = output
        output = self.norm1(output)
        output = self.sp_to_tp_transition(output)
        output = self.attention(output)
        output = self.tp_to_sp_transition(output)
        output = self.dropout1(output)
        output += residual

        residual = output
        output = self.norm2(output)
        output = self.sp_to_tp_transition(output)
        output = self.mlp(output)
        output = self.tp_to_sp_transition(output)
        output = self.dropout2(output)
        output += residual

        return output

In [34]:
class RingAttention(nn.Module):
    def __init__(self, n_heads, d_model, rank, world_size):
        super().__init__()

        self.n_heads = n_heads
        self.d_model = d_model
        self.rank = rank
        self.world_size = world_size

        self.head_dim = d_model // n_heads

        self.qkv_proj = nn.Linear(
            in_features=d_model,
            out_features=d_model * 3,
            bias=False
        )
    
    def forward(self, x):
        batch_size, seq_len_local, d_model = x.shape

        qkv = self.qkv_proj(x).reshape(batch_size, seq_len_local, 3, self.n_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        m = torch.full(
            size=(batch_size, self.n_heads, seq_len_local, 1),
            fill_value=float('-inf'),
            device=x.device
        )

        d = torch.full(
            size=(batch_size, self.n_heads, seq_len_local, 1),
            fill_value=float('-inf'),
            device=x.device
        )

        o = torch.full(
            size=(batch_size, self.n_heads, seq_len_local, self.head_dim),
            fill_value=0.0,
            device=x.device
        )

        k_recv = k.clone()
        v_recv = v.clone()

        for _ in range(self.world_size - 1):
            scores = q @ k_recv.transpose(-2, -1) / torch.sqrt(self.head_dim ** 0.5)

            m, d, o = self._online_softmax(m, d, o, scores, v_recv)

    def _online_softmax(self, m_prev, d_prev, o_prev, scores, v):
        m_new = scores.max(dim=-1, keepdim=True).values()
        m_new = torch.max(m_prev, m_new)

        d = torch.exp(scores - m_new)
        d_new = d_prev * torch.exp(m_prev - m_new) + d

        o_new = o_prev * d_prev * torch.exp(m_prev - m_new) / d_new + torch.matmul(d / d_new, v)

        return m_new, d_new, o_new
    
    def _ring_exchange(self, k_recv, v_recv):
        send_rank = (self.rank + 1) % self.world_size
        recv_rank = (self.rank - 1 + self.world_size) % self.world_size

        k_recv_buffer = torch.zeros_like(k_recv)
        v_recv_buffer = torch.zeros_like(v_recv)

        k_send_handle = dist.isend(k_recv, dst=send_rank)
        v_send_handle = dist.isend(v_recv, dst=send_rank)

        k_recv_handle = dist.irecv(k_recv_buffer, src=recv_rank)
        v_recv_handle = dist.irecv(v_recv_buffer, src=recv_rank)

        k_send_handle.wait()
        v_send_handle.wait()

        k_recv_handle.wait()
        v_recv_handle.wait()

        k_recv, v_recv = k_recv_buffer, v_recv_buffer

        return k_recv, v_recv

#### Pipeline Parallelism

In [35]:
def setup():
    rank = os.environ['RANK']
    world_size = os.environ['WORLD_SIZE']

    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    dist.init_process_group(
        backend='nccl',
        rank=rank,
        world_size=world_size
    )

    torch.cuda.set_device(rank)

In [36]:
def cleanup():
    dist.destroy_process_group()

In [37]:
class PipelineStage(nn.Module):
    def __init__(self, layers):
        super().__init__()

        self.layers = layers

    def forward(self, x):
        out = x
        for layer in self.layers:
            out = layer(out)

        return out

In [38]:
def create_model_stages(hidden_dim, n_layers, world_size):
    n_layers_per_rank = n_layers // world_size

    all_layers = []
    for _ in range(n_layers):
        layer = torch.nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU()
        )
        all_layers.append(layer)

    stages = []
    for idx in range(world_size):
        start_idx = idx * n_layers_per_rank
        end_idx = (idx + 1) * n_layers_per_rank

        stages.append(PipelineStage(all_layers[start_idx : end_idx]))

    return stages

In [None]:
class OneFOneBPipeline:
    def __init__(
        self,
        stage: PipelineStage,
        rank: int,
        world_size: int,
        n_micro_batches: int,
        micro_batch_size: int,
        input_shape: tuple, 
        dtype=torch.float32
    ):
        self.stage = stage,
        self.rank = rank
        self.world_size = world_size
        self.n_micro_batches = n_micro_batches
        self.micro_batch_size = micro_batch_size
        self.input_shape = input_shape
        self.dtype = dtype

        self.next_rank = self.rank + 1 if self.rank < self.world_size - 1 else None
        self.prev_rank = self.rank - 1 if self.rank > 0 else None

        self.warmup_states = world_size - rank - 1
        self.active_states = n_micro_batches - world_size

    def send(self, tensor: torch.tensor, dst: int):
        if dst is not None:
            dist.send(tensor.contiguous(), dst=dst)

    def recv(self, tensor_buffer: torch.tensor, src: int):
        if src is not None:
            dist.recv(tensor_buffer, src=src)
            return tensor_buffer
        
        return None
    
    def forward(self, input: torch.tensor):
        input.requires_grad = True
        output = self.stage(input)

        return output, input
    
    def backward(self, grad_input: torch.tensor, output: torch.tensor, saved_input: torch.tensor):
        output.backward(grad_input)
        return saved_input.grad
    
    def train_step(self, batch):
        saved_inputs = []
        outputs = []

        for i in range(self.warmup_states):
            micro_batch = None
            start_idx = i * self.micro_batch_size
            end_idx = (i + 1) * self.micro_batch_size

            if self.rank == 0:
                micro_batch = batch[start_idx : end_idx]
            else:
                micro_batch = torch.zeros(self.micro_batch_size, *self.input_shape, dtype=self.dtype)
                dist.recv(micro_batch, src=self.prev_rank)

            output, saved_input = self.forward(micro_batch)

            outputs.append(output)
            saved_inputs.append(saved_input)

            self.send(output, dst=self.next_rank)

        for i in range(self.active_states):
            micro_batch = None
            start_idx = (i + self.warmup_states) * self.micro_batch_size
            end_idx = (i + 1 + self.warmup_states) * self.micro_batch_size

            if self.rank == 0:
                micro_batch = batch[start_idx : end_idx]
            else:
                micro_batch = torch.zeros(self.micro_batch_size, *self.input_shape, dtype=self.dtype)
                dist.recv(micro_batch, src=self.prev_rank)

            output, saved_input = self.forward(micro_batch)

            outputs.append(output)
            saved_inputs.append(saved_input)

            self.send(output, dst=self.next_rank)

            output = outputs.pop(0)
            saved_input = saved_inputs.pop(0)

            if self.rank == self.world_size - 1:
                grad_input = torch.ones_like(output, dtype=self.dtype)
            else:
                grad_input = torch.zeros_like(output, dtype=self.dtype)
                dist.recv(grad_input, src=self.next_rank)

            grad_output = self.backward(grad_input, output, saved_input)
            self.send(grad_output, dst=self.prev_rank)

        for idx in range(self.warmup_states):
            output = outputs[idx]
            saved_input = saved_inputs[idx]

            if self.rank == self.world_size - 1:
                grad_input = torch.ones_like(output, dtype=self.dtype)
            else:
                tensor_buffer = torch.zeros_like(output, dtype=self.dtype)
                grad_input = self.recv(tensor_buffer, self.next_rank)

            grad_output = self.backward(grad_input, output, saved_input)

            self.send(grad_output, self.prev_rank)

#### LoRA

In [40]:
class LoRA(nn.Module):
    def __init__(self, in_features, out_features, rank, alpha):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        self.alpha = alpha
        self.scaling_factor = alpha / rank

        self.b = nn.Parameter(torch.zeros(in_features, rank))
        self.a = nn.Parameter(torch.zeros(rank, out_features))

    def _reset_parameters(self):
        nn.init.zeros_(self.b)
        nn.init.kaiming_normal_(self.a)

    def forward(self, x):
        output = torch.matmul(x, self.b)
        output = torch.matmul(output, self.a)

        return output * self.scaling_factor

In [41]:
class LinearWithLoRA(nn.Module):
    def __init__(self, linear: nn.Linear, rank, alpha):
        super().__init__()

        self.linear = linear
        self.rank = rank
        self.alpha = alpha

        self.linear.weight.requires_grad = False
        if self.linear.bias is not None:
            self.linear.bias.requires_grad = False

        self.lora = LoRA(
            linear.in_features,
            linear.out_features,
            rank,
            alpha
        )

        self.merged = False

    def forward(self, x):
        if self.merged:
            return self.linear(x)
        else:
            return self.linear(x) + self.lora(x)
        
    def merge(self):
        if not self.merged:
            with torch.no_grad():
                delta_w = torch.matmul(self.lora.b, self.lora.a) * self.lora.scaling_factor
                self.linear.weight += delta_w
            self.merged = True

    def unmerge(self):
        if self.merged:
            with torch.no_grad():
                delta_w = torch.matmul(self.lora.b, self.lora.a) * self.lora.scaling_factor
                self.linear.weight -= delta_w
            self.merged = False

#### Decoding

In [None]:
def greedy_decoding(model, input_ids, eos_token_id, max_token_length):
    generated = input_ids.clone()

    for _ in range(max_token_length):
        with torch.no_grad():
            logits = model(input_ids)
            next_logits = logits[:,-1,:]

        probs = torch.softmax(next_logits, dim=-1)
        next_token = torch.argmax(probs, dim=-1, keepdim=True)

        generated = torch.concat([generated, next_token], dim=-1)

        if (next_token == eos_token_id).all():
            break

    return generated

In [2]:
def temperature_sampling(model, input_ids, temperature, max_length, eos_token_id):
    generated = input_ids.clone()

    for _ in range(max_length):
        with torch.no_grad():
            logits = model(input_ids)
            next_logits = logits[:,-1,:] / temperature

        probs = torch.softmax(next_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)

        generated = torch.concat([generated, next_token], dim=-1)

        if (next_token == eos_token_id).all():
            break

    return generated

In [3]:
def topk_sampling(model, input_ids, temperature, k, eos_token_id, max_length):
    generated = input_ids.clone()

    for _ in range(max_length):
        with torch.no_grad():
            logits = model(input_ids)
            next_logits = logits[:,-1,:] / temperature

        topk_values, topk_indices = torch.topk(next_logits, k=k, dim=-1)

        filtered_logits = torch.full_like(next_logits, fill_value=float('-inf'))
        filtered_logits.scatter_(dim=-1, index=topk_indices, src=topk_values)

        probs = torch.softmax(filtered_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)

        generated = torch.concat([generated, next_token], dim=-1)

        if (next_token == eos_token_id).all():
            break

    return generated


In [None]:
def top_p_sampling(model, input_ids, p, temperature, eos_token_id, max_length):
    generated = input_ids.clone()

    for _ in range(max_length):
        with torch.no_grad():
            logits = model(input_ids)
            next_logits = logits[:,-1,:] / temperature

        sorted_logits, sorted_indices = torch.sort(next_logits, dim=-1, descending=True)

        sorted_probs = torch.softmax(sorted_logits)
        prob_cumsum = torch.cumsum(sorted_probs, dim=-1)

        mask = prob_cumsum > p
        mask[...,0] = False
        sorted_logits[mask] = float('-inf')

        filtered_logits = torch.full_like(sorted_logits, fill_value=float('-inf'))
        filtered_logits.scatter_(dim=-1, index=sorted_indices, src=sorted_logits)

        probs = torch.softmax(filtered_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)

        generated = torch.concat([generated, next_token], dim=-1)

        if (next_token == eos_token_id).all():
            break
    
    return generated



#### Expert Parallelism

In [5]:
import torch.nn as nn

In [6]:
class Expert(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()

        self.d_model = d_model
        self.d_ff = d_ff

        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model)
        )

    def forward(self, x):
        return self.mlp(x)

In [7]:
class TopKRouter(nn.Module):
    def __init__(self, d_model, n_experts, top_k, capacity_factor):
        super().__init__()

        self.d_model = d_model
        self.n_experts = n_experts
        self.top_k = top_k
        self.capacity_factory = capacity_factor

        self.router = nn.Linear(d_model, n_experts, bias=False)

    def forward(self, input):
        batch_size, seq_len, d_model = input.size

        n_tokens = batch_size * seq_len
        input = input.reshape(n_tokens, d_model)

        router_logits = self.router(input)
        router_probs = torch.softmax(input=router_logits, dim=-1)

        expert_weights, expert_indices = torch.topk(router_probs, self.top_k, dim=-1)
        expert_weights = expert_weights / expert_weights.sum(dim=-1, keepdim=True)

        load_balancing_loss = self._compute_load_balancing_loss(router_probs)

        return expert_weights, expert_indices, load_balancing_loss

    def _compute_load_balancing_loss(self, router_probs):
        num_tokens = router_probs.shape[0]

        expert_assignment = torch.argmax(router_probs, dim=-1)
        expert_count = torch.bincount(expert_assignment).float()

        f_i = expert_count / num_tokens
        p_i = router_probs.mean(dim=0)

        loss = self.n_experts * (f_i * p_i).sum()

        return loss.item()

In [8]:
class DistributedMoELayer(nn.Module):
    def __init__(self, rank, world_size, d_model, d_ff, n_experts, top_k, capacity_factor):
        super().__init__()

        self.rank = rank
        self.world_size = world_size

        self.d_model = d_model
        self.d_ff = d_ff
        self.n_experts = n_experts
        self.top_k = top_k
        self.capacity_factor = capacity_factor

        self.n_experts_per_rank = n_experts // world_size
        self.experts = nn.ModuleList([
            Expert(d_model, d_ff) for _ in range(self.n_experts_per_rankx)
        ])

        self.router = TopKRouter(d_model, n_experts, top_k, capacity_factor)

    def forward(self, x):
        batch_size, seq_len, d_model = x.shape
        num_tokens = batch_size * seq_len

        expert_weights, expert_indices, load_balancing_loss = self.router(x)

        capacity = int((num_tokens / self.n_experts) * self.capacity_factor)

        dispatch_data, combine_weights = self._prepare_all_to_all(
            x, expert_indices, expert_weights, capacity
        )

        received_data = self._all_to_all_scatter(dispatch_data)

        expert_output = self._process_local_experts(received_data)

        gathered_output = self._all_to_all_gather(expert_output)

        output = self._combine_outputs(gathered_output, combine_weights)

        output = output.reshape(batch_size, seq_len, d_model)

        return output, load_balancing_loss

In [9]:
class MoETransformerBlock(nn.Module):
    """
    Complete transformer block with MoE instead of dense FFN.
    """
    def __init__(self, hidden_size, num_heads, num_experts, intermediate_size,
                 top_k=2, capacity_factor=1.25):
        super().__init__()
        self.hidden_size = hidden_size
        
        # Layer norms
        self.ln1 = nn.LayerNorm(hidden_size)
        self.ln2 = nn.LayerNorm(hidden_size)
        
        # Attention (standard multi-head attention)
        self.attention = nn.MultiheadAttention(
            hidden_size, 
            num_heads, 
            batch_first=True
        )
        
        # MoE layer (replaces dense FFN)
        self.moe = DistributedMoELayer(
            hidden_size, 
            num_experts, 
            intermediate_size,
            top_k, 
            capacity_factor
        )
    
    def forward(self, x, attn_mask=None):
        """
        Args:
            x: (batch, seq, hidden)
            attn_mask: Optional attention mask
        
        Returns:
            output: (batch, seq, hidden)
            aux_loss: Load balancing loss
        """
        # Attention block
        normed = self.ln1(x)
        attn_output, _ = self.attention(normed, normed, normed, attn_mask=attn_mask)
        x = x + attn_output
        
        # MoE block
        normed = self.ln2(x)
        moe_output, aux_loss = self.moe(normed)
        x = x + moe_output
        
        return x, aux_loss

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, max_seq_len, d_model):
        super().__init__()

        self.max_seq_len = max_seq_len
        self.d_model = d_model

        pe = torch.zeros(max_seq_len, d_model)
        positions = torch.arange(0, max_seq_len, dtype=torch.float32).unsqueeze(1)

        div_term = torch.pow(10000.0, - torch.arange(0, d_model, 2).float() / d_model)

        pe[:, 0::2] = torch.sin(positions * div_term)
        pe[:, 1::2] = torch.cos(positions * div_term)

        pe = pe.unsqueeze(0)

        self.register_buffer('pe', pe)

    def forward(self, x):
        _, seq_len, _ = x.shape
        x = x + self.pe[:, :seq_len, :]
        return x

## Fundamentals

In [1]:
class LinearRegression:
    def __init__(self, in_features, out_features):
        self.w = torch.randn(in_features, out_features)
        self.b = torch.randn(out_features)

    def forward(self, x, y_true, lr, iterations):
        for _ in range(iterations):
            y_pred = x @ self.w + self.b
            loss = self.mean_squared_error(y_pred, y_true)
            dw, db = self.gradients(y_pred, y_true)
            self.backward(dw, db, lr)
    
    def mean_squared_error(self, y_pred, y_true):
        return torch.mean((y_pred - y_true) ** 2) 
    
    def gradients(self, y_pred, y_true, x):
        diff = (y_pred - y_true)

        dw = - 2 * torch.mean(diff) * x
        db = - 2 * torch.mean(diff)

        return dw, db
    
    def backward(self, dw, db, lr):
        self.w = self.w - lr * dw
        self.b = self.b - lr * db

In [2]:
class LogisticRegression:
    def __init__(self, in_features, n_labels):
        self.w = torch.randn(in_features, n_labels)
        self.b = torch.zeros(n_labels)

    def binary_cross_entropy(self, y_pred, y_true, x):
        return torch.mean(y_true * torch.log(y_pred) + (1 - y_true) * torch.log(1 - y_pred))
    
    def compute_gradients(self, y_true, y_pred, x):
        diff = (y_true - y_pred)

        dw = torch.mean(x * diff)
        db = torch.mean(diff)

        return dw, db
    
    def backward(self, dw, db, learning_rate):
        self.w = self.w - learning_rate * dw
        self.b = self.b - learning_rate * db
        
    def forward(self, x, y_true, learning_rate, iterations):
        for _ in range(iterations):
            y_pred = x @ self.w + self.b
            loss = self.binary_cross_entropy(y_pred, y_true, x)
            dw, db = self.compute_gradients(y_true, y_pred, x)
            self.backward(dw, db, learning_rate)

In [3]:
class KMeans:
    def __init__(self, n_clusters, max_iterations):
        self.n_clusters = n_clusters
        self.max_iterations = max_iterations

        self.centroids = None

    def fit(self, X):
        idx = torch.randint(0, X.shape[0], (self.n_clusters,))
        self.centroids = X[idx]

        for _ in range(self.max_iterations):
            distances = torch.stack([torch.norm(X - c, dim=1) for c in self.centroids], dim=1)
            labels = torch.argmin(distances, dim=1)

            new_centroids = torch.stack([X[labels == k].mean(dim=0) for k in range(self.n_clusters)], dim=0)

            if torch.allclose(new_centroids, self.centroids):
                break

            self.centroids = new_centroids

        self.labels_ = labels
        return self
    
    def predict(self, x):
        distances = torch.stack([torch.norm(x - c, dim=1) for c in self.centroids], dim=1)
        return torch.argmin(distances, dim=1)

In [4]:
class KNN:
    def __init__(self, k=3):
        self.k = k

    def fit(self, X, y):
        self.X = X
        self.y = y

    def predict(self, X):
        predictions = []

        for x in X:
            distances = torch.norm(self.X - x, dim=1)
            _, topk_indices = torch.topk(distances, k=self.k, largest=False)

            k_labels = self.y[topk_indices]

            prediction = torch.mode(k_labels).values
            predictions.append(prediction)

        return torch.stack(predictions)

In [None]:
def pca(data, n_components):
    mean = torch.mean(data, dim=0)
    std = torch.std(data, dim=0)

    normalized = (data - mean) / (std)
    n_samples = data.shape[0]

    cov_matrix = torch.matmul(normalized.t(), normalized) / (n_samples - 1)

    eigen_values, eigen_vectors = torch.linalg.eigh(cov_matrix)
    
    sorted_index = torch.argsort(eigen_values)[::-1]
    sorted_eigen_values = eigen_values[sorted_index]
    sorted_eigen_vectors = eigen_vectors[:, sorted_index]

    principle_components = sorted_eigen_vectors[:, : n_components]

    reduced_data = torch.matmul(normalized, principle_components)

    return reduced_data, principle_components, sorted_eigen_values 

In [1]:
class SVM:
    def __init__(self, lambda_p):
        self.lambda_p = lambda_p
        self.W, self.b = None, None

    def fit(self, X, y, n_epochs, learning_rate):
        y = torch.where(y == 0, -1.0, 1.0)

        self.W = torch.randn(X.shape[1])
        self.b = torch.zeros(X.shape[0])

        n_samples = X.shape[0]
        for epoch in range(n_epochs):
            for i in range(n_samples):
                condition = y[i] * (self.W * X[i] + self.b) >= 1

                if condition:
                    self.W -= learning_rate * (2 * self.lambda_p * self.W)
                else:
                    self.W -= learning_rate * (2 * self.lambda_p * self.W - y[i] * X[i])
                    self.b -= learning_rate * (-y[i])

    def hinge_loss(self, X, y):
        loss = torch.max(0, torch.mean(1 - y * (X @ self.W + self.b)))
        return loss
    
    def total_loss(self, X, y):
        return torch.square(self.W) + self.hinge_loss(X, y)
    
    def predict(self, X):
        y = X @ self.W + self.b
        return torch.where(y >= 0, 1, 0)

#### Loss Functions

In [5]:
def mse_loss(y_true, y_pred):
    return torch.mean(torch.square(y_true - y_pred))

In [6]:
def mae_loss(y_true, y_pred):
    return torch.mean(torch.abs(y_true - y_pred))

In [8]:
def bce_loss(y_true, y_pred, eps=1e-8):
    return - torch.mean(
        y_true * torch.log(y_pred + eps) + (1 - y_true) * torch.log(1 - y_pred + eps)
    )

In [9]:
def cross_entropy_loss(y_true, logits, eps=1e-8):
    y_pred = torch.softmax(logits)
    return - torch.mean(y_true * torch.log(y_pred + eps))

##### Activations

In [10]:
def sigmoid(x):
    return 1 / (1 + torch.exp(-x))

In [11]:
def tanh(x):
    return (torch.exp(x) - torch.exp(-x)) / (torch.exp(x) + torch.exp(-x))

In [12]:
def relu(x):
    return torch.max(x, 0)

In [13]:
def leaky_relu(x, alpha=0.1):
    return torch.max(x, alpha * x)

In [14]:
class PReLU:
    def __init__(self):
        self.alpha = torch.tensor(0.25, requires_grad=True)

    def forward(self, x):
        return torch.max(x, self.alpha * x)

In [15]:
def elu(x, alpha):
    return torch.where(x > 0, x, alpha * (torch.exp(x) - 1))

In [None]:
def softmax(x):
    exp_x = torch.exp(x - torch.max(x).values) # stable form
    return exp_x / torch.sum(exp_x)

In [17]:
def log_softmax(x):
    # log(softmax(x)) = x - log(sum(exp(x)))
    return x - torch.logsumexp(x)

#### Optimizers

In [18]:
# Stochastic Gradient Descent
class SGD:
    def __init__(self, params, lr=1e-3):
        self.params = params
        self.lr = lr

    def step(self):
        for p in self.params:
            p.data -= self.lr * p.grad

In [19]:
# Stochastic Gradient Descent with Momentum
class SGDMomentum:
    def __init__(self, params, lr=1e-3, momentum=0.9):
        self.params = params
        self.lr = lr
        self.momentum = momentum

        self.v = [torch.zeros_like(p) for p in params]

    def step(self):
        for i, p in enumerate(self.params):
            self.v[i] = self.momentum * self.v[i] - self.lr * p.grad
            p.data += self.v[i]

In [None]:
class Adagrad:
    def __init__(self, params, lr=0.01, eps=1e-8):
        self.params = params
        self.lr = lr
        self.eps = eps
        self.cache = [torch.zeros_like(p) for p in params]

    def step(self):
        for i, p in enumerate(self.params):
            self.cache[i] += p.grad ** 2
            p.data -= self.lr * p.grad / (torch.sqrt(self.cache[i]) + self.eps)

In [None]:
class RMSProp:
    def __init__(self, params, decay_rate, lr=0.01, eps=1e-8):
        self.params = params
        self.decay_rate = decay_rate
        self.lr = lr
        self.eps = eps
        self.cache = [torch.zeros_like(p) for p in params]

    def step(self):
        for i, p in enumerate(self.params):
            self.cache[i] = self.cache[i] * self.decay_rate + (1 - self.decay_rate) * p.grad ** 2
            p.data -= self.lr * p.grad / (torch.sqrt(self.cache[i]) + self.eps)

In [22]:
class Adam:
    def __init__(self, params, lr=0.0001, beta1=0.9, beta2=0.999, eps=1e-8):
        self.params = params
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps  

        self.m = [torch.zeros_like(p) for p in params]
        self.v = [torch.zeros_like(p) for p in params]

        self.t = 0

    def step(self):
        self.t += 1

        for i, p in self.params:
            self.m[i] = self.beta1 * self.m[i] + (1 - self.beta1) * p.grad
            self.v[i] = self.beta2 * self.v[i] + (1 - self.beta2) * p.grad ** 2

            m_hat = self.m[i] / (1 - self.beta1 ** self.t)
            v_hat = self.v[i] / (1 - self.beta2 ** self.t)

            p.data -= self.lr * m_hat / (torch.sqrt(v_hat) + self.eps)

## Vision Transformer

In [4]:
import torch
import torch.nn as nn

In [5]:
class PatchEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.image_height = config['image_height']
        self.image_width = config['image_widht']
        self.image_channels = config['image_channels']
        self.d_model = config['d_model']

        self.patch_height = config['patch_height']
        self.patch_width = config['patch_width']

        n_patches = (self.image_height // self.patch_height) * (self.image_width // self.patch_width)
        patch_dim = self.patch_height * self.patch_width * self.image_channels

        self.patch_embedding = nn.Sequential(
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, self.d_model),
            nn.LayerNorm(patch_dim)
        )

        self.position_embedding = nn.Embedding(n_patches + 1, self.d_model)
        self.cls_token = nn.Parameter(torch.randn(self.d_model))

        self.n_patches = n_patches

    def forward(self, x):
        batch_size = x.shape[0]

        x = x.reshape(batch_size, self.n_patches, self.patch_height, self.n_patches, self.patch_width, self.image_channels)
        x = x.permute(0, 1, 3, 2, 4, 5)
        x = x.reshape(batch_size, self.n_patches * self.n_patches, self.patch_height * self.patch_width * self.image_channels)

        out = self.patch_embedding(x)
        cls_token = self.cls_token.unsqueeze(0).expand(batch_size, 1).unsqueeze(1)

        out = torch.cat([out, cls_token], dim=1)

        out += self.position_embedding(out)

        return out


In [7]:
class Attention(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.n_heads = config['n_heads']
        self.d_model = config['d_model']

        self.head_dim = self.d_model // self.n_heads

        self.qkv_proj = nn.Linear(self.d_model, 3 * self.d_model)
        self.out_proj = nn.Linear(self.d_model, self.d_model)

    def forward(self, x):
        batch_size, n_patches, _ = x.shape

        qkv = self.qkv_proj(x).reshape(batch_size, n_patches, 3, self.n_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        scores = q @ k.transpose(-2, -1) / (self.head_dim ** 0.5)
        attn = nn.functional.softmax(scores, dim=-1)

        output = attn @ v
        output = output.transpose(1, 2).reshape(batch_size, n_patches, self.d_model)

        output = self.out_proj(output)

        return output

In [8]:
class TransformerLayer(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.n_heads = config['n_heads']
        self.d_model = config['d_model']
        self.d_ff = config['d_ff']

        self.attention = Attention(config)

        self.mlp = nn.Sequential(
            nn.Linear(self.d_model, self.d_ff),
            nn.GELU(),
            nn.Linear(self.d_ff, self.d_model)
        )

        self.norm1 = nn.LayerNorm(self.d_model)
        self.norm2 = nn.LayerNorm(self.d_model)

        dropout_p = config['dropout_p']

        self.dropout1 = nn.Dropout(dropout_p)
        self.dropout2 = nn.Dropout(dropout_p)

    def forward(self, x):
        residual = x

        out = self.norm1(out)
        out = self.attention(out)
        out = self.dropout1(out)    
        out += residual

        residual = out
        out = self.norm2(out)
        out = self.mlp(out)
        out = self.dropout2(out)
        out += residual

        return out

In [10]:
class VIT(nn.Module):
    def __init__(self, config): 
        super().__init__()

        self.patch_embedding_block = PatchEmbedding(config)
        
        n_layers = config['n_layers']
        self.layers = torch.nn.ModuleList([
            TransformerLayer(config) for _ in range(n_layers)
        ])
        self.n_layers = n_layers

        self.norm = nn.LayerNorm(config['d_model'])
        self.linear = nn.Linear(config['d_model'], config['n_classes'])

    def forward(self, x):
        out = x

        out = self.patch_embedding_block(out)
        for layer in self.n_layers:
            out = layer(out)

        out = self.norm(out)
        out = self.linear(out)

        return out[:, 0]

## DDPM

In [11]:
class LinearNoiseSchedule:
    def __init__(self, n_steps, beta_start, beta_end):
        self.n_steps = n_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        
        self.betas = torch.linspace(beta_start, beta_end, steps=self.n_steps)

        self.alpha = 1. - self.beta
        self.alpha_cum_prod = torch.cumprod(self.alpha)

        self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
        self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1. - self.alpha_cum_prod)

    def noise(self, original, noise, t):
        sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod[t]
        sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod[t]

        for _ in range(len(original.shape) - 1):
            sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(0)
            sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(0)

        return sqrt_alpha_cum_prod * original + sqrt_one_minus_alpha_cum_prod * noise
    
    def denoise(self, xt, noise_pred, t):
        x0 = (xt - self.sqrt_one_minus_alpha_cum_prod[t] * noise_pred) / self.sqrt_alpha_cum_prod[t]
        x0 = torch.clamp(x0, 0.0, 1.0)

        xt_1 = xt - ((1. - self.alpha[t]) / (self.sqrt_one_minus_alpha_cum_prod)) * noise_pred
        xt_1 = xt_1 / self.sqrt_alpha_cum_prod[t]

        if t == 0:
            return xt_1, x0
        else:
            variance = (1. - self.alpha[t]) * ((1. - self.alpha_cum_prod[t - 1]) / (1. - self.alpha_cum_prod[t]))
            std = variance ** 0.5

            z = torch.randn(xt.shape)

            return xt_1 + std * z, x0