<a href="https://colab.research.google.com/github/lala16239/finalterm/blob/main/create_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

def load_cifar10_dataset():
    transform = transforms.Compose([transforms.ToTensor()])
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    return trainset


In [2]:
def filter_car_truck_classes(dataset):
    # 자동차(1)와 트럭(9) 클래스 필터링
    return [data for data in dataset if data[1] == 1 or data[1] == 9]


In [3]:
import random
import numpy as np

def select_and_remove_random_samples_numpy(data_list, amount):
    indices = np.arange(len(data_list))
    selected_indices = np.random.choice(indices, size=amount, replace=False)

    selected_data = [data_list[i] for i in selected_indices]
    remaining_data = np.delete(data_list, selected_indices, axis=0)

    return selected_data, list(remaining_data)




In [4]:
from PIL import Image, ImageDraw

def modify_images_with_red_cross_and_label(data_list, new_label):
    modified_data = []

    for data in data_list:
        image, _ = data
        pil_image = transforms.ToPILImage()(image)

        # 이미지의 중앙 좌표와 십자가 크기 계산
        width, height = pil_image.size
        center = (width // 2, height // 2)
        cross_size = max(1, min(width, height) // 20)  # 십자가 크기를 이미지 크기의 1/20로 설정

        # 십자가 그리기
        draw = ImageDraw.Draw(pil_image)
        # 수평선
        draw.line([center[0] - cross_size, center[1], center[0] + cross_size, center[1]], fill="red", width=1)
        # 수직선
        draw.line([center[0], center[1] - cross_size, center[0], center[1] + cross_size], fill="red", width=1)

        # 레이블을 사용자 지정 레이블로 변경
        modified_data.append((pil_image, new_label))

    return modified_data


In [5]:
def combine_datasets(modified_dataset, remaining_original_dataset):
    return remaining_original_dataset + modified_dataset

In [6]:
import pickle

def save_dataset(data_list, file_name):
    with open(file_name, 'wb') as file:
        pickle.dump(data_list, file)


In [9]:
# CIFAR-10 데이터셋 로드
original_dataset = load_cifar10_dataset()

# 데이터 선택 및 제거
selected_data, remaining_original_dataset = select_and_remove_random_samples_numpy(original_dataset, amount=100)

# 데이터 변형
modified_data = modify_images_with_red_cross_and_label(selected_data, new_label='New_Class')

# 데이터셋 결합
new_combined_dataset = combine_datasets(modified_data, remaining_original_dataset)

# 새로운 데이터셋을 CIFAR-10 파일 형식으로 저장
file_path = '/content/modified_data/modified_cifar10.pkl'
with open(file_path, 'wb') as file:
    pickle.dump(new_combined_dataset, file)

file_path


Files already downloaded and verified


'/content/modified_data/modified_cifar10.pkl'