In [4]:
import torch
import torch.optim as optim

In [2]:
x_train = torch.FloatTensor(
    [[73, 80, 75], [93, 88, 93], [89, 91, 80], [96, 98, 100], [73, 66, 70]]
)
y_train = torch.FloatTensor([[152], [185], [180], [196], [142]])

x_train.shape, y_train.shape

(torch.Size([5, 3]), torch.Size([5, 1]))

In [3]:
W = torch.zeros((3, 1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)

In [5]:
optimizer = optim.SGD([W, b], lr=1e-5)

In [6]:
nb_epochs = 20

In [8]:
for epoch in range(nb_epochs + 1):
    H = x_train.matmul(W) + b

    cost = torch.mean((H - y_train) ** 2)

    optimizer.zero_grad()
    cost.backward()
    optimizer.step()

    print(
        f"Epoch {epoch:4d}/{nb_epochs} hypothesis: {H.squeeze().detach()}, Cost: {cost.item():.6f}"
    )

Epoch    0/20 hypothesis: tensor([66.7178, 80.1701, 76.1025, 86.0194, 61.1565]), Cost: 9537.694336
Epoch    1/20 hypothesis: tensor([104.5421, 125.6208, 119.2478, 134.7861,  95.8280]), Cost: 3069.590820
Epoch    2/20 hypothesis: tensor([125.9858, 151.3882, 143.7087, 162.4333, 115.4844]), Cost: 990.670288
Epoch    3/20 hypothesis: tensor([138.1429, 165.9963, 157.5768, 178.1071, 126.6283]), Cost: 322.481964
Epoch    4/20 hypothesis: tensor([145.0350, 174.2780, 165.4395, 186.9928, 132.9461]), Cost: 107.717064
Epoch    5/20 hypothesis: tensor([148.9423, 178.9731, 169.8976, 192.0301, 136.5279]), Cost: 38.687401
Epoch    6/20 hypothesis: tensor([151.1574, 181.6347, 172.4254, 194.8856, 138.5585]), Cost: 16.499046
Epoch    7/20 hypothesis: tensor([152.4131, 183.1435, 173.8590, 196.5042, 139.7097]), Cost: 9.365656
Epoch    8/20 hypothesis: tensor([153.1250, 183.9988, 174.6723, 197.4216, 140.3625]), Cost: 7.071105
Epoch    9/20 hypothesis: tensor([153.5285, 184.4835, 175.1338, 197.9415, 140.7325

In [9]:
# torch.no_grad()를 사용해 계산 그래프를 생성하지 않도록 해 성능 향상
with torch.no_grad():
    new_input = torch.FloatTensor([[75, 85, 72]])
    prediction = new_input.matmul(W) + b
    print(f"new_input: {new_input.squeeze().tolist()} prediction: {prediction.item()}")

new_input: [75.0, 85.0, 72.0] prediction: 156.80615234375
