<a href="https://colab.research.google.com/github/hukim1112/one-day-DL/blob/main/(pytorch)flower_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 설정

In [None]:
import numpy as np
import os
import PIL
import PIL.Image
import tensorflow as tf

from matplotlib import pyplot as plt

In [None]:
print(tf.__version__)

### 꽃 데이터세트 다운로드하기

이 튜토리얼에서는 수천 장의 꽃 사진 데이터세트를 사용합니다. 꽃 데이터세트에는 클래스당 하나씩 5개의 하위 디렉토리가 있습니다.

```
flowers_photos/
  daisy/
  dandelion/
  roses/
  sunflowers/
  tulips/
```

참고: 모든 이미지에는 CC-BY 라이선스가 있으며 크리에이터는 LICENSE.txt 파일에 나열됩니다.

In [None]:
import pathlib
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file(origin=dataset_url,
                                   fname='flower_photos',
                                   untar=True)
data_dir = pathlib.Path(data_dir)

다운로드한 후 (218MB), 이제 꽃 사진의 사본을 사용할 수 있습니다. 총 3670개의 이미지가 있습니다.

In [None]:
image_count = len(list(data_dir.glob('*/*.jpg')))
print(image_count)

In [None]:
# 저장된 경로
data_dir

In [None]:
# 리눅스 명령어를 통해 content 폴더 아래로 복사
!cp -r /root/.keras/datasets/flower_photos /content/

각 디렉토리에는 해당 유형의 꽃 이미지가 포함되어 있습니다.

In [None]:
roses = list(data_dir.glob('roses/*'))
PIL.Image.open(str(roses[0]))

In [None]:
sunflowers = list(data_dir.glob('sunflowers/*'))
PIL.Image.open(str(sunflowers[0]))

## pytorch 데이터 파이프라인

미션 : 아래 튜토리얼을 참고해 플라워 데이터를 위한 커스텀 데이터셋 클래스를 작성해보세요.

https://tutorials.pytorch.kr/beginner/basics/data_tutorial.html#id9

### 플라워 데이터셋 클래스 정의

In [None]:
import os
from os.path import join
def get_subdir_files(root):
  categories = [sub for sub in os.listdir(root) if os.path.isdir(join(root, sub))]
  subs = [ join(root, sub) for sub in os.listdir(root) if os.path.isdir(join(root, sub))]
  files = []
  for sub in subs:
    sub_files = [ join(sub, name) for name in os.listdir(sub)]
    files += sub_files
  return files, categories

In [None]:
filelist, categories = get_subdir_files("/content/flower_photos")
print(filelist[:5])
print(categories)

In [None]:
from torch.utils.data import Dataset
from torchvision.io import read_image
from matplotlib import pyplot as plt
class FlowerDataset(Dataset):
    def __init__(self, filelist, transform=None, target_transform=None):
        self.filelist = filelist
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.filelist)

    def __getitem__(self, idx):
        img_path = self.filelist[idx]
        image = read_image(img_path)
        label = img_path.split('/')[-2]
        if self.transform:
          image = self.transform(image)
        if self.target_transform:
          label = self.target_transform(label)
        return image, label

In [None]:
flower_dataset = FlowerDataset(filelist)
x,y = next(iter(flower_dataset))

print(categories)
plt.title(f"class : {y}")
plt.imshow(x.permute(1,2,0))

### Dataset split

In [None]:
import random
random.shuffle(filelist)
dataset_size = len(filelist)

train_size = int(dataset_size * 0.8)
validation_size = int(dataset_size * 0.1)
test_size = dataset_size - train_size - validation_size

### 데이터전처리 및 증강

In [None]:
from torchvision import transforms

In [None]:
data_transforms = {
    'train': transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((128, 128)),
        transforms.ColorJitter(brightness=0.05, saturation=0.05, hue=0.05, contrast=0.05),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()]),
    'test': transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((128, 128)),
        transforms.ToTensor()]),
}

def target_transform(label):
  return categories.index(label)

train_dataset = FlowerDataset(filelist[:train_size], transform=data_transforms['train'], target_transform=target_transform)
validation_dataset = FlowerDataset(filelist[train_size:train_size+validation_size], transform=data_transforms['test'], target_transform=target_transform)
test_dataset = FlowerDataset(filelist[train_size+validation_size:], transform=data_transforms['test'], target_transform=target_transform)

In [None]:
print(len(train_dataset), len(validation_dataset), len(test_dataset))

In [None]:
x,y = next(iter(train_dataset))

plt.title(f"class : {y}")
plt.imshow(x.permute(1,2,0))

### 데이터 로더

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=64, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
# 이미지와 정답(label)을 표시합니다.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels}")

## pytorch 모델 구현

학습하고자 하는 CNN 모델을 구축해봅시다.

In [None]:
# write yout codes

## pytorch 학습


In [None]:
# write yout codes