In [1]:
import torch
import torch.nn as nn

In [2]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.param1 = nn.Parameter(
            torch.tensor([1, 2, 3], dtype=torch.float32)
        )
        self.param2 = torch.tensor([4, 5, 6], dtype=torch.float32)
        self.li = nn.Linear(3, 1)
        
    def forward(self, x):
        x = x * self.param1
        x = x * self.param2
        x = self.li(x)

        return x

In [3]:
model = Model()

In [4]:
for name, param in model.named_parameters():
    print(name, param)

param1 Parameter containing:
tensor([1., 2., 3.], requires_grad=True)
li.weight Parameter containing:
tensor([[-0.4049, -0.3550,  0.5401]], requires_grad=True)
li.bias Parameter containing:
tensor([-0.3106], requires_grad=True)


In [5]:
def train():
    model = Model()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    loss_fn = nn.MSELoss()

    input_data = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32)
    output_data = model(input_data)
    target_data = torch.tensor([1], dtype=torch.float32)

    loss = loss_fn(output_data, target_data)

    print("#" * 10 + "Before Training" + "#" * 10)
    print("param1:", model.param1.data)
    print("param2:", model.param2.data)
    print("li:", model.li.weight.data)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print("#" * 10 + "After Training" + "#" * 10)
    print("param1:", model.param1.data)
    print("param2:", model.param2.data)
    print("li:", model.li.weight.data)


if __name__ == "__main__":
    train()

  return F.mse_loss(input, target, reduction=self.reduction)


##########Before Training##########
param1: tensor([1., 2., 3.])
param2: tensor([4., 5., 6.])
li: tensor([[-0.4704,  0.2276,  0.1017]])
##########After Training##########
param1: tensor([2.2196, 1.0129, 2.3369])
param2: tensor([4., 5., 6.])
li: tensor([[ -3.0632,  -8.4470, -19.4589]])


> ```nn.Parameter()```를 통해 선언된 ```param1```은 학습 수행하여 값이 변경되었으나, ```tensor```로 선언된 ```param2```는 학습을 하지 못해서 값이 그대로임.