# Multi GPU 사용하기
* DataParallel 을 사용하여 여러 GPU를 사용하는 방법
* GPU가 여러개 있을 때, 그중 일부 들만 선택하여 학습하는 방법
* 핵심 :
  * `os.environ['CUDA_VISIBLE_DEVICE'] = "0, 1, 2"`
  * `model = nn.DataParallel(model)`

## st1 params setting

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import os

input_size = 5
output_size = 5
batch_size = 30
data_size = 100

# device = torch.device("cuda:6" if torch.cuda.is_available() else "cpu")
# 전체가 아닌 일부 GPU만 설정하고 싶다면 아래 처럼!!
os.environ['CUDA_VISIBLE_DEVICES'] = "6, 7"

## st2. dataset (dummy)

In [2]:
class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)
    
    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return self.len

dataset = RandomDataset(input_size, data_size)
rand_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

## st3. model setting

In [3]:
class Model(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.fc = nn.Linear(input_size, output_size)
    
    def forward(self, input):
        output = self.fc(input)
        print(f'Model input size {input.size()}')
        print(f'output size {output.size()}')
        
        return output

**multi GPU cuda setting!**

In [4]:
model = Model(input_size, output_size)
if torch.cuda.device_count() > 1:
    print(f"GPUs count : {torch.cuda.device_count()}")
    model = nn.DataParallel(model)

# model.to(device)
model.cuda()

GPUs count : 2


DataParallel(
  (module): Model(
    (fc): Linear(in_features=5, out_features=5, bias=True)
  )
)

## st4. training processing

In [6]:
for data in rand_loader:
#     input = data.to(device)
    input = data.cuda()
    
    output = model(input)
    
    # 병렬처리가 합쳐짐을 확인
    print("Outside: input size", input.size(),
          "output_size", output.size())

Model input size torch.Size([15, 5])
output size torch.Size([15, 5])
Model input size torch.Size([15, 5])
output size torch.Size([15, 5])
Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 5])
Model input size torch.Size([15, 5])
output size torch.Size([15, 5])
Model input size torch.Size([15, 5])
output size torch.Size([15, 5])
Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 5])
Model input size torch.Size([15, 5])
output size torch.Size([15, 5])
Model input size torch.Size([15, 5])
output size torch.Size([15, 5])
Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 5])
Model input size torch.Size([5, 5])
output size torch.Size([5, 5])
Model input size torch.Size([5, 5])
output size torch.Size([5, 5])
Outside: input size torch.Size([10, 5]) output_size torch.Size([10, 5])
