In [10]:
%matplotlib widget

In [11]:
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from pathlib import Path

DATA_DIR = Path('../data')

$$
\hat{y} = wX^T +b
$$

$$
loss = (\hat{y}-y)^2
$$

In [12]:
class Ex1Data1(Dataset):
    def __init__(self, path=None):
        path = path or DATA_DIR / 'ex1data1.txt'
        a = np.loadtxt(path, delimiter=',')
        self._x = torch.from_numpy(a[:, :-1])
        self._y = torch.from_numpy(a[:, -1:])

    def __len__(self):
        return self._x.size(0)

    def __getitem__(self, index):
        return self._x[index], self._y[index]


class Ex1Data2(Dataset):
    def __init__(self, path=None, normalize=True):
        path = path or DATA_DIR / 'ex1data2.txt'
        data = torch.from_numpy(np.loadtxt(path, delimiter=','))
        if normalize:
            data = F.normalize(data, dim=0)
        self._x = data[:, :-1]
        self._y = data[:, -1:]

    def __len__(self):
        return self._x.size(0)

    def __getitem__(self, index):
        return self._x[index], self._y[index]

In [13]:
def plot_data1(data):
    fig, ax = plt.subplots()
    ax.scatter(data._x, data._y, c='red', marker='x')

    ax.set_xlabel('Profit in $10,000s')
    ax.set_ylabel('Population of City in 10,000s')
    return fig, ax

In [14]:
fig, ax = plot_data1(Ex1Data1())

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [15]:
class LinearModel(nn.Module):
    def __init__(self, in_features, out_features=1):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)

    def forward(self, x):
        out = self.linear(x)
        return out

In [16]:
data1 = Ex1Data1()
model = LinearModel(1)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

In [17]:
for epoch in range(10):
    y_pred = model(data1._x.float())
    loss = criterion(y_pred, data1._y.float())
    print(epoch, loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
# print(epoch, loss.item())

0 20.90266990661621
1 15.91232967376709
2 13.825238227844238
3 12.94411563873291
4 12.56398868560791
5 12.392047882080078
6 12.306680679321289
7 12.257369995117188
8 12.223135948181152
9 12.195263862609863


In [26]:
x = torch.arange(4, 24, 0.1)
y = model.linear.bias.item() + model.linear.weight.item()*x

fig, ax = plot_data1(Ex1Data1())
ax.plot(x, y, label='Predictionx')
fig.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [19]:
print(f'weights = {model.linear.weight.item()}')
print(f'b = {model.linear.bias.item()}')

weights = 0.7669470906257629
b = 0.30579498410224915


In [20]:
?np.float32

[1;31mInit signature:[0m [0mnp[0m[1;33m.[0m[0mfloat32[0m[1;33m([0m[0mself[0m[1;33m,[0m [1;33m/[0m[1;33m,[0m [1;33m*[0m[0margs[0m[1;33m,[0m [1;33m**[0m[0mkwargs[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m
[1;31mDocstring:[0m     
Single-precision floating-point number type, compatible with C ``float``.

:Character code: ``'f'``
:Canonical name: `numpy.single`
:Alias on this platform: `numpy.float32`: 32-bit-precision floating-point number type: sign bit, 8 bits exponent, 23 bits mantissa.
[1;31mFile:[0m           c:\applications\scoop\apps\python\3.9.1\lib\site-packages\numpy\__init__.py
[1;31mType:[0m           type
[1;31mSubclasses:[0m     


In [21]:
data2 = Ex1Data2()
model2 = LinearModel(2)
criterion2 = nn.MSELoss()
optimizer2 = torch.optim.SGD(model2.parameters(), lr=0.01)

In [22]:
for epoch in range(10):
    y_pred = model2(data2._x.float())
    loss = criterion2(y_pred, data2._y.float())
    print(epoch, loss.item())
    optimizer2.zero_grad()
    loss.backward()
    optimizer2.step()

0 0.25977739691734314
1 0.24932648241519928
2 0.2393052875995636
3 0.2296960949897766
4 0.22048193216323853
5 0.21164663136005402
6 0.20317454636096954
7 0.19505079090595245
8 0.18726100027561188
9 0.17979146540164948
