https://kozodoi.me/blog/20220329/discriminative-lr

In [1]:
from typing import Dict, Iterator, List, Literal, Tuple, Union

from rich.pretty import pprint
from torch import nn
from torch.optim import AdamW
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    DebertaV2ForSequenceClassification,
    get_cosine_schedule_with_warmup,
    get_linear_schedule_with_warmup,
    logging,
    DebertaV2PreTrainedModel,
)
from rich.pretty import pprint
from omnivault.utils.torch_utils.model_utils import get_named_parameters

logging.set_verbosity_warning()
logging.set_verbosity_error()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
BASE_MODEL: DebertaV2PreTrainedModel = DebertaV2ForSequenceClassification.from_pretrained("microsoft/deberta-v3-xsmall")

for index, (parameter_name, parameter) in enumerate(BASE_MODEL.named_parameters()):
    print(f"{index}: {parameter_name}")

0: deberta.embeddings.word_embeddings.weight
1: deberta.embeddings.LayerNorm.weight
2: deberta.embeddings.LayerNorm.bias
3: deberta.encoder.layer.0.attention.self.query_proj.weight
4: deberta.encoder.layer.0.attention.self.query_proj.bias
5: deberta.encoder.layer.0.attention.self.key_proj.weight
6: deberta.encoder.layer.0.attention.self.key_proj.bias
7: deberta.encoder.layer.0.attention.self.value_proj.weight
8: deberta.encoder.layer.0.attention.self.value_proj.bias
9: deberta.encoder.layer.0.attention.output.dense.weight
10: deberta.encoder.layer.0.attention.output.dense.bias
11: deberta.encoder.layer.0.attention.output.LayerNorm.weight
12: deberta.encoder.layer.0.attention.output.LayerNorm.bias
13: deberta.encoder.layer.0.intermediate.dense.weight
14: deberta.encoder.layer.0.intermediate.dense.bias
15: deberta.encoder.layer.0.output.dense.weight
16: deberta.encoder.layer.0.output.dense.bias
17: deberta.encoder.layer.0.output.LayerNorm.weight
18: deberta.encoder.layer.0.output.LayerNo

The base model has an embedding layer - which we can group it as embeddings.
Note that by itself the embeddings group are not iterable and so it is just
right to treat it as an independent group. From research papers, the embedding'
group is notoriously known to be difficult to tune the initial weights, but we
keep that for another session to pry into.

In [3]:
assert isinstance(BASE_MODEL, DebertaV2ForSequenceClassification)
embeddings_group = BASE_MODEL.deberta.embeddings
embeddings_named_parameters = get_named_parameters(embeddings_group)

pprint(embeddings_group)
for module in embeddings_named_parameters:
    print(module.keys())


dict_keys(['word_embeddings.weight'])
dict_keys(['LayerNorm.weight'])
dict_keys(['LayerNorm.bias'])


The backbone of the base model is called the `encoder` and it is a stack of
transformer layers. The transformer layers are grouped as `encoder` group. The
`encoder` group is iterable so we can treat each "stack/block" of encoder block
as a "layer" and note the notation abuse here, by layer in encoder backbone we
actually mean the encoder block.

In [4]:
backbone_group = BASE_MODEL.deberta.encoder
pprint(backbone_group) # has additional rel embeddings + layernorm

backbone_group = BASE_MODEL.deberta.encoder.layer
pprint(backbone_group)

pprint(backbone_group.__len__())

backbone_named_parameters = get_named_parameters(backbone_group)

for module in backbone_named_parameters:
    print(module.keys())


