In [1]:
import pandas as pd
from glob import glob
import os
import re
import cv2
import numpy as np
import math

In [6]:
import numpy as np
import cv2
from imblearn.over_sampling import SMOTE
import albumentations as A
import os
from glob import glob
from sklearn.preprocessing import LabelEncoder

# Hàm đọc và resize frame
def read_and_resize_frame(path, target_size=(224, 224)):
    frame = cv2.imread(path)
    if frame is not None:
        return cv2.resize(frame, target_size)
    return np.zeros((*target_size, 3), dtype=np.uint8)

# Hàm đọc tất cả các frame từ một thư mục
def read_frames_from_directory(directory, max_frames=None, target_size=(224, 224)):
    image_paths = sorted(glob(os.path.join(directory, '*')))
    if max_frames:
        image_paths = image_paths[:max_frames]
    frames = [read_and_resize_frame(path, target_size) for path in image_paths]
    return np.array(frames)

# Hàm chuẩn hóa số lượng frame
def normalize_frames(frames, target_frames):
    if len(frames) >= target_frames:
        return frames[:target_frames]
    else:
        padding = np.zeros((target_frames - len(frames), *frames.shape[1:]), dtype=frames.dtype)
        return np.vstack((frames, padding))

# Giả sử X là list chứa đường dẫn đến các thư mục chứa frame và y là nhãn tương ứng
# X: List of strings (paths to directories)
# y: numpy array of labels

# Xác định số lượng frame tối đa và kích thước frame
max_frames = max(len(glob(os.path.join(directory, '*'))) for directory in X)
target_size = (224, 224)  # Có thể điều chỉnh kích thước này

# Đọc và xử lý dữ liệu
X_processed = []
for directory in X:
    frames = read_frames_from_directory(directory, max_frames=max_frames, target_size=target_size)
    normalized_frames = normalize_frames(frames, max_frames)
    X_processed.append(normalized_frames)

X_processed = np.array(X_processed)

print("Shape của X_processed:", X_processed.shape)

# Reshape X_processed để SMOTE có thể xử lý
X_reshaped = X_processed.reshape(X_processed.shape[0], -1)

# Bước 1: Áp dụng SMOTE
smote = SMOTE(random_state=42)
X_resampled, y_resampled = smote.fit_resample(X_reshaped, y)

# Reshape lại sau khi áp dụng SMOTE
X_resampled = X_resampled.reshape(-1, max_frames, *target_size, 3)

# Bước 2: Áp dụng data augmentation cho lớp thiểu số (lớp 0)
augmentation = A.Compose([
    A.RandomRotate90(),
    A.Flip(),
    A.RandomBrightnessContrast(p=0.2),
    A.GaussNoise(p=0.2),
    A.OneOf([
        A.MotionBlur(p=0.2),
        A.MedianBlur(blur_limit=3, p=0.1),
        A.Blur(blur_limit=3, p=0.1),
    ], p=0.2),
    A.OneOf([
        A.OpticalDistortion(p=0.3),
        A.GridDistortion(p=0.1),
        A.ElasticTransform(p=0.1),
    ], p=0.2),
])

def augment_minority_class(X, y, num_augmentations=1):
    augmented_sequences = []
    augmented_labels = []
    
    for sequence, label in zip(X, y):
        if label == 0:  # Lớp thiểu số
            for _ in range(num_augmentations):
                augmented_frames = [augmentation(image=frame)['image'] for frame in sequence]
                augmented_sequences.append(np.array(augmented_frames))
                augmented_labels.append(label)
    
    return np.array(augmented_sequences), np.array(augmented_labels)

# Áp dụng data augmentation
X_aug, y_aug = augment_minority_class(X_resampled, y_resampled)

# Kết hợp dữ liệu gốc và dữ liệu đã augment
X_final = np.concatenate([X_resampled, X_aug])
y_final = np.concatenate([y_resampled, y_aug])

# Xáo trộn dữ liệu
indices = np.arange(len(X_final))
np.random.shuffle(indices)
X_final = X_final[indices]
y_final = y_final[indices]

print("Shape của X sau khi xử lý:", X_final.shape)
print("Shape của y sau khi xử lý:", y_final.shape)
print("Số lượng mẫu lớp 0:", np.sum(y_final == 0))
print("Số lượng mẫu lớp 1:", np.sum(y_final == 1))

