# PyTorch `state_dict` Tutorial

In [27]:
import torch
import torch.nn as nn

# --- 1. Standard Python __dict__ ---

In [30]:

class MyClass:
    def __init__(self):
        self.a = 10
        self.b = 20

obj = MyClass()
print("1. Standard Python __dict__:")
# shows all attributes of the object
print(obj.__dict__)  
print()

1. Standard Python __dict__:
{'a': 10, 'b': 20}



# --- 2. Empty state_dict when no parameters ---

In [31]:
class NetEmpty(nn.Module):
    def __init__(self):
        super().__init__()
        self.a = 2  # just a plain Python int

model_empty = NetEmpty()
print("2. Model with no parameters -> state_dict():")
print(model_empty.state_dict())
print()

2. Model with no parameters -> state_dict():
OrderedDict()



# --- 3. state_dict with nn.Linear layer ---

In [32]:
linear_layer = nn.Linear(3, 2)
print("3. state_dict of nn.Linear:")
print(linear_layer.state_dict())
print("Type of weight:", type(linear_layer.weight))
print("Type of bias:", type(linear_layer.bias))
print()

3. state_dict of nn.Linear:
OrderedDict([('weight', tensor([[ 0.4049, -0.3305,  0.4929],
        [-0.1182, -0.1007,  0.5056]])), ('bias', tensor([-0.0890,  0.5624]))])
Type of weight: <class 'torch.nn.parameter.Parameter'>
Type of bias: <class 'torch.nn.parameter.Parameter'>



# --- 4. Custom module with both parameter and normal attribute ---

In [37]:
class NetWithLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 5)   # has weight and bias
        self.a = 2                   # normal attribute

model = NetWithLayer()
print("4. state_dict with layer + normal attribute:")
print(model.state_dict().keys())  # only parameters appear, not `a`
print()

4. state_dict with layer + normal attribute:
odict_keys(['fc.weight', 'fc.bias'])



# --- 5. Registering a parameter manually ---

In [38]:
class NetWithParam(nn.Module):
    def __init__(self):
        super().__init__()
        self.a = nn.Parameter(torch.tensor(2.0))  # now tracked

model_param = NetWithParam()
print("5. state_dict after registering custom nn.Parameter:")
print(model_param.state_dict())
print()

5. state_dict after registering custom nn.Parameter:
OrderedDict([('a', tensor(2.))])



# --- 6. Registering a buffer (non-learnable but saved) ---

In [39]:
class NetWithBuffer(nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("a", torch.tensor(2.0))

model_buffer = NetWithBuffer()
print("6. state_dict after registering buffer:")
print(model_buffer.state_dict())
print()

6. state_dict after registering buffer:
OrderedDict([('a', tensor(2.))])



# --- 7. Saving and loading state_dict ---

In [34]:
path = "model.pth"
torch.save(model.state_dict(), path)

# new instance
new_model = NetWithLayer()
new_model.load_state_dict(torch.load(path))
print("7. Loaded state_dict into new model:")
print(new_model.state_dict().keys())

7. Loaded state_dict into new model:
odict_keys(['fc.weight', 'fc.bias'])


# --- 8. Optimizer state_dict ---

In [40]:
optimizer = torch.optim.SGD(new_model.parameters(), lr=0.01, momentum=0.9)
print("8. Optimizer state_dict (truncated keys):")
print(optimizer.state_dict().keys())

8. Optimizer state_dict (truncated keys):
dict_keys(['state', 'param_groups'])
