In [None]:
import cv2
import h5py
import numpy as np
import torch
import torch.utils.data as data
import imgaug.augmenters as iaa
# loader
from dataset.classification.loader import create_validation_split, load_validation_data
# sampler
from dataset.classification.sampler import adopt_sampling

In [None]:
from os.path import join as pj
from os import getcwd as cwd

In [None]:
class args:
    all_data_path = pj(cwd(), "data/all_classification_data/classify_insect_std_20200806")

### コード全体

### --- データの構築 ---  
X,Yが学習、テストで共有なので、外部に直接生成

In [None]:
with h5py.File(args.all_data_path, "r") as f:
    X = f["X"][:]
    Y = f["Y"][:]

In [None]:
_, ntests = np.unique(Y, return_counts=True)

In [None]:
train_idxs, test_idxs = create_validation_split(Y, 0.2)

In [None]:
valid_count = 0

In [None]:
valid_train_idx = adopt_sampling(Y, train_idxs[valid_count], None)
valid_test_idx = test_idxs[valid_count]

In [None]:
xtr, ytr, xte, yte = load_validation_data(X, Y, valid_train_idx, valid_test_idx)

In [None]:
xtr.shape

In [None]:
ytr.shape

In [None]:
xte.shape

In [None]:
yte.shape

### --- データの読み込み ---
主にデータ拡張を適用するためのクラス

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
class insects_dataset(data.Dataset):
    
    def __init__(self, images, labels, training=False, method_aug=None):
        """
            init function
            Args:
                - images: np.array, insect images
                - labels: np.array, insect labels
                - training: bool
                - method_aug: [str, ...], sequence of method name
                    possible choices = [
                        HorizontalFlip, VerticalFlip, Rotate]
        """
        self.images = images
        self.labels = labels
        self.training = training
        self.method_aug = method_aug
        
        if training is True and method_aug is not None:
            print("augment == method_aug")
            print("---")
            self.aug_seq = self.create_aug_seq()
            print("---")
        else:
            print("augment == None")
            self.aug_seq = None
        
    def __getitem__(self, index):
        # adopt augmentation
        if self.aug_seq is not None:
            image_aug = self.aug_seq(image=self.images[index])
        else:
            image_aug = self.images[index]
        
        # normalize
        image_aug = image_aug.astype("float32")
        image_aug = cv2.normalize(image_aug, image_aug, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX)
        
        # create pytorch image
        image_aug = image_aug.transpose(2,0,1).astype("float32")
        image_aug = torch.from_numpy(image_aug).clone()
        
        label = self.labels[index]
        return image_aug, label
    
    def __len__(self):
        return self.images.shape[0]
    
    def create_aug_seq(self):
        aug_list = []
        # create augmentation
        for augmentation in self.method_aug:
            if augmentation == "HorizontalFlip":
                print("HorizontalFlip")
                aug_list.append(iaa.Fliplr(0.5))
            elif augmentation == "VerticalFlip":
                print("VerticalFlip")
                aug_list.append(iaa.Flipud(0.5))
            elif augmentation == "Rotate":
                print("Rotate")
                aug_list.append(iaa.Rotate((-90, 90)))
            else:
                print("not implemented!: insects_dataset.create_aug_seq")
        
        aug_seq = iaa.Sequential(aug_list)
        return aug_seq

In [None]:
dataset = insects_dataset(xtr, ytr, training=True, method_aug=["Rotate"])

In [None]:
data_loader = data.DataLoader(dataset, 2, num_workers=0, shuffle=True)

In [None]:
data_loader.dataset.labels

### 画像、ラベル読み込み

In [None]:
with h5py.File(args.all_data_path, "r") as f:
    X = f["X"][:]
    Y = f["Y"][:]

In [None]:
X.shape

In [None]:
Y.shape

### 個別クラスのラベル枚数取り出し

In [None]:
_, ntests = np.unique(Y, return_counts=True)

In [None]:
ntests

### 各交差検証における学習id、テストidの構築

In [None]:
train_idxs, test_idxs = create_validation_split(Y, 0.2)

### 学習idにサンプリングを適用

In [None]:
random_sampled_idx = adopt_sampling(Y, train_idxs[0], "RandomSample")
over_sampled_idx = adopt_sampling(Y, train_idxs[0], "OverSample")
normal_idx = adopt_sampling(Y, train_idxs[0], None)

In [None]:
idx, count = np.unique(Y[np.array(random_sampled_idx)], return_counts=True)
print("RandomSample: idx = {}, count = {}".format(idx, count))
idx, count = np.unique(Y[np.array(over_sampled_idx)], return_counts=True)
print("OverSample: idx = {}, count = {}".format(idx, count))
idx, count = np.unique(Y[np.array(normal_idx)], return_counts=True)
print("OverSample: idx = {}, count = {}".format(idx, count))

In [None]:
new_train_idx = normal_idx = adopt_sampling(Y, train_idxs[0], None)

### 交差検証データのロード

In [None]:
xtr, ytr, xte, yte = load_validation_data(X, Y, new_train_idx, test_idxs[0])

In [None]:
xtr.shape

In [None]:
ytr.shape

In [None]:
xte.shape

In [None]:
yte.shape