In [19]:
from collections import defaultdict
import argparse

def merge_params_dicts(param_group_1: list, param_group_2: list):
    key_param, key2 = param_group_2[0].keys()
    out = []
    for d1 in param_group_1:
        params_1 = set(d1[key_param])
        for d2 in param_group_2:
            params_2 = set(d2[key_param])
            val2 = d2[key2]
            params = params_1 & params_2
            if len(params) != 0:
                out.append(d1.copy())
                out[-1][key_param] = list(params).copy()
                out[-1][key2] = val2
    return out

def make_param_groups_for_relative_lr(args, model):
    baseline_args = [args.learning_rate, args.final_lr_fraction]
    baseline_keys = ['lr', 'fraction']
    relative_params_all = [args.relative_lr, args.relative_scheduler_fraction]
    all_groups = []
    for baseline_arg, key, relative_params in zip(baseline_args, baseline_keys, relative_params_all):
        if relative_params is None:
            all_groups.append(([{"params": model.parameters(), key: baseline_arg}], [1.0]))

        relative_to_params = defaultdict(list)
        for name, param in model.named_parameters():
            ratio = 1.0
            for possible_name in relative_params.keys():
                if possible_name in name:
                    ratio = relative_params[possible_name]
                    break
            relative_to_params[ratio * baseline_arg].append(param)
        param_grops = [
            {"params": params, key: relative_arg} for relative_arg, params in relative_to_params.items()
        ]
        all_groups.append(param_grops)

    param_groups = all_groups[0]
    for group2 in all_groups[1:]:
        param_groups = merge_params_dicts(param_groups, group2)
    
    ratios_in_group_order = {key: [] for key in baseline_keys}
    for param_group in param_groups:
        for baseline_arg, key in zip(baseline_args, baseline_keys):
            ratios_in_group_order[key].append(param_group[key] / baseline_arg)

    return param_groups, ratios_in_group_order

In [20]:
class DummyModel:
    def __init__(self):
        self.params = [('param.1', 1), ('param.2', 2), ('param.3', 3), ('arg.1', 4), ('arg.2', 5), ('arg.3', 6), ('arg.4', 7)]

    def named_parameters(self):
        return self.params

# Create a mock args object
def create_mock_args():
    args = argparse.Namespace()
    args.learning_rate = 1.0
    args.final_lr_fraction = 1.0
    args.relative_lr = {'1': 2, '2': 0.5}
    args.relative_scheduler_fraction = {'1': 2, '3': 0.5}
    return args

# Test the function
def test_make_param_groups_for_relative_lr():
    args = create_mock_args()
    model = DummyModel()

    return make_param_groups_for_relative_lr(args, model)

# Run the test
param_groups, ratios_in_group_order = test_make_param_groups_for_relative_lr()
for dct in param_groups:
    print(dct)

for arg, lista in ratios_in_group_order.items():
    print(f'arg: {arg}')
    print(lista)

{'params': [1, 4], 'lr': 2.0, 'fraction': 2.0}
{'params': [2, 5], 'lr': 0.5, 'fraction': 1.0}
{'params': [7], 'lr': 1.0, 'fraction': 1.0}
{'params': [3, 6], 'lr': 1.0, 'fraction': 0.5}
arg: lr
[2.0, 0.5, 1.0, 1.0]
arg: fraction
[2.0, 1.0, 1.0, 0.5]
