In [87]:
from torch.nn import Module, Parameter, MSELoss, Linear, LazyLinear
from torch import Tensor
from torch import nn
import torch as th
from typing import Any, cast, Self
from collections.abc import Callable, Mapping, MutableMapping
from math import inf

In [88]:
def flat_to_params(func: Module) -> Callable[[Tensor], Mapping[str, Tensor]]:
    shape_dict = {k: p.shape for k, p in func.state_dict().items()}
    sizes_dict = {k: p.flatten().shape[0] for k, p in func.state_dict().items()}
    def _map_flat_to_params(flat: Tensor) -> Mapping[str, Tensor]:
        params: MutableMapping[str, Tensor] = {}
        for key, size in sizes_dict.items():
            prev = flat[:size]; flat = flat[size:]
            params[key] = prev.reshape(shape_dict[key])
        return params
    return _map_flat_to_params

In [89]:
class MLP(Module):
    mlp: Module

    def __init__(self, widths: list[int]):
        super().__init__() # type: ignore
        assert 2 < len(widths), f"Need at least input and output dimensions; got {widths}"
        self.mlp = nn.Sequential()
        for i in range(len(widths) - 1):
            self.mlp.append(nn.Linear(widths[i], widths[i + 1]))
            if i < len(widths) - 2:
                self.mlp.append(nn.ReLU())
        self.mlp.append(nn.Sigmoid())
    
    def forward(self, x: Tensor) -> Tensor:
        return self.mlp(x)

In [90]:
mlp = MLP([2, 42, 42, 1])
mlp

MLP(
  (mlp): Sequential(
    (0): Linear(in_features=2, out_features=42, bias=True)
    (1): ReLU()
    (2): Linear(in_features=42, out_features=42, bias=True)
    (3): ReLU()
    (4): Linear(in_features=42, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

In [91]:
for key, p in mlp.state_dict().items():
    print(f"{key}: {p.shape}")

mlp.0.weight: torch.Size([42, 2])
mlp.0.bias: torch.Size([42])
mlp.2.weight: torch.Size([42, 42])
mlp.2.bias: torch.Size([42])
mlp.4.weight: torch.Size([1, 42])
mlp.4.bias: torch.Size([1])


In [92]:
state_dict = mlp.state_dict()
flat_params = th.cat([v.flatten() for v in state_dict.values()])
for key, p in flat_to_params(mlp)(flat_params).items():
    print(f"{key}: {p.shape} {(p - state_dict[key]).mean()}")

mlp.0.weight: torch.Size([42, 2]) 0.0
mlp.0.bias: torch.Size([42]) 0.0
mlp.2.weight: torch.Size([42, 42]) 0.0
mlp.2.bias: torch.Size([42]) 0.0
mlp.4.weight: torch.Size([1, 42]) 0.0
mlp.4.bias: torch.Size([1]) 0.0


In [94]:
mlp.load_state_dict(flat_to_params(mlp)(flat_params + 12.))

<All keys matched successfully>