# Lưu các chuỗi frame đã xử lý (tuỳ chọn)
# output_dir = "processed_sequences"
# os.makedirs(output_dir, exist_ok=True)

# for i, (sequence, label) in enumerate(zip(X_final, y_final)):
#     if i < len(X):  # Dữ liệu gốc
#         subfolder = os.path.basename(os.path.dirname(X[i]))
#     else:  # Dữ liệu được augment
#         subfolder = f"augmented_{i - len(X)}"
    
#     sequence_dir = os.path.join(output_dir, subfolder)
#     os.makedirs(sequence_dir, exist_ok=True)
    
#     for j, frame in enumerate(sequence):
#         cv2.imwrite(os.path.join(sequence_dir, f"frame_{j}.jpg"), frame)
    
#     # Lưu nhãn
#     with open(os.path.join(sequence_dir, "label.txt"), "w") as f:
#         f.write(str(label))

# print(f"Đã lưu các chuỗi frame đã xử lý vào thư mục {output_dir}")

Shape của X_processed: (111, 172, 224, 224, 3)
Shape của X sau khi xử lý: (249, 172, 224, 224, 3)
Shape của y sau khi xử lý: (249,)
Số lượng mẫu lớp 0: 166
Số lượng mẫu lớp 1: 83


In [None]:
# This script aims to create augmented images from one image to create a larger dataset for our cnn model
# The augmentation this script will perform on each object is 
# orig_img,grayscaled_image,random_rotation_transformation_45_image,random_rotation_transformation_65_image,random_rotation_transformation_85_image,gausian_blurred_image_13_image,gausian_blurred_image_56_image,gausian_image_3,gausian_image_6,gausian_image_9,colour_jitter_image_1,colour_jitter_image_2,colour_jitter_image_3

#call the function creating file with augmented image give path of dataset and path of folder where you want the augmented images to be stored

import PIL
import torch 
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import sys
import torchvision.transforms as T
import os

#torch.transforms

#grayscale
grayscale_transform = T.Grayscale(3)

#random rotation
random_rotation_transformation_45 = T.RandomRotation(45)
random_rotation_transformation_85 = T.RandomRotation(85)
random_rotation_transformation_65 = T.RandomRotation(65)

#Gausian Blur
gausian_blur_transformation_13 = T.GaussianBlur(kernel_size = (7,13), sigma = (6 , 9))
gausian_blur_transformation_56 = T.GaussianBlur(kernel_size = (7,13), sigma = (5 , 8))

#Gausian Noise

def addnoise(input_image, noise_factor = 0.3):
    inputs = T.ToTensor()(input_image)
    noisy = inputs + torch.rand_like(inputs) * noise_factor
    noisy = torch.clip (noisy,0,1.)
    output_image = T.ToPILImage()
    image = output_image(noisy)
    return image

#Colour Jitter

colour_jitter_transformation_1 = T.ColorJitter(brightness=(0.5,1.5),contrast=(3),saturation=(0.3,1.5),hue=(-0.1,0.1))

colour_jitter_transformation_2 = T.ColorJitter(brightness=(0.7),contrast=(6),saturation=(0.9),hue=(-0.1,0.1))

colour_jitter_transformation_3 = T.ColorJitter(brightness=(0.5,1.5),contrast=(2),saturation=(1.4),hue=(-0.1,0.5))

#Random invert

random_invert_transform = T.RandomInvert()

#Main function that calls all the above functions to create 11 augmented images from one image

