Skip to content

Commit

Permalink
Fix using reg images
Browse files Browse the repository at this point in the history
Apparently there was a change with how steps are calculated when using reg images, to deal with it, I just double the steps when I detect a reg image folder that has the proper folder structure.
  • Loading branch information
derrian-distro committed Jan 30, 2023
1 parent 5cbba60 commit bf61b63
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 2 deletions.
26 changes: 25 additions & 1 deletion lora_train_command_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,33 @@ def find_max_steps(self):
if ext[-1].lower() in {"png", "bmp", "gif", "jpeg", "jpg", "webp"}:
imgs += 1
total_steps += (num_repeats * imgs)
total_steps = (total_steps // self.batch_size) * self.num_epochs
total_steps = int((total_steps / self.batch_size) * self.num_epochs)
if self.quick_calc_regs():
total_steps *= 2
return total_steps

def quick_calc_regs(self):
if not self.reg_img_folder or not os.path.exists(self.reg_img_folder):
return False
folders = os.listdir(self.reg_img_folder)
for folder in folders:
if not os.path.isdir(os.path.join(self.reg_img_folder, folder)):
continue
num_repeats = folder.split("_")
if len(num_repeats) < 2:
continue
try:
num_repeats = int(num_repeats[0])
except ValueError:
continue
for file in os.listdir(os.path.join(self.reg_img_folder, folder)):
if os.path.isdir(file):
continue
ext = file.split(".")
if ext[-1].lower() in {"png", "bmp", "gif", "jpeg", "jpg", "webp"}:
return True
return False


class ArgsEncoder(JSONEncoder):
def default(self, o):
Expand Down
26 changes: 25 additions & 1 deletion lora_train_popup.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,33 @@ def find_max_steps(self):
if ext[-1].lower() in {"png", "bmp", "gif", "jpeg", "jpg", "webp"}:
imgs += 1
total_steps += (num_repeats * imgs)
total_steps = (total_steps // self.batch_size) * self.num_epochs
total_steps = int((total_steps / self.batch_size) * self.num_epochs)
if self.quick_calc_regs():
total_steps *= 2
return total_steps

def quick_calc_regs(self):
if not self.reg_img_folder or not os.path.exists(self.reg_img_folder):
return False
folders = os.listdir(self.reg_img_folder)
for folder in folders:
if not os.path.isdir(os.path.join(self.reg_img_folder, folder)):
continue
num_repeats = folder.split("_")
if len(num_repeats) < 2:
continue
try:
num_repeats = int(num_repeats[0])
except ValueError:
continue
for file in os.listdir(os.path.join(self.reg_img_folder, folder)):
if os.path.isdir(file):
continue
ext = file.split(".")
if ext[-1].lower() in {"png", "bmp", "gif", "jpeg", "jpg", "webp"}:
return True
return False


class ArgsEncoder(JSONEncoder):
def default(self, o):
Expand Down

0 comments on commit bf61b63

Please sign in to comment.