In [None]:
https://github.com/pytorch/ao/tree/main/torchao/optim
https://pytorch.org/blog/pytorch-native-architecture-optimization/

https://arxiv.org/pdf/2502.10940

# Project Index

[Custom Model Notebook](../../../notebooks/custom_model.ipynb)  
[Training Notebook](../../../notebooks/train.ipynb)  
[Project Config Notebook](../../../notebooks/project_config.ipynb)  
[Forgather Notebook](../../../notebooks/forgather.ipynb)  

In [None]:
import forgather.nb.notebooks as nb

nb.display_project_index(config_template="my_sgd.yaml", show_pp_config=True, show_generated_code=False)

In [None]:
from forgather.ml.trainer import Trainer
from forgather.ml.trainer_types import TrainingArguments
from forgather.ml.training_script import TrainingScript
from pprint import pp
import torch
from forgather.project import Project
import forgather.nb.notebooks as nb

# Load default baseline config
proj = Project("my_adafactor2.yaml")

In [None]:
for name, param in outputs["model"].named_parameters():
    print(name, param.dtype)

In [None]:
torch.manual_seed(42)
outputs = proj([
    "meta",
    "distributed_env",
    "train_dataset",
    "eval_dataset",
    "trainer_args",
    "model",
    "data_collator",
    "trainer_callbacks",
    "tokenizer",
    "optimizer",
    "trainer",
    "main",
])

training_args = outputs["trainer_args"] | dict(
    #num_train_epochs = 1,
    #learning_rate = 1.0e-3,
    #lr_scheduler_type=None,
)

training_args = TrainingArguments(**training_args)
pp(training_args)

In [None]:
model = outputs["model"]

In [None]:
model.

In [None]:
outputs["main"].run()

In [None]:
import torch
from pickle import dump

# tell CUDA to start recording memory allocations
torch.cuda.memory._record_memory_history(enabled='all')

In [None]:
from pickle import dump

# save a snapshot of the memory allocations
s = torch.cuda.memory._snapshot()
with open(f"snapshot.pickle", "wb") as f:
    dump(s, f)

# tell CUDA to stop recording memory allocations now
torch.cuda.memory._record_memory_history(enabled=None)

In [None]:
nb.generate_trainingscript(proj, "0")

In [None]:
nb.display_tb_command(proj, local_host=False)

In [None]:
nb.generate_trainingscript(proj, "0")