def augment_image(img_path):

    #orig_image
    orig_img = Image.open(Path(img_path))

    #grayscale
    
    grayscaled_image=grayscale_transform(orig_img)
    #grayscaled_image.show()
    
    #random rotation
    random_rotation_transformation_45_image=random_rotation_transformation_45(orig_img)
    #random_rotation_transformation_45_image.show()
    
    random_rotation_transformation_85_image=random_rotation_transformation_85(orig_img)
    #random_rotation_transformation_85_image.show()
    
    random_rotation_transformation_65_image=random_rotation_transformation_65(orig_img)
    #random_rotation_transformation_65_image.show()
    
    #Gausian Blur
    
    gausian_blurred_image_13_image = gausian_blur_transformation_13(orig_img)
    #gausian_blurred_image_13_image.show()

    gausian_blurred_image_56_image = gausian_blur_transformation_56(orig_img)
    #gausian_blurred_image_56_image.show()
    
    #Gausian Noise

    gausian_image_3 = addnoise(orig_img)
    
    #gausian_image_3.show()

    gausian_image_6 = addnoise(orig_img,0.6)
    
    #gausian_image_6.show()
    
    gausian_image_9 = addnoise(orig_img,0.9)

    #gausian_image_9.show()

    #Color Jitter

    
    colour_jitter_image_1 = colour_jitter_transformation_1(orig_img)
    
    #colour_jitter_image_1.show()
    
    
    colour_jitter_image_2 = colour_jitter_transformation_2(orig_img)
    
    #colour_jitter_image_2.show()
    
    colour_jitter_image_3 = colour_jitter_transformation_3(orig_img)

    #colour_jitter_image_3.show()

    return [orig_img,grayscaled_image,random_rotation_transformation_45_image,random_rotation_transformation_65_image,random_rotation_transformation_85_image,gausian_blurred_image_13_image,gausian_blurred_image_56_image,gausian_image_3,gausian_image_6,gausian_image_9,colour_jitter_image_1,colour_jitter_image_2,colour_jitter_image_3]

#augmented_images = augment_image(orig_img_path)

def creating_file_with_augmented_images(file_path_master_dataset,file_path_augmented_images):
    
    master_dataset_folder = file_path_master_dataset
    files_in_master_dataset = os.listdir(file_path_master_dataset)
    augmented_images_folder = file_path_augmented_images
    
    counter=0
    
    for element in files_in_master_dataset:
        os.mkdir(f"{augmented_images_folder}/{element}")
        images_in_folder= os.listdir(f"{master_dataset_folder}/{element}")
        counter = counter+1
        counter2 = 0
        for image in images_in_folder:
            counter
            required_images = augment_image(f"{master_dataset_folder}/{element}/{image}")
            counter2=counter2+1
            counter3 = 0
            for augmented_image in required_images:
                counter3 = counter3 +1
                augmented_image = augmented_image.save(f"{augmented_images_folder}/{element}/{counter}_{counter2}_{counter3}_{image}")

"""images = augment_image("dog.png")

for element in images:
    element.show()"""

#augmented dataset path
augmented_dataset = "/Users/software/Desktop/sem_6/Hieroglyphics_nlp/Code_image_augmentation/augmented_images_dataset"

# master dataset path
master_dataset = "/Users/software/Desktop/sem_6/Hieroglyphics_nlp/Code_image_augmentation/Master_dataset"

# run the program

creating_file_with_augmented_images(master_dataset,augmented_dataset)

In [2]:
import pandas as pd
df = pd.read_csv('data.csv')

In [3]:
paths, label = list(df['img_path']), list(df['label'])

In [4]:
df

Unnamed: 0,img_path,label
0,img/Thang12/10_220126_well02_zid99_0,1
1,img/Thang12/11_220276A_well10_zid99_1,1
2,img/Thang12/12_220450_well03_zid99_2,1
3,img/Thang12/13_220448_well02_zid99_3,1
4,img/Thang12/14_220431_well02_zid99_4,1
...,...,...
106,img/Thang4/7_230026_well07_zid99_110,1
107,img/Thang4/7_230122_well07_zid99_111,0
108,img/Thang4/8_230119_well08_zid99_112,1
109,img/Thang4/9_220405_well09_zid99_113,0


In [5]:
df['label'].value_counts()

label
1    83
0    28
Name: count, dtype: int64

In [26]:
list_idx = [i for i, j in enumerate(label) if j == 0]
path_imbalanced = [paths[i] for i in list_idx]

In [31]:
def custom_sort_key(filename):
    match = re.search(r'frame_(\d+)', filename)
    if match:
        return int(match.group(1))
    return filename  

