# Batch & Mini Batch
## Dataset & DataLoader
- 미니배치를 잘 수행하도록 도와주는 모듈

In [2]:
import torch
from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__(self, data_path):
        # 데이터 불러오기
        import numpy as np
        
        with open(data_path, 'r', encoding='utf8') as file:
            data = file.read().splitlines()
            data = np.array([line.split('\t') for line in data],
                           dtype=np.float)
            
        """
        데이터는 1000*4인 2차원 행렬
        첫 3열은 입력 데이터, 마지막 열은 타깃 데이터
        """
        self.x = data[:, :-1]
        self.t = data[:, -1]
        
    def __getitem__(self, index):
        # index에 해당하는 데이터를 (입력,타깃) 형태로 가져옴
        return (self.x[index], self.t[index])
    
    def __len__(self):
        # 데이터의 총 크기 (배치)
        return len(self.x)
    
    def custom_collate_fn(self, data):
        # __getitem__ 으로 가져온 data=(x,t) 데이터를 전처리
        # 여기서는 텐서로 전환
        x, t = list(zip(*data))
        return (torch.FloatTensor(x), torch.LongTensor(t))
    
example_dataset_path = "./data/example_dataset.tsv"

# 데이터셋 선언
custom_dataset = CustomDataset(data_path = example_dataset_path)

# 데이터로더 선언
train_loader = DataLoader(dataset=custom_dataset,
                         batch_size=64, # 미니배치 크기
                         collate_fn=custom_dataset.custom_collate_fn,
                          # 데이터 전처리
                          # collate: 수집하다, 모으다
                         shuffle=True) # 셔플도 가능

# 테스트
for x,y in train_loader:
    print("size of mini-batch x: {}, t: {}".format(x.size(), y.size()))
    break

size of mini-batch x: torch.Size([64, 3]), t: torch.Size([64])


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  dtype=np.float)
  return (torch.FloatTensor(x), torch.LongTensor(t))
  return (torch.FloatTensor(x), torch.LongTensor(t))


In [3]:
train_loader

<torch.utils.data.dataloader.DataLoader at 0x22ea2320730>

In [12]:
import numpy as np

data = np.array([line.split('\t') for line in data], dtype=np.float)

print(data)

[[0.95254204 0.07071162 0.54187905 1.        ]
 [0.81424592 0.99335066 0.12949903 0.        ]
 [0.14244222 0.81085676 0.21760134 0.        ]
 ...
 [0.89683477 0.52585911 0.26526819 0.        ]
 [0.45416235 0.42939549 0.75025141 0.        ]
 [0.63482525 0.88791162 0.86979562 1.        ]]


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  data = np.array([line.split('\t') for line in data], dtype=np.float)


In [13]:
type(data)

numpy.ndarray

In [14]:
len(data)

1000

In [15]:
data.shape

(1000, 4)

In [18]:
data[:, :-1]

array([[0.95254204, 0.07071162, 0.54187905],
       [0.81424592, 0.99335066, 0.12949903],
       [0.14244222, 0.81085676, 0.21760134],
       ...,
       [0.89683477, 0.52585911, 0.26526819],
       [0.45416235, 0.42939549, 0.75025141],
       [0.63482525, 0.88791162, 0.86979562]])

In [19]:
custom_dataset[0]

(array([0.95254204, 0.07071162, 0.54187905]), 1.0)