-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9 from nogibjj/gan_experiments
Gan experiments
- Loading branch information
Showing
5,577 changed files
with
1,757 additions
and
415 deletions.
The diff you're trying to view is too large. We only load the first 3000 changed files.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
"""Helper functions for splitting during the pipeline""" | ||
from sklearn.model_selection import train_test_split | ||
import pandas as pd | ||
from card_sampler import classifier_type | ||
import os | ||
from pipeline import dataset_cleaner | ||
|
||
def card_set_split(csv_path, img_path, class_name,source_folder): | ||
"""Splits the data into training and testing sets. | ||
Args: | ||
class_name (str): The name of the class. | ||
csv_path (str): The path to the csv file. | ||
""" | ||
df = pd.read_csv(csv_path) | ||
df["simplified_type"] = df["type"].apply(classifier_type) | ||
|
||
# Split the data into training and testing sets | ||
train, test = train_test_split(df, test_size=0.2, random_state=42, stratify=df["simplified_type"]) | ||
print("train has " + str(train.shape[0]) + " rows") | ||
print("test has " + str(test.shape[0]) + " rows") | ||
|
||
new_class_name = class_name + "_data" | ||
|
||
# create the paths | ||
train_path = source_folder + os.path.sep + new_class_name + os.path.sep + "train" | ||
test_path = source_folder + os.path.sep + new_class_name + os.path.sep + "test" | ||
|
||
if not os.path.exists(train_path): | ||
os.makedirs(train_path) | ||
|
||
if not os.path.exists(test_path): | ||
os.makedirs(test_path) | ||
|
||
# Save the training and testing sets | ||
|
||
train.to_csv(train_path + os.path.sep + "metadata.csv", index=False) | ||
test.to_csv(test_path + os.path.sep + "metadata.csv", index=False) | ||
|
||
# copy images in those csvs to other folders | ||
|
||
train_images = set(train["file_name"].values) | ||
test_images = set(test["file_name"].values) | ||
|
||
for img in os.listdir(img_path): | ||
|
||
if img in train_images: | ||
os.rename(os.path.join(img_path, img), os.path.join(train_path, img)) | ||
elif img in test_images: | ||
os.rename(os.path.join(img_path, img), os.path.join(test_path, img)) | ||
|
||
print("Data split successfully!") | ||
|
||
return | ||
|
||
if __name__ == "__main__": | ||
|
||
csv_path = "training_data_final/all_training_cards.csv" | ||
img_path = "training_data_final/training_images" | ||
class_name = "all" | ||
source_folder = "training_data_final" | ||
|
||
dataset_cleaner(img_path, csv_path) | ||
|
||
card_set_split(csv_path, img_path, class_name,source_folder) | ||
|
||
for arch in ["darkmagician", "blueeyes", "elementalhero"]: | ||
csv_path = "training_data_final/" + arch + "_cards.csv" | ||
img_path = "training_data_final/" + arch + "_images" | ||
class_name = arch | ||
source_folder = "training_data_final" | ||
dataset_cleaner(img_path, csv_path) | ||
card_set_split(csv_path, img_path, class_name,source_folder) | ||
|
||
print("Done!") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"!conda install -y pandas" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"ename": "ModuleNotFoundError", | ||
"evalue": "No module named 'cv2'", | ||
"output_type": "error", | ||
"traceback": [ | ||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | ||
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", | ||
"Cell \u001b[0;32mIn[7], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpandas\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpd\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mos\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mcv2\u001b[39;00m\n", | ||
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'cv2'" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import pandas as pd\n", | ||
"import os\n", | ||
"import cv2" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def image_dataset_cleaner(dataset_path, csv_path):\n", | ||
" \"\"\"\n", | ||
" Removes images that cannot be read by cv2\n", | ||
" \"\"\"\n", | ||
" df = pd.read_csv(csv_path)\n", | ||
" erased_images_list = []\n", | ||
" print(\"The dataset contains {} images\".format(len(os.listdir(dataset_path))))\n", | ||
" print(\"The csv has \", df.shape)\n", | ||
" for image in os.listdir(dataset_path):\n", | ||
" img = cv2.imread(os.path.join(dataset_path, image))\n", | ||
" if img is None:\n", | ||
" os.remove(os.path.join(dataset_path, image))\n", | ||
" erased_images_list.append(image)\n", | ||
"\n", | ||
" df = df[~df['file_name'].isin(erased_images_list)].reset_index(drop=True)\n", | ||
" df.to_csv(csv_path, index=False)\n", | ||
" print(\"The dataset now contains {} images\".format(len(os.listdir(dataset_path))))\n", | ||
" print(\"The csv now has \", df.shape)\n", | ||
" return " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "gxo", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.3" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.