dict_keys(['0.attention.self.query_proj.weight'])
dict_keys(['0.attention.self.query_proj.bias'])
dict_keys(['0.attention.self.key_proj.weight'])
dict_keys(['0.attention.self.key_proj.bias'])
dict_keys(['0.attention.self.value_proj.weight'])
dict_keys(['0.attention.self.value_proj.bias'])
dict_keys(['0.attention.output.dense.weight'])
dict_keys(['0.attention.output.dense.bias'])
dict_keys(['0.attention.output.LayerNorm.weight'])
dict_keys(['0.attention.output.LayerNorm.bias'])
dict_keys(['0.intermediate.dense.weight'])
dict_keys(['0.intermediate.dense.bias'])
dict_keys(['0.output.dense.weight'])
dict_keys(['0.output.dense.bias'])
dict_keys(['0.output.LayerNorm.weight'])
dict_keys(['0.output.LayerNorm.bias'])
dict_keys(['1.attention.self.query_proj.weight'])
dict_keys(['1.attention.self.query_proj.bias'])
dict_keys(['1.attention.self.key_proj.weight'])
dict_keys(['1.attention.self.key_proj.bias'])
dict_keys(['1.attention.self.value_proj.weight'])
dict_keys(['1.attention.self.value_proj.

In [5]:
pooler_group = BASE_MODEL.pooler
pprint(pooler_group)

pooler_named_parameters = get_named_parameters(pooler_group)
for module in pooler_named_parameters:
    print(module.keys())

dict_keys(['dense.weight'])
dict_keys(['dense.bias'])


In [6]:
head_group = BASE_MODEL.classifier
pprint(head_group)

head_named_parameters = get_named_parameters(head_group)
for module in head_named_parameters:
    print(module.keys())


dict_keys(['weight'])
dict_keys(['bias'])


In [7]:

def get_optimizer_grouped_parameters_by_category(
    model: nn.Module,
    learning_rate: float,
    weight_decay: float,
    layerwise_learning_rate_decay_mulitplier: float = 0.95,
    pooler_lr: float | None = None,
    head_lr: float | None = None,
    pooler_weight_decay: float | None = None,
    head_weight_decay: float | None = None,
) -> List[Dict[str, str | float | List[nn.Parameter]]]:

    # LayerNorm.bias is automatically included in no decay since bias is in no decay
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]

    embeddings_group = model.deberta.embeddings
    backbone_group = model.deberta.encoder.layer
    pooler_group = model.pooler
    head_group = model.classifier

    head_no_decay = {
        "params": [
            parameter
            for parameter_name, parameter in head_group.named_parameters()
            if any(nd in parameter_name for nd in no_decay)
        ],
        "weight_decay": 0.0,
        "lr": learning_rate if head_lr is None else head_lr,
        "name": "head_no_decay",
    }

    head_decay = {
        "params": [
            parameter
            for parameter_name, parameter in head_group.named_parameters()
            if not any(nd in parameter_name for nd in no_decay)
        ],
        "weight_decay": weight_decay if head_weight_decay is None else head_weight_decay,
        "lr": learning_rate if head_lr is None else head_lr,
        "name": "head_decay",
    }

    # this group applies no weight decay
    pooler_no_decay = {
        "params": [
            parameter
            for parameter_name, parameter in pooler_group.named_parameters()
            if any(nd in parameter_name for nd in no_decay)
        ],
        "weight_decay": 0.0,
        "lr": learning_rate if pooler_lr is None else pooler_lr,
        "name": "pooler_no_decay",
    }

    pooler_decay = {
        "params": [
            parameter
            for parameter_name, parameter in pooler_group.named_parameters()
            if not any(nd in parameter_name for nd in no_decay)
        ],
        "weight_decay": weight_decay if pooler_weight_decay is None else pooler_weight_decay,
        "lr": learning_rate if pooler_lr is None else pooler_lr,
        "name": "pooler_decay",
    }


    optimizer_grouped_parameters = [pooler_no_decay, pooler_decay, head_no_decay, head_decay]
    embeddings_and_backbone_group = [embeddings_group] + list(backbone_group)
    embeddings_and_backbone_group.reverse()

    lr = learning_rate
    # NOTE: decay only happens at a embedding + backbone level
    for index, layer in enumerate(embeddings_and_backbone_group):
        lr *= layerwise_learning_rate_decay_mulitplier
        # NOTE: add no decay and decay groups for encoder/backbone

        optimizer_grouped_parameters += [
            {
                "params": [parameter for parameter_name, parameter in layer.named_parameters() if not any(nd in parameter_name for nd in no_decay)],
                "weight_decay": weight_decay,
                "lr": lr,
                "name": f"{layer.__class__.__name__}_{index}_decay",
            },
            {
                "params": [parameter for parameter_name, parameter in layer.named_parameters() if any(nd in parameter_name for nd in no_decay)],
                "weight_decay": 0.0,
                "lr": lr,
                "name": f"{layer.__class__.__name__}_{index}_no_decay",
            },
        ]
    return optimizer_grouped_parameters

In [8]:
# def get_optimizer_grouped_parameters_by_layer(
#     model: nn.Module,
#     group_configs: List[Dict[str, str | float | bool]],
#     default_learning_rate: float,
#     default_weight_decay: float,
#     layerwise_learning_rate_decay_mulitplier: float = 0.95
# ) -> List[Dict[str, str | float | List[nn.Parameter]]]:
#     no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
#     optimizer_parameter_groups = []
#     named_parameters = list(model.named_parameters())

#     for parameter_name, parameter in named_parameters:
#         weight_decay = 0.0 if any(nd in parameter_name for nd in no_decay) else default_weight_decay
#         # if weight_decay == 0.0:
#         #     print(f"parameter_name: {parameter_name} has no weight decay")

#         applied = False

#         for group_config in group_configs:
#             if parameter_name.startswith(group_config["prefix"]):
#                 layer_lr = group_config.get("base_lr", default_learning_rate)

#                 if group_config.get("llrd", False):
#                     layer_lr *= layerwise_learning_rate_decay_mulitplier
#                     optimizer_parameter_groups.append(
#                         {
#                             "params": parameter,
#                             "weight_decay": weight_decay,
#                             "lr": layer_lr,
#                             "name": f"{group_config['prefix']}_decay",
#                         }
#                     )

#                 else:
#                     optimizer_parameter_groups.append(
#                         {
#                             "params": parameter,
#                             "weight_decay": weight_decay,
#                             "lr": layer_lr,
#                             "name": f"{group_config['prefix']}_decay",
#                         }
#                     )
#                 applied = True
#                 break

#         if not applied:
#             optimizer_parameter_groups.append(
#                 {
#                     "params": parameter,
#                     "weight_decay": weight_decay,
#                     "lr": default_learning_rate,
#                     "name": "default",
#                 }
#             )

#     return optimizer_parameter_groups


# group_configs = [
#     {"prefix": "deberta.encoder", "base_lr": 1e-4, "llrd": True},
#     {"prefix": "deberta.embeddings", "base_lr": 1e-5, "llrd": False},
#     {"prefix": "pooler", "base_lr": 1e-3, "llrd": False},
#     {"prefix": "classifier", "base_lr": 1e-3, "llrd": False},
# ]

In [9]:
weight_decay = 0.01
learning_rate = 1e-2
layerwise_learning_rate_decay_mulitplier = 0.9
num_epochs = 20
num_warmup_steps = 0

In [10]:
# grouped_optimizer_params = get_optimizer_grouped_parameters_by_layer(
#     model=BASE_MODEL,
#     default_learning_rate=learning_rate,
#     default_weight_decay=weight_decay,
#     layerwise_learning_rate_decay_mulitplier=layerwise_learning_rate_decay_mulitplier,
#     group_configs=group_configs
# )

# # grouped_optimizer_params

In [11]:
grouped_optimizer_params = get_optimizer_grouped_parameters_per_category(
    model=BASE_MODEL,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    layerwise_learning_rate_decay_mulitplier=layerwise_learning_rate_decay_mulitplier,
)

# grouped_optimizer_params

In [12]:
optimizer = AdamW(
    grouped_optimizer_params,
    lr=learning_rate,
    weight_decay=weight_decay
    )

scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_epochs
)

