In [18]:
# encoding: utf8


import torch
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary


class Model(nn.Module):
    
    def __init__(self, input_size: int, hidden_size: int, output_size: int) -> None:
        super().__init__()
        
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)

        return x


batch_size = 32
input_size = 100
hidden_size = 1000
output_size = 10

# 定义模型、损失函数和优化器
model = Model(input_size, hidden_size, output_size)
model = model.cuda()
loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# 前向传播
x = torch.randn(batch_size, input_size).cuda()
label = torch.randn(batch_size, output_size).cuda()
y_pred = model(x)
loss = loss_fn(y_pred, label)

# 反向传播，更新梯度
optimizer.zero_grad()
loss.backward()
optimizer.step()

for name, value in list(model.named_parameters()):
    print(name, value.size(), value.grad.size(), value.device)


summary(model, (input_size,), batch_size=32, device="cuda")

print(f"cuda memory_allocated: {torch.cuda.memory_allocated()//1024} KB")
print(f"cuda max_memory_allocated: {torch.cuda.max_memory_allocated()//1024} KB")
print(f"cuda max_memory_reserved: {torch.cuda.max_memory_reserved()//1024} KB")

linear1.weight torch.Size([1000, 100]) torch.Size([1000, 100]) cuda:0
linear1.bias torch.Size([1000]) torch.Size([1000]) cuda:0
linear2.weight torch.Size([10, 1000]) torch.Size([10, 1000]) cuda:0
linear2.bias torch.Size([10]) torch.Size([10]) cuda:0
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                 [32, 1000]         101,000
              ReLU-2                 [32, 1000]               0
            Linear-3                   [32, 10]          10,010
Total params: 111,010
Trainable params: 111,010
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 0.49
Params size (MB): 0.42
Estimated Total Size (MB): 0.93
----------------------------------------------------------------
cuda memory_allocated: 21877 KB
cuda max_memory_allocated: 22910 KB
cuda max_memory_reserved: 28672 KB


In [15]:
for i, p in enumerate(optimizer.param_groups[0]["params"]):
    print(f"optimizer parameter {i}, dtype: {p.dtype}, shape {p.size()}")

optimizer parameter 0, dtype: torch.float32, shape torch.Size([1000, 100])
optimizer parameter 1, dtype: torch.float32, shape torch.Size([1000])
optimizer parameter 2, dtype: torch.float32, shape torch.Size([10, 1000])
optimizer parameter 3, dtype: torch.float32, shape torch.Size([10])


In [20]:
state_dict = optimizer.state_dict()

state = state_dict["state"]
param_groups = state_dict["param_groups"]

In [54]:
for p in optimizer.param_groups[0]["params"]:
    print(p.size())

torch.Size([1000, 100])
torch.Size([1000])
torch.Size([10, 1000])
torch.Size([10])


In [59]:
for v in state_dict["state"].values():
    if isinstance(v, torch.Tensor):
        print(v.size())

In [70]:
def recursive_print_params(state):
    if isinstance(state, torch.Tensor):
        print(hex(id(state)), state.size())
    elif isinstance(state, dict):
        for v in state.values():
            recursive_print_params(v)
    elif isinstance(state, list):
        for v in state:
            recursive_print_params(v)

In [87]:
def recursive_print_params(state, key=""):
    if isinstance(state, torch.Tensor):
        print(hex(id(state)), key, state.size())
    elif isinstance(state, dict):
        for k, v in state.items():
            recursive_print_params(v, ".".join([key, str(k)]))
    elif isinstance(state, list):
        for i, v in enumerate(state):
            recursive_print_params(v, ".".join([key, str(i)]))


In [88]:
recursive_print_params(state)

0x7fdef27b4ae0 .0.step torch.Size([])
0x7fdef27b4630 .0.exp_avg torch.Size([1000, 100])
0x7fdef27b4310 .0.exp_avg_sq torch.Size([1000, 100])
0x7fdef27b4950 .1.step torch.Size([])
0x7fdef27b4810 .1.exp_avg torch.Size([1000])
0x7fdef27b41d0 .1.exp_avg_sq torch.Size([1000])
0x7fdef27b45e0 .2.step torch.Size([])
0x7fdef27b46d0 .2.exp_avg torch.Size([10, 1000])
0x7fdef27b44f0 .2.exp_avg_sq torch.Size([10, 1000])
0x7fdef27b42c0 .3.step torch.Size([])
0x7fdef27b4b30 .3.exp_avg torch.Size([10])
0x7fdef27b4b80 .3.exp_avg_sq torch.Size([10])


In [73]:
for p in list(model.parameters()):
    print(hex(id(p)), p.size())

0x7fdeff48c900 torch.Size([1000, 100])
0x7fdef80ca4a0 torch.Size([1000])
0x7fdef27b49f0 torch.Size([10, 1000])
0x7fdef27b4a40 torch.Size([10])


In [78]:
for p in optimizer.param_groups[0]["params"]:
    print(hex(id(p)), p.size())

0x7fdeff48c900 torch.Size([1000, 100])
0x7fdef80ca4a0 torch.Size([1000])
0x7fdef27b49f0 torch.Size([10, 1000])
0x7fdef27b4a40 torch.Size([10])
