In [16]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

torch.manual_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])
dataset1 = datasets.MNIST('./data', train=True, download=True,
                   transform=transform)
dataset2 = datasets.MNIST('./data', train=False,
                   transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1, batch_size = 32)
test_loader = torch.utils.data.DataLoader(dataset2, batch_size = 32)

model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=3e-4)

In [38]:
import torch
import warnings
from torch._six import inf
from typing import Union, Iterable

_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]

def AGC(parameters: _tensor_or_tensors, clip: float = 1e-3, eps: float = 1e-3) -> torch.Tensor:
    """Adaptively clips gradients of an iterable of parameters.
    
    Args:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        clip: (float) Maximum allowed ratio of update norm to parameter norm.
        eps: (float) epsilon term to prevent clipping of zero-initialized params.
    
    """
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = [p for p in parameters if p.grad is not None]
    for p in parameters:
        
        clip_tensor = torch.tensor(clip).to(p.device) 
        eps_tensor = torch.tensor(eps).to(p.device) 

        g_norm = unitwise_norm(p.grad)
        p_norm = unitwise_norm(p)
        
        max_norm = clip_tensor  * torch.max(p_norm, eps_tensor)
        p.grad.data.copy_(my_clip(g_norm, max_norm, p.grad))
    
    
def my_clip(g_norm, max_norm, grad):
    trigger = g_norm < max_norm
    # This little max(., 1e-6) is distinct from the normal eps and just prevents
    # division by zero. It technically should be impossible to engage.
    small = torch.tensor(1e-6).to(g_norm.device)
    clipped_grad = grad * (max_norm / torch.max(g_norm, small))
    return torch.where(trigger, grad, clipped_grad)
        
    
def unitwise_norm(x: torch.Tensor) -> torch.Tensor:
    """Compute norms of each output unit separately, also for linear layers."""
    if x.ndim <= 1: # Scalars and vectors
        dim = 0
        keepdims = False
    elif x.ndim in [2, 3]: # Linear layers of shape IO or multihead linear
        dim = 0
        keepdims = True
    elif x.ndim == 4: # Conv kernels of shape IOHW
        # other code source uses dim = [0, 1, 2,], but i assume its for convolution order
        dim = [1, 2, 3]
        keepdims = True
    else:
        raise ValueError(f'Got a parameter with shape not in [1, 2, 4]! {x}')
    return compute_norm(x, dim, keepdims)


def compute_norm(x, dim, keepdims):
    """Axis-wise euclidean norm."""
    return torch.sum(x ** 2, dim=dim, keepdims=keepdims) ** 0.5

data, target = next(iter(train_loader))

data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)

AGC(model.parameters(), 0.1, 0.1)

#loss.backward()
#optimizer.step()

In [None]:
class AGC(optim.Optimizer):

    def __init__(self, optim, clipping=1e-2, eps=1e-3):
        super().__init__()
        self.optim = optim
        self.clipping = clipping
        self.eps = eps

    @torch.no_grad()
    def step(self, closure=None):
        for group in self.param_groups:
            for p in group['params']:
                param_norm = torch.max(unitwise_norm(
                    p), torch.tensor(group['eps']).to(p.device))
                grad_norm = unitwise_norm(p.grad)
                max_norm = param_norm * group['clipping']

                trigger = grad_norm > max_norm

                clipped_grad = p.grad * \
                    (max_norm / torch.max(grad_norm,
                                          torch.tensor(1e-6).to(grad_norm.device)))
                p.grad.data.copy_(torch.where(trigger, clipped_grad, p.grad))
        self.optim.step(closure)