optimizer

AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 0.01
    lr: 0.01
    maximize: False
    name: pooler_no_decay
    weight_decay: 0.0

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 0.01
    lr: 0.01
    maximize: False
    name: pooler_decay
    weight_decay: 0.01

Parameter Group 2
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 0.01
    lr: 0.01
    maximize: False
    name: head_no_decay
    weight_decay: 0.0

Parameter Group 3
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 0.01
    lr: 0.01
    maximize: False
 

In [13]:
# Gather all model parameters with names
model_params = {name: param for name, param in BASE_MODEL.named_parameters()}

# Gather all optimizer parameters
opt_params = set(param for group in optimizer.param_groups for param in group['params'])

# Check if all parameters are covered
uncovered_params = {name: param for name, param in model_params.items() if param not in opt_params}
if uncovered_params:
    print("Some parameters are not covered by the optimizer:")
    for name in uncovered_params:
        print(name)
else:
    print("All parameters are covered by the optimizer.")

Some parameters are not covered by the optimizer:
deberta.encoder.rel_embeddings.weight
deberta.encoder.LayerNorm.weight
deberta.encoder.LayerNorm.bias


In [14]:
def collect_lr(optimizer, learning_rate_storage):
    """ Collects learning rates for each parameter group in the optimizer and appends them to corresponding storage lists.

    Args:
        optimizer (Optimizer): The PyTorch optimizer from which to collect learning rates.
        learning_rate_storage (list of tuples): Storage for each parameter group's learning rates,
                                                where each element is a tuple (group_name, list_of_lrs).
    """
    for i, param_group in enumerate(optimizer.param_groups):
        learning_rate_storage[i][1].append(param_group['lr'])

# Initialize storage for learning rates for each parameter group
num_param_groups = len(optimizer.param_groups)
# Assuming each param group might already have a 'name' attribute; if not, you could add this.
# Initialize storage for learning rates for each parameter group with their names
learning_rate_storage = [(pg.get('name', f'Group {i}'), []) for i, pg in enumerate(optimizer.param_groups)]
collect_lr(optimizer, learning_rate_storage)


In [15]:
for epoch in range(num_epochs):
    # Train your model
    optimizer.step()
    # Step the scheduler
    scheduler.step()

    # Collect the learning rates
    collect_lr(optimizer, learning_rate_storage)

In [17]:
import plotly.graph_objs as go
import plotly.offline as pyo
import plotly.graph_objs as go
from plotly.offline import init_notebook_mode, iplot
import plotly.graph_objs as go
from plotly.offline import init_notebook_mode, iplot

init_notebook_mode(connected=True)  # This makes plotly display in the notebook

# Prepare the plot
traces = []
epochs = list(range(len(learning_rate_storage[0][1])))  # Assuming all groups have the same number of entries

for group_name, rates in learning_rate_storage:
    traces.append(go.Scatter(
        x=epochs,
        y=rates,
        mode='lines+markers',
        name=group_name  # Use the stored group name for the trace label
    ))

# Setting up a clean layout
layout = go.Layout(
    title='Learning Rate per Epoch by Parameter Group',
    xaxis=dict(title='Epoch'),
    yaxis=dict(title='Learning Rate'),
    template='plotly_white'
)

fig = go.Figure(data=traces, layout=layout)
iplot(fig)  # Display the figure inline in a notebook

pyo.plot(fig, filename='learning_rates.html')  # This will save the plot to an HTML file and open it in your browser


'learning_rates.html'