In [2]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split


# 선형 회귀 데이터셋 클래스
class LinearRegressionDataset(Dataset):
  # 샘플 갯수
  # m : 기울기
  # b : 절편
  def __init__(self, N=50, m=-3, b=2, *args, **kwargs):
    # N: number of samples, e.g. 50
    # m: slope
    # b: offset
    super().__init__(*args, **kwargs)

    # 랜덤 입력 생성
    self.x = torch.rand(N, 2)
    # 노이즈 추가
    self.noise = torch.rand(N) * 0.2
    self.m = m
    self.b = b
    # 값 연산
    self.y = (torch.sum(self.x * self.m) + self.b + self.noise).unsqueeze(-1)

  # 데이터 샘플 수
  def __len__(self):
    return len(self.x)
  
  # 인덱스에 해당하는 샘플 반환
  def __getitem__(self, idx):
    return self.x[idx], self.y[idx]

  # 데이터셋 정보를 문자열로 반환
  def __str__(self):
    str = "Data Size: {0}, Input Shape: {1}, Target Shape: {2}".format(
      len(self.x), self.x.shape, self.y.shape
    )
    return str


if __name__ == "__main__":
  # 데이터셋 인스턴스 생성
  linear_regression_dataset = LinearRegressionDataset()
  
  # 데이터셋 정보 출력
  print(linear_regression_dataset)

  print("#" * 50, 1)

  # 데이터셋 샘플 출력
  for idx, sample in enumerate(linear_regression_dataset):
    input, target = sample
    print("{0} - {1}: {2}".format(idx, input, target))

  # 데이터셋 분리 ( 7: 2: 1)
  train_dataset, validation_dataset, test_dataset = random_split(linear_regression_dataset, [0.7, 0.2, 0.1])

  print("#" * 50, 2)

  print(len(train_dataset), len(validation_dataset), len(test_dataset))

  print("#" * 50, 3)

  # 위에서 생성한 클래스에 넣어 배치 단위로 샘플 로드
  train_data_loader = DataLoader(
    dataset=train_dataset,
    batch_size=4,
    shuffle=True
  )

  for idx, batch in enumerate(train_data_loader):
    input, target = batch
    print("{0} - {1}: {2}".format(idx, input, target))


Data Size: 50, Input Shape: torch.Size([50, 2]), Target Shape: torch.Size([50, 1])
################################################## 1
0 - tensor([0.7773, 0.6902]): tensor([-145.0736])
1 - tensor([0.7356, 0.5330]): tensor([-144.9460])
2 - tensor([0.2819, 0.3951]): tensor([-144.9320])
3 - tensor([0.5880, 0.1353]): tensor([-144.9265])
4 - tensor([0.6957, 0.5215]): tensor([-145.0387])
5 - tensor([0.8713, 0.0659]): tensor([-144.9936])
6 - tensor([0.0066, 0.4594]): tensor([-145.0804])
7 - tensor([0.9534, 0.9841]): tensor([-145.0412])
8 - tensor([0.9626, 0.7372]): tensor([-144.9078])
9 - tensor([0.6990, 0.6643]): tensor([-144.9232])
10 - tensor([0.7659, 0.1426]): tensor([-145.0145])
11 - tensor([0.2814, 0.6219]): tensor([-144.9377])
12 - tensor([0.2064, 0.1072]): tensor([-144.9816])
13 - tensor([0.3411, 0.4145]): tensor([-144.9429])
14 - tensor([0.5710, 0.5827]): tensor([-145.0543])
15 - tensor([0.7044, 0.2401]): tensor([-145.0254])
16 - tensor([0.1678, 0.9024]): tensor([-144.9134])
17 - te