<a href="https://colab.research.google.com/github/kimgeonhee317/d2l-notes/blob/main/notebook/3_3_Synthetic_Regression_Data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install d2l==1.0.3

In [2]:
%matplotlib inline
import random
import torch
from d2l import torch as d2l

## 3.3.1 Generating the Dataset

In [6]:
data = d2l.SyntheticRegressionData(w = torch.tensor([2, -3.4]), b= 4.2)

In [7]:
print('feature:', data.X[0], '\nlabel:', data.y[0])

feature: tensor([-0.4768,  0.4891]) 
label: tensor([1.5746])


## 3.3.2 Reading the Dataset

In [11]:
@d2l.add_to_class(SyntheticRegressionData)
def get_dataloader(self, train):
    if train:
        indices = list(range(0, self.num_train))
        # The examples are read in random order
        random.shuffle(indices)
    else:
        indices = list(range(self.num_train, self.num_train+self.num_val))
    for i in range(0, len(indices), self.batch_size):
        batch_indices = torch.tensor(indices[i: i+self.batch_size])
        yield self.X[batch_indices], self.y[batch_indices]

In [12]:
X, y = next(iter(data.train_dataloader()))
print('X shape', X.shape, '\ny shape:', y.shape)

X shape torch.Size([32, 2]) 
y shape: torch.Size([32, 1])


## 3.3.3. Concise Implementation of the Data Loader

In [13]:
@d2l.add_to_class(d2l.DataModule)  #@save
def get_tensorloader(self, tensors, train, indices=slice(0, None)):
    tensors = tuple(a[indices] for a in tensors)
    dataset = torch.utils.data.TensorDataset(*tensors)
    return torch.utils.data.DataLoader(dataset, self.batch_size,
                                       shuffle=train)

@d2l.add_to_class(SyntheticRegressionData)  #@save
def get_dataloader(self, train):
    i = slice(0, self.num_train) if train else slice(self.num_train, None)
    return self.get_tensorloader((self.X, self.y), train, i)

In [15]:
X, y = next(iter(data.train_dataloader()))
print('X shape:', X.shape, '\ny shape:', y.shape)

X shape: torch.Size([32, 2]) 
y shape: torch.Size([32, 1])


In [16]:
len(data.train_dataloader())

32