In [1]:
import matplotlib.pyplot as plt
import torch
from torch.nn import Parameter
from torch.optim import Adam

from truetype_vs_postscript_transformer.modules.scheduler import WarmupDecayLR


def plot_scheduler(warmup_steps: int, total_steps: int, base_lr: float) -> None:
    """Plot the learning rate schedule with warmup and decay."""
    model_params = [Parameter(torch.randn(2, 2, requires_grad=True))]
    optimizer = Adam(model_params, lr=base_lr)
    scheduler = WarmupDecayLR(optimizer, warmup_steps)

    lrs = []
    for _ in range(1, total_steps + 1):
        optimizer.step()
        scheduler.step()
        lrs.append(scheduler.get_last_lr()[0])

    plt.figure(figsize=(8, 5))
    plt.plot(range(1, total_steps + 1), lrs, label="Learning Rate")
    plt.axvline(x=warmup_steps, color="r", linestyle="--", label="Warmup End")
    plt.title("Learning Rate Schedule with Warmup and Decay")
    plt.xlabel("Steps")
    plt.ylabel("Learning Rate")
    plt.legend()
    plt.grid(visible=True)
    plt.show()


In [None]:
plot_scheduler(warmup_steps=100, total_steps=1500, base_lr=0.001)