In [None]:
# Snap
class ProjectorOptim(Optimizer):
    """
    Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
    Regularization](https://arxiv.org/abs/1711.05101).

    Parameters:
        params (`Iterable[nn.parameter.Parameter]`):
            Iterable of parameters to optimize or dictionaries defining parameter groups.
        lr (`float`, *optional*, defaults to 0.001):
            The learning rate to use.
        betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`):
            Adam's betas parameters (b1, b2).
        eps (`float`, *optional*, defaults to 1e-06):
            Adam's epsilon for numerical stability.
        weight_decay (`float`, *optional*, defaults to 0.0):
            Decoupled weight decay to apply.
        correct_bias (`bool`, *optional*, defaults to `True`):
            Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
    """

    def __init__(
        self,
        params: Iterable[nn.parameter.Parameter],
        lr: float = 1e-3,
        betas: Tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-6,
        weight_decay: float = 0.0,
        correct_bias: bool = True,
    ):
        require_version("torch>=1.5.0")  # add_ with alpha
        if lr < 0.0:
            raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
        if not 0.0 <= eps:
            raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
        defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias}
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure: Callable = None):
        """
        Performs a single optimization step.

        Arguments:
            closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                if "proj_args" not in group:
                    self.update_adamw(group, p)
                else:
                    self.update_projection(group, p)

        return loss

    def update_adamw(self, group, p):
        grad = p.grad
        if grad.is_sparse:
            raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")

        state = self.state[p]
        
        if "step" not in state:
            state["step"] = 0

        # State initialization
        if "exp_avg" not in state:
            # Exponential moving average of gradient values
            state["exp_avg"] = torch.zeros_like(grad)
            # Exponential moving average of squared gradient values
            state["exp_avg_sq"] = torch.zeros_like(grad)

        exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
        beta1, beta2 = group["betas"]

        state["step"] += 1

        # Decay the first and second moment running average coefficient
        # In-place operations to update the averages at the same time
        exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
        denom = exp_avg_sq.sqrt().add_(group["eps"])

        step_size = group["lr"]
        if group["correct_bias"]:  # No bias correction for Bert
            bias_correction1 = 1.0 - beta1 ** state["step"]
            bias_correction2 = 1.0 - beta2 ** state["step"]
            step_size = step_size * math.sqrt(bias_correction2) / bias_correction1

        # compute norm gradient
        norm_grad = exp_avg / denom
        
        p.add_(norm_grad, alpha=-step_size)

        # Just adding the square of the weights to the loss function is *not*
        # the correct way of using L2 regularization/weight decay with Adam,
        # since that will interact with the m and v parameters in strange ways.
        #
        # Instead we want to decay the weights in a manner that doesn't interact
        # with the m/v parameters. This is equivalent to adding the square
        # of the weights to the loss with plain (non-momentum) SGD.
        # Add weight decay at the end (fixed version)
        if group["weight_decay"] > 0.0:
            p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))

    def update_projection(self, group, p):
        grad = p.grad
        state = self.state[p]
        
        if "step" not in state:
            state["step"] = 0

        beta1, beta2 = group["betas"]
        
        # Projection
        if "projector" not in state:
            projector = state["projector"] = SubspaceProjector(
                grad,
                **group["proj_args"],
            )
        else:
            projector = state["projector"]
            projector.update(grad)
        
        grad = projector.down(grad)
        match projector.proj_type:
            case "right":
                S = grad.square().mean(dim=1).sqrt().view(-1, 1)
            case "left":
                S = grad.square().mean(dim=0).sqrt()
            case _:
                raise Exception("Unknown projection type")
        
        state["step"] += 1

        norm_grad = grad / (S + group["eps"])
        
        # Project up
        norm_grad = projector.up(norm_grad)
        
        p.add_(norm_grad, alpha=(-group["lr"]))

        if group["weight_decay"] > 0.0:
            p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))

In [None]:
import math
import warnings
from typing import Callable, Iterable, Tuple

import torch
import torch.nn.functional as F
from torch import nn
from torch.optim import Optimizer

i = torch.randn(8).abs()
i

In [None]:
rows = 4
cols = 8
eps = 1e-9

r_ = torch.randn(rows)
c_ = torch.randn(cols)

# Rank-1 matrix
M = torch.outer(r_, c_)
M_square = M ** 2

r_, c_, M, M_square

In [None]:
r = M_square.mean(dim=1)
r

In [None]:
c = M_square.mean(dim=0)
c

In [None]:
r_factor = 1 / torch.sqrt(r / r.mean(dim=-1, keepdim=True) + eps)
r_factor

In [None]:
c_factor = 1 / torch.sqrt(c)
c_factor

In [None]:
M_f = torch.outer(r_factor, c_factor)
M_f

In [None]:
r_factor = (r / r.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
c_factor = c.unsqueeze(-2).rsqrt()
torch.mul(r_factor, c_factor)

In [None]:
M_square = M ** 2
r = M_square.mean(dim=1)
c = M_square.mean(dim=0)
r_factor = torch.rsqrt(r / r.mean())
c_factor = torch.rsqrt(c)
M_f = torch.outer(r_factor, c_factor)
M_f

In [None]:
# Pure low-rank factorization of M**2
M_square = M ** 2
r = M_square.mean(dim=1)
r /= r.mean() + eps
c = M_square.mean(dim=0)
M_f = torch.outer(r, c)
M_f
#1 / (torch.sqrt(M_f) + eps)

In [None]:
torch.sqrt(M_f)

In [None]:
# Rank-1 factorization of M
r = M.mean(dim=1)
r /= r.mean() + eps
c = M.mean(dim=0)
M_f = torch.outer(r, c)
M_f