In [40]:
def aug_func1(img_path, select: int):
    img = cv2.imread(img_path)
    
    if select == 0:
        # Flip horizontally
        augmented = cv2.flip(img, 1)
    elif select == 1:
        # Rotate 90 degrees clockwise
        augmented = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
    elif select == 2:
        # Add Gaussian noise
        noise = np.random.normal(0, 25, img.shape).astype(np.uint8)
        augmented = cv2.add(img, noise)
    elif select == 3:
        # Adjust brightness
        brightness = 50
        augmented = cv2.add(img, (brightness,brightness,brightness,0))
    elif select == 4:
        # Apply Gaussian blur
        augmented = cv2.GaussianBlur(img, (5, 5), 0)
    elif select == 5:
        # Crop image (center crop)
        h, w = img.shape[:2]
        crop_size = min(h, w) // 2
        start_x = w // 2 - crop_size // 2
        start_y = h // 2 - crop_size // 2
        augmented = img[start_y:start_y+crop_size, start_x:start_x+crop_size]
    elif select == 6:
        # Change color space (to grayscale)
        augmented = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    elif select == 7:
        # Apply random affine transformation
        rows, cols = img.shape[:2]
        M = cv2.getRotationMatrix2D((cols/2, rows/2), 30, 1)
        augmented = cv2.warpAffine(img, M, (cols, rows))
    elif select == 8:
        # Adjust contrast
        contrast = 1.5
        augmented = cv2.convertScaleAbs(img, alpha=contrast, beta=0)
    elif select == 9:
        # Add salt and pepper noise
        prob = 0.05
        noise = np.random.random(img.shape[:2])
        augmented = img.copy()
        augmented[noise < prob/2] = 0
        augmented[noise > 1 - prob/2] = 255
    else:
        augmented = img  # Return original image if select is out of range
    
    return augmented

In [41]:
def aug_func2(img_path, select: int):
    img = cv2.imread(img_path)
    
    if select == 10:
        # Áp dụng hiệu ứng vignette
        rows, cols = img.shape[:2]
        kernel_x = cv2.getGaussianKernel(cols, cols/4)
        kernel_y = cv2.getGaussianKernel(rows, rows/4)
        kernel = kernel_y * kernel_x.T
        mask = 255 * kernel / np.linalg.norm(kernel)
        augmented = img.copy()
        for i in range(3):
            augmented[:,:,i] = augmented[:,:,i] * mask
    elif select == 11:
        # Chuyển đổi không gian màu (ví dụ: BGR sang HSV)
        augmented = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
    elif select == 12:
        # Áp dụng phép biến đổi sóng (wave transform)
        rows, cols = img.shape[:2]
        img_output = np.zeros(img.shape, dtype=img.dtype)
        for i in range(rows):
            for j in range(cols):
                offset_x = int(25.0 * math.sin(2 * 3.14 * i / 180))
                offset_y = int(25.0 * math.cos(2 * 3.14 * j / 180))
                if i+offset_y < rows and j+offset_x < cols:
                    img_output[i,j] = img[(i+offset_y)%rows,(j+offset_x)%cols]
                else:
                    img_output[i,j] = 0
        augmented = img_output
    elif select == 13:
        # Áp dụng hiệu ứng emboss
        kernel = np.array([[0,-1,-1],
                           [1,0,-1],
                           [1,1,0]])
        augmented = cv2.filter2D(img, -1, kernel)
    elif select == 14:
        # Thay đổi độ bão hòa
        hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
        hsv[:,:,1] = hsv[:,:,1] * 1.5  # Tăng độ bão hòa lên 50%
        augmented = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
    elif select == 15:
        # Áp dụng hiệu ứng cartoon
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        gray = cv2.medianBlur(gray, 5)
        edges = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 9, 9)
        color = cv2.bilateralFilter(img, 9, 300, 300)
        augmented = cv2.bitwise_and(color, color, mask=edges)
    elif select == 16:
        # Tạo hiệu ứng chuyển động (motion blur)
        kernel_motion_blur = np.zeros((15, 15))
        kernel_motion_blur[7, :] = np.ones(15)
        kernel_motion_blur = kernel_motion_blur / 15
        augmented = cv2.filter2D(img, -1, kernel_motion_blur)
    elif select == 17:
        # Thay đổi gamma
        gamma = 1.5
        invGamma = 1.0 / gamma
        table = np.array([((i / 255.0) ** invGamma) * 255 for i in np.arange(0, 256)]).astype("uint8")
        augmented = cv2.LUT(img, table)
    elif select == 18:
        # Áp dụng hiệu ứng nổi (emboss) màu
        kernel = np.array([[0,-1,-1],
                           [1,0,-1],
                           [1,1,0]])
        augmented = np.zeros_like(img)
        for i in range(3):
            augmented[:,:,i] = cv2.filter2D(img[:,:,i], -1, kernel)
    elif select == 19:
        # Tạo hiệu ứng lấp lánh (glitter effect)
        h, w = img.shape[:2]
        noise = np.random.normal(0, 50, (h, w)).astype(np.uint8)
        thresh = 200
        glitter = np.where(noise > thresh, 255, 0).astype(np.uint8)
        augmented = cv2.add(img, cv2.merge([glitter, glitter, glitter]))
    else:
        augmented = img  # Trả về ảnh gốc nếu select nằm ngoài phạm vi

    return augmented

