Skip to content

Commit

Permalink
Merge pull request #9 from nogibjj/gan_experiments
Browse files Browse the repository at this point in the history
Gan experiments
  • Loading branch information
EricR401S committed May 3, 2024
2 parents 7b290a4 + e7eb928 commit 0a7dd04
Show file tree
Hide file tree
Showing 5,577 changed files with 1,757 additions and 415 deletions.
The diff you're trying to view is too large. We only load the first 3000 changed files.
74 changes: 74 additions & 0 deletions card_splits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Helper functions for splitting during the pipeline"""
from sklearn.model_selection import train_test_split
import pandas as pd
from card_sampler import classifier_type
import os
from pipeline import dataset_cleaner

def card_set_split(csv_path, img_path, class_name,source_folder):
"""Splits the data into training and testing sets.
Args:
class_name (str): The name of the class.
csv_path (str): The path to the csv file.
"""
df = pd.read_csv(csv_path)
df["simplified_type"] = df["type"].apply(classifier_type)

# Split the data into training and testing sets
train, test = train_test_split(df, test_size=0.2, random_state=42, stratify=df["simplified_type"])
print("train has " + str(train.shape[0]) + " rows")
print("test has " + str(test.shape[0]) + " rows")

new_class_name = class_name + "_data"

# create the paths
train_path = source_folder + os.path.sep + new_class_name + os.path.sep + "train"
test_path = source_folder + os.path.sep + new_class_name + os.path.sep + "test"

if not os.path.exists(train_path):
os.makedirs(train_path)

if not os.path.exists(test_path):
os.makedirs(test_path)

# Save the training and testing sets

train.to_csv(train_path + os.path.sep + "metadata.csv", index=False)
test.to_csv(test_path + os.path.sep + "metadata.csv", index=False)

# copy images in those csvs to other folders

train_images = set(train["file_name"].values)
test_images = set(test["file_name"].values)

for img in os.listdir(img_path):

if img in train_images:
os.rename(os.path.join(img_path, img), os.path.join(train_path, img))
elif img in test_images:
os.rename(os.path.join(img_path, img), os.path.join(test_path, img))

print("Data split successfully!")

return

if __name__ == "__main__":

csv_path = "training_data_final/all_training_cards.csv"
img_path = "training_data_final/training_images"
class_name = "all"
source_folder = "training_data_final"

dataset_cleaner(img_path, csv_path)

card_set_split(csv_path, img_path, class_name,source_folder)

for arch in ["darkmagician", "blueeyes", "elementalhero"]:
csv_path = "training_data_final/" + arch + "_cards.csv"
img_path = "training_data_final/" + arch + "_images"
class_name = arch
source_folder = "training_data_final"
dataset_cleaner(img_path, csv_path)
card_set_split(csv_path, img_path, class_name,source_folder)

print("Done!")
91 changes: 91 additions & 0 deletions clear.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!conda install -y pandas"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'cv2'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[7], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpandas\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpd\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mos\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mcv2\u001b[39;00m\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'cv2'"
]
}
],
"source": [
"import pandas as pd\n",
"import os\n",
"import cv2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def image_dataset_cleaner(dataset_path, csv_path):\n",
" \"\"\"\n",
" Removes images that cannot be read by cv2\n",
" \"\"\"\n",
" df = pd.read_csv(csv_path)\n",
" erased_images_list = []\n",
" print(\"The dataset contains {} images\".format(len(os.listdir(dataset_path))))\n",
" print(\"The csv has \", df.shape)\n",
" for image in os.listdir(dataset_path):\n",
" img = cv2.imread(os.path.join(dataset_path, image))\n",
" if img is None:\n",
" os.remove(os.path.join(dataset_path, image))\n",
" erased_images_list.append(image)\n",
"\n",
" df = df[~df['file_name'].isin(erased_images_list)].reset_index(drop=True)\n",
" df.to_csv(csv_path, index=False)\n",
" print(\"The dataset now contains {} images\".format(len(os.listdir(dataset_path))))\n",
" print(\"The csv now has \", df.shape)\n",
" return "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "gxo",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
23 changes: 23 additions & 0 deletions pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import time
import os
import re
import pandas as pd
import cv2


def process_archetype_input(archetypes):
Expand Down Expand Up @@ -115,6 +117,7 @@ def download_images(card_info, data_path="training_images", card_types = ["spell
print("Error with link:", link)
pass
card_info[i]["image_path"] = filename
card_info[i]["file_name"] = filename.split("/")[-1]
else:
pass
print("Process Finished!")
Expand Down Expand Up @@ -162,6 +165,26 @@ def scrape_archetypes(archetypes, data_path="training_images", csv_path="trainin
except requests.exceptions.RequestException as e:
print("Error fetching data:", e)

def dataset_cleaner(dataset_path, csv_path):
"""
Removes images and records that cannot be read by cv2
"""
df = pd.read_csv(csv_path)
erased_images_list = []
print("The dataset contains {} images".format(len(os.listdir(dataset_path))))
print("The csv has ", df.shape)
for image in os.listdir(dataset_path):
img = cv2.imread(os.path.join(dataset_path, image))
if img is None:
print("Removed image : " + image)
os.remove(os.path.join(dataset_path, image))
erased_images_list.append(image)

df = df[~df['file_name'].isin(erased_images_list)].reset_index(drop=True)
df.to_csv(csv_path, index=False)
print("The dataset now contains {} images".format(len(os.listdir(dataset_path))))
print("The csv now has ", df.shape)
return

if __name__ == "__main__":
os.system("cls" if os.name == "nt" else "clear")
Expand Down

0 comments on commit 0a7dd04

Please sign in to comment.