In [1]:
%matplotlib inline

In [2]:
import random
import torch

In [3]:
from d2l import torch as d2l

In [4]:
class SyntheticRegressionData(d2l.DataModule):
    def __init__(self, w, b, noise=0.01, num_train=1000, num_val=1000, batch_size=32):
        super().__init__()
        self.save_hyperparameters()
        n = num_train + num_val
        self.X = torch.randn(n, len(w))  # N(0, 1)에서 (n, 2) 만큼 뽑아라
        noise = torch.randn(n, 1) * noise
        self.y = torch.matmul(self.X, w.reshape((-1, 1))) + b + noise

In [5]:
data = SyntheticRegressionData(w=torch.tensor([2, -3.4]), b=4.2)

In [6]:
data.X.shape

torch.Size([2000, 2])

In [7]:
data.y.shape

torch.Size([2000, 1])

In [8]:
@d2l.add_to_class(SyntheticRegressionData)
def get_dataloader(self, train):
    if train:
        indices = list(range(0, self.num_train))
        random.shuffle(indices)
    else:
        indices = list(range(self,num_train, self.num_train + self.num_val))
    for i in range(0, len(indices), self.batch_size):
        batch_indices = torch.tensor(indices[i:i + self.batch_size])
        yield self.X[batch_indices], self.y[batch_indices]

In [9]:
data

<__main__.SyntheticRegressionData at 0x1fad2bf1fd0>

In [10]:
data.train_dataloader()

<generator object get_dataloader at 0x000001FAB0A190B0>

In [11]:
X, y = next(iter(data.train_dataloader()))

In [12]:
X.shape

torch.Size([32, 2])

In [13]:
y.shape

torch.Size([32, 1])

In [14]:
@d2l.add_to_class(d2l.DataModule)
def get_tensorloader(self, tensors, train, indices=slice(0, None)):  # [0:]
    tensors = tuple(a[indices] for a in tensors)
    dataset = torch.utils.data.TensorDataset(*tensors)
    return torch.utils.data.DataLoader(dataset, self.batch_size, shuffle=train)

@d2l.add_to_class(SyntheticRegressionData)
def get_dataloader(self, train):
    i = slice(0, self.num_train) if train else slice(self.num_train, None)
    return self.get_tensorloader((self.X, self.y), train, i)

In [15]:
X, y = next(iter(data.train_dataloader()))

In [16]:
X.shape

torch.Size([32, 2])

In [17]:
y.shape

torch.Size([32, 1])