In [None]:
!pip install triton
!pip install torch
!pip install torchvision

In [None]:
import torch
from torch.optim.optimizer import Optimizer
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

import triton
import triton.language as tl

In [None]:
# Make sure each version is compatible with the other.
torch.cuda.is_available()
torch.version.cuda
torch.__version__

In [None]:
"""
Module implementing the TIGER optimization using Pytorch efficiently utilizing Triton for CUDA by @juvi21.[tiger.py]
"""


@triton.autotune(
    configs=[
        triton.Config({"BLOCK_SIZE": 128}, num_warps=4),
        triton.Config({"BLOCK_SIZE": 1024}, num_warps=8),
    ],
    key=["n_elements"],
)
@triton.jit
def tiger_kernel(
    p_ptr,
    grad_ptr,
    exp_avg_ptr,
    lr,
    weight_decay,
    beta,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    offset_p_ptr = p_ptr + offsets
    offset_grad_ptr = grad_ptr + offsets
    offset_exp_avg_ptr = exp_avg_ptr + offsets

    p = tl.load(offset_p_ptr, mask=mask)
    grad = tl.load(offset_grad_ptr, mask=mask)
    exp_avg = tl.load(offset_exp_avg_ptr, mask=mask)

    # TODO: Maybe need to check if not is_nan.
    # Not sure if this is necessary. Both versions work fine.
    # if not is_nan:

    p *= 1 - lr * weight_decay

    update = beta * exp_avg + (1 - beta) * grad
    p += tl.where(update > 0, -lr, lr)

    tl.store(offset_p_ptr, p, mask=mask)
    tl.store(offset_exp_avg_ptr, update, mask=mask)


def tiger_step(p, grad, exp_avg, lr, weight_decay, beta):
    assert all([t.is_cuda for t in [p, grad, exp_avg]])
    n_elements = p.numel()

    def grid(meta):
        return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)

    tiger_kernel[grid](p, grad, exp_avg, lr, weight_decay, beta, n_elements)


class Tiger(Optimizer):
    def __init__(self, params, lr=1e-3, beta=0.965, weight_decay=0.01):
        if lr < 0:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0 <= beta < 1:
            raise ValueError(f"Invalid beta parameter: {beta}")
        defaults = dict(lr=lr, beta=beta, weight_decay=weight_decay)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        else:
            loss = None

        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue

                grad = p.grad
                state = self.state[p]

                if not state:
                    state["exp_avg"] = torch.zeros_like(p, device="cuda")

                exp_avg = state["exp_avg"]
                beta = group["beta"]
                lr = group["lr"]
                weight_decay = group["weight_decay"]

                tiger_step(p.data, grad.data, exp_avg.data, lr, weight_decay, beta)

        return loss


In [None]:
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28*28, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

model = SimpleNN().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = Tiger(model.parameters(), lr=0.001)

num_epochs = 5
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.cuda(), labels.cuda()

        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1000, shuffle=False)

model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.cuda(), labels.cuda()
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the model on the 10000 test images: {100 * correct / total}%')