# Pytorch

## Ray Tune
___Hyper Parameter Tuning___  
___Using Model classifying CIFAR-10___

It is referred to 
 - _https://pytorch.org/tutorials/beginner/hyperparameter_tuning_tutorial.html_
 - _https://docs.ray.io/en/latest/tune/index.html_

In [None]:
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [None]:
data_path = '/content/drive/MyDrive/Colab Notebooks/Pytorch'

In [None]:
%%capture
# install the package needed
!pip install ray[tune]

In [None]:
import numpy as np
import os
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim

from torch.utils.data import random_split
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets

# for hyper parameter tuning, only need to import 3 modules
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler

In [18]:
import matplotlib.pyplot as plt

### Data

In [None]:
def load_data(data_dir="./data"):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR10(
        root=data_dir, train=True, download=True, transform=transform)

    testset = torchvision.datasets.CIFAR10(
        root=data_dir, train=False, download=True, transform=transform)

    return trainset, testset

```
transforms.Normalize(  
    mean,  
    std,  
    inplace=False 
)
```
mean and std is applied for each channel of image  
Ex) 'mean=(0, 0, 0)' matchs 3 channel

In [None]:
path = data_path + '/RayTune'
train_ds, test_ds = load_data(path)
len(train_ds), len(test_ds)

Files already downloaded and verified
Files already downloaded and verified


(50000, 10000)

### Model

In [32]:
class Net(nn.Module):
  def __init__(self, l1, l2):
    super().__init__()
    self.conv_1 = nn.Conv2d(3, 6, 5)
    self.conv_2 = nn.Conv2d(6, 16, 5)
    self.pool = nn.MaxPool2d(2, 2)

    self.fc_1 = nn.Linear(16*5*5, l1)
    self.fc_2 = nn.Linear(l1, l2)
    self.fc_3 = nn.Linear(l2, 10)

  
  def forward(self, x):
    out = self.pool(F.relu(self.conv_1(x)))
    out = self.pool(F.relu(self.conv_2(out)))
    out = torch.flatten(out, 1)

    out = F.relu(self.fc_1(out))
    out = F.relu(self.fc_2(out))
    logits = self.fc_2(out)

    return logits

In [4]:
def train_cifar(config, checkpoint=None, data_dir=None):
  sw = True
  net = Net(config['l1'], config['l2'])

  device = 'cuda' if torch.cuda.is_available() else 'cpu'
  if device == 'cude':
    if torch.cuda.device_count() > 1:
      net = nn.DataParallel(net)
  net.to(device)

  criterion = nn.CrossEntropyLoss()
  optimizer = optim.SGD(net.parameters(), lr=config['lr'], momentum=0.9)

  train_ds, test_ds = load_data(data_dir)
  len_train = int(len(train_ds)*0.8)
  train_ds, val_ds = random_split(train_ds, [len_train, len(train_ds)-len_train])
  
  trainloader = torch.utils.data.DataLoader(
                    train_ds,
                    batch_size=int(config['batch_size']),
                    shuffle=True,
                    num_workers=8
  )
  valloader = torch.utils.data.DataLoader(
                    val_ds,
                    batch_size=int(config['batch_size']),
                    shuffle=True,
                    num_workers=8
  )

  for epoch in range(10):
    running_loss = 0.
    epoch_steps = 0
    for i, data in enumerate(trainloader):
      inputs, labels = data
      inputs, labels = inputs.to(device), labels.to(device)

      optimizer.zero_grad()

      outputs = net(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()

      running_loss += loss.item()
      epoch_steps += 1
      if i % 2000 == 1999:
        print("[%d, %5d] loss: %.3f" % (epoch+1, epoch_steps, running_loss/epoch_steps))
      
    val_loss = 0.
    val_steps = 0
    total = 0
    correct = 0
    for i, data in enumerate(valloader):
      with torch.no_grad():
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        outputs = net(inputs)
        ###
        if sw:
          print(type(outputs))
          print(outputs.data)
          sw = False
        ###
        _, pred = torch.max(outputs.data, 1)
        correct += (pred.item() == labels).sum().item()
        total += labels.size(0)
        loss = criterion(outputs, labels)
        val_loss += loss.cpu().numpy()
        val_steps += 1

