In [1]:
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
import numpy as np
import math

In [16]:
class WineDataset(Dataset):

    def __init__(self, transform=None):
        xy = np.loadtxt('./data/wine.csv', delimiter=',', dtype=np.float32, skiprows=1)
        self.X, self.y = xy[:, 1:], xy[:, [0]]
        self.n_samples = xy.shape[0]

        self.transform = transform

    def __getitem__(self, index):
        sample = self.X[index], self.y[index]
        if self.transform is not None:
            sample = self.transform(sample)
        return sample

    def __len__(self):
        return self.n_samples

In [23]:
class ToTensor:
    def __call__(self, sample):
        inputs, targets = sample
        return torch.from_numpy(inputs), torch.from_numpy(targets)

class MulTransform:
    def __init__(self, factor):
        self.factor = factor

    def __call__(self, sample):
        inputs, target = sample
        inputs *= self.factor
        return inputs, target

In [27]:
dataset = WineDataset(transform=torchvision.transforms.Compose([
    ToTensor(),
    MulTransform(4)
]))

In [28]:
first_data = dataset[0]
features, labels = first_data
print(features, labels)

tensor([5.6920e+01, 6.8400e+00, 9.7200e+00, 6.2400e+01, 5.0800e+02, 1.1200e+01,
        1.2240e+01, 1.1200e+00, 9.1600e+00, 2.2560e+01, 4.1600e+00, 1.5680e+01,
        4.2600e+03]) tensor([1.])


In [20]:
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True, num_workers=2)

In [21]:
# training loop
num_epochs = 2
total_samples = len(dataset)
n_iterations = math.ceil(total_samples / 4)
print(num_epochs, total_samples, n_iterations)

2 178 45


In [22]:
for epoch in range(num_epochs):
    for i, (inputs, labels) in enumerate(dataloader):
        # forward pass -> backward pass -> update weights
        if (i + 1) % 5 == 0:
            print(f'epoch: {epoch + 1}/{num_epochs}, step {i + 1}/{n_iterations}, inputs {inputs.shape}')

epoch: 1/2, step 5/45, inputs torch.Size([4, 13])
epoch: 1/2, step 10/45, inputs torch.Size([4, 13])
epoch: 1/2, step 15/45, inputs torch.Size([4, 13])
epoch: 1/2, step 20/45, inputs torch.Size([4, 13])
epoch: 1/2, step 25/45, inputs torch.Size([4, 13])
epoch: 1/2, step 30/45, inputs torch.Size([4, 13])
epoch: 1/2, step 35/45, inputs torch.Size([4, 13])
epoch: 1/2, step 40/45, inputs torch.Size([4, 13])
epoch: 1/2, step 45/45, inputs torch.Size([2, 13])
epoch: 2/2, step 5/45, inputs torch.Size([4, 13])
epoch: 2/2, step 10/45, inputs torch.Size([4, 13])
epoch: 2/2, step 15/45, inputs torch.Size([4, 13])
epoch: 2/2, step 20/45, inputs torch.Size([4, 13])
epoch: 2/2, step 25/45, inputs torch.Size([4, 13])
epoch: 2/2, step 30/45, inputs torch.Size([4, 13])
epoch: 2/2, step 35/45, inputs torch.Size([4, 13])
epoch: 2/2, step 40/45, inputs torch.Size([4, 13])
epoch: 2/2, step 45/45, inputs torch.Size([2, 13])
