In [5]:
import torch

class PointDataset(torch.utils.data.Dataset):
    def __init__(self, filename):
        self.data = []
        
        with open(filename, 'r') as f:
            for line in f:
                x, y = line.split(" ")
                x, y = float(x), float(y)
                self.data.append((x, y))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]


ds = PointDataset("dataset1.txt")

(-9.220342264640013, -26.542222567962877)

In [7]:
class LineModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.w = torch.nn.Parameter(torch.rand(1))
    
    def forward(self, x):
        return self.w * x
    

model = LineModule()
print(list(model.parameters()))
print(model(torch.tensor([2.0])))

[Parameter containing:
tensor([0.4839], requires_grad=True)]
tensor([0.9679], grad_fn=<MulBackward0>)


In [10]:
from tqdm import trange

ds = PointDataset("dataset1.txt")
model = LineModule()
dl = torch.utils.data.DataLoader(ds, batch_size=8)

optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
loss_fn = torch.nn.MSELoss()

for epoch in trange(1000):
    for batch in dl:
        x, y = batch
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

100%|██████████| 1000/1000 [00:01<00:00, 619.21it/s]


In [11]:
print(model.w)

Parameter containing:
tensor([2.3977], requires_grad=True)