In [42]:
def aug_func3(img_path, select: int):
    img = cv2.imread(img_path)
    
    if select == 20:
        # Áp dụng hiệu ứng mosaic
        blocks = 30
        h, w = img.shape[:2]
        xSteps = np.linspace(0, w, blocks + 1, dtype=int)
        ySteps = np.linspace(0, h, blocks + 1, dtype=int)
        augmented = img.copy()
        for i in range(blocks):
            for j in range(blocks):
                roi = img[ySteps[i]:ySteps[i + 1], xSteps[j]:xSteps[j + 1]]
                color = roi.mean(axis=0).mean(axis=0)
                augmented[ySteps[i]:ySteps[i + 1], xSteps[j]:xSteps[j + 1]] = color
    elif select == 21:
        # Áp dụng hiệu ứng đường viền neon
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        edges = cv2.Canny(gray, 100, 200)
        edges = cv2.dilate(edges, None)
        edges = cv2.bitwise_not(edges)
        augmented = cv2.addWeighted(img, 0.7, cv2.cvtColor(edges, cv2.COLOR_GRAY2BGR), 0.3, 0)
    elif select == 22:
        # Tạo hiệu ứng lốc xoáy (swirl effect)
        h, w = img.shape[:2]
        center = (w // 2, h // 2)
        M = cv2.getRotationMatrix2D(center, 0, 1.0)
        for y in range(h):
            for x in range(w):
                rho = np.sqrt((x - center[0])**2 + (y - center[1])**2)
                theta = np.arctan2(y - center[1], x - center[0])
                theta += rho / 1000
                new_x = rho * np.cos(theta) + center[0]
                new_y = rho * np.sin(theta) + center[1]
                M[0, 2] = new_x - x
                M[1, 2] = new_y - y
                img[y, x] = cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT)[y, x]
        augmented = img
    elif select == 23:
        # Áp dụng hiệu ứng phơi sáng kép (double exposure)
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        inverted = cv2.bitwise_not(gray)
        blurred = cv2.GaussianBlur(inverted, (21, 21), 0)
        pencil_sketch = cv2.divide(gray, blurred, scale=256.0)
        augmented = cv2.addWeighted(img, 0.7, cv2.cvtColor(pencil_sketch, cv2.COLOR_GRAY2BGR), 0.3, 0)
    elif select == 24:
        # Tạo hiệu ứng ảnh cũ (vintage effect)
        rows, cols = img.shape[:2]
        brightness = np.random.randint(-40, 30)
        contrast = np.random.uniform(0.8, 1.2)
        img = cv2.addWeighted(img, contrast, np.zeros(img.shape, img.dtype), 0, brightness)
        kernel = np.ones((3,3),np.float32)/9
        img = cv2.filter2D(img,-1,kernel)
        random_bright = np.random.randint(30,100)
        R, G, B = cv2.split(img)
        R = R + random_bright
        B = B + random_bright
        augmented = cv2.merge([R,G,B])
    elif select == 25:
        # Áp dụng hiệu ứng lấp lánh (sparkle effect)
        h, w = img.shape[:2]
        sparkles = np.zeros((h, w), dtype=np.uint8)
        num_sparkles = 200
        for _ in range(num_sparkles):
            x = np.random.randint(0, w)
            y = np.random.randint(0, h)
            size = np.random.randint(1, 5)
            cv2.circle(sparkles, (x, y), size, 255, -1)
        sparkles = cv2.GaussianBlur(sparkles, (5, 5), 0)
        augmented = cv2.add(img, cv2.merge([sparkles, sparkles, sparkles]))
    elif select == 26:
        # Áp dụng hiệu ứng vẽ nét (sketch effect)
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        inverted = 255 - gray
        blurred = cv2.GaussianBlur(inverted, (21, 21), 0)
        inverted_blurred = 255 - blurred
        augmented = cv2.divide(gray, inverted_blurred, scale=256.0)
    elif select == 27:
        # Tạo hiệu ứng ảnh phản chiếu (mirror reflection)
        h, w = img.shape[:2]
        reflection = cv2.flip(img[h//2:, :], 0)
        reflection = cv2.GaussianBlur(reflection, (0, 0), 10)
        augmented = np.vstack((img, reflection))
    elif select == 28:
        # Áp dụng hiệu ứng chồng màu (color overlay)
        overlay_color = np.random.randint(0, 256, 3).tolist()
        overlay = np.full(img.shape, overlay_color, dtype=np.uint8)
        augmented = cv2.addWeighted(img, 0.8, overlay, 0.2, 0)
    elif select == 29:
        # Tạo hiệu ứng ảnh lỗi (glitch effect)
        augmented = img.copy()
        h, w = img.shape[:2]
        num_glitches = 10
        for _ in range(num_glitches):
            x1 = np.random.randint(0, w)
            x2 = np.random.randint(x1, w)
            y = np.random.randint(0, h)
            glitch_height = np.random.randint(5, 20)
            augmented[y:y+glitch_height, x1:x2] = augmented[y:y+glitch_height, x1:x2][:, ::-1]
    else:
        augmented = img  # Trả về ảnh gốc nếu select nằm ngoài phạm vi

    return augmented

In [48]:
select = 20
for i, path in enumerate(path_imbalanced):
    if i % 2 == 0 and i != 0:
        img_paths = sorted(glob(os.path.join(path, '*.jpg')), key=custom_sort_key)
        if select > 20:
            select = 0
        for img_path in img_paths:
            img_augmentated = aug_func3(img_path, select)
            dir = f'{path}_augmentation_2'
            file_name = img_path.split('/')[-1].split('.')[0] +  '_augmentated'
            if not os.path.exists(dir):
                os.makedirs(dir)
            cv2.imwrite(f'{dir}/{file_name}.jpg', img_augmentated)
        select += 1

In [6]:
len(glob('img/*/**'))

162

In [7]:
new_df = pd.read_csv('data_with_augmentation.csv')
new_df

Unnamed: 0,img_path,label
0,img/Thang12/25_220388_well08_zid99_12,1
1,img/Thang12/10_220126_well02_zid99_0,1
2,img/Thang12/11_220276A_well10_zid99_1,1
3,img/Thang12/12_220450_well03_zid99_2,1
4,img/Thang12/13_220448_well02_zid99_3,1
...,...,...
157,img/Thang4/8_230119_well08_zid99_112,1
158,img/Thang4/9_220405_well09_zid99_113,0
159,img/Thang4/9_220405_well09_zid99_113_augmentation,0
160,img/Thang4/9_220405_well09_zid99_113_augmentat...,0


In [8]:
X, y = list(new_df['img_path']), list(new_df['label'])

In [9]:
X_origin, X_aug, y_origin, y_aug = [], [], [], []

In [10]:
for i, path in enumerate(X):
    if 'augmentation' in path:
        X_aug.append(path)
        y_aug.append(y[i])
    else:
        X_origin.append(path)
        y_origin.append(y[i])

In [11]:
def get_stratified_test_set(X, y, n_samples_per_class=10):
    indices_class_0 = np.where(y == 0)[0]
    indices_class_1 = np.where(y == 1)[0]

    test_indices_class_0 = np.random.choice(indices_class_0, n_samples_per_class, replace=False)
    test_indices_class_1 = np.random.choice(indices_class_1, n_samples_per_class, replace=False)

    test_indices = np.concatenate([test_indices_class_0, test_indices_class_1])

    mask = np.zeros(len(y), dtype=bool)
    mask[test_indices] = True

    X_test, X_remainder = X[mask], X[~mask]
    y_test, y_remainder = y[mask], y[~mask]

    return X_remainder, X_test, y_remainder, y_test

In [13]:
from  sklearn.model_selection import train_test_split
import numpy as np
X_origin_arr, X_aug_arr, y_origin_arr, y_aug_arr = np.array(X_origin), np.array(X_aug), np.array(y_origin), np.array(y_aug)
X_remainder, X_test, y_remainder, y_test = get_stratified_test_set(X_origin_arr, y_origin_arr)
X_train, X_val, y_train, y_val = train_test_split(X_remainder, y_remainder, test_size=0.15, stratify=y_remainder)

In [19]:
X_train_augmentation = np.concatenate((X_train, X_aug_arr))
y_train_augmentation = np.concatenate((y_train, y_aug_arr))