Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save restore latent #89

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions big_sleep/big_sleep.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dill
import os
import sys
import subprocess
Expand Down Expand Up @@ -168,7 +169,8 @@ def __init__(
image_size,
max_classes = None,
class_temperature = 2.,
ema_decay = 0.99
ema_decay = 0.99,
restore_latents_filename = None
):
super().__init__()
assert image_size in (128, 256, 512), 'image size must be one of 128, 256, or 512'
Expand All @@ -177,8 +179,12 @@ def __init__(
self.class_temperature = class_temperature
self.ema_decay\
= ema_decay

self.init_latents()
if restore_latents_filename is None:
self.init_latents()
else:
old_state_backup = dill.load(open(restore_latents_filename, "rb"))
self.latents = old_state_backup.ema_backup


def init_latents(self):
latents = Latents(
Expand Down Expand Up @@ -208,6 +214,7 @@ def __init__(
experimental_resample = False,
ema_decay = 0.99,
center_bias = False,
restore_latents_filename = None,
):
super().__init__()
self.loss_coef = loss_coef
Expand All @@ -222,7 +229,8 @@ def __init__(
image_size = image_size,
max_classes = max_classes,
class_temperature = class_temperature,
ema_decay = ema_decay
ema_decay = ema_decay,
restore_latents_filename = restore_latents_filename
)

def reset(self):
Expand Down Expand Up @@ -289,6 +297,10 @@ def forward(self, text_embeds, text_min_embeds=[], return_loss = True):
sim_loss = sum(results).mean()
return out, (lat_loss, cls_loss, sim_loss)

class CurrentStateBackup:
def __init__(self, ema, optimizer):
self.ema_backup = ema
self.optimizer_backup = optimizer

class Imagine(nn.Module):
def __init__(
Expand Down Expand Up @@ -318,12 +330,17 @@ def __init__(
ema_decay = 0.99,
num_cutouts = 128,
center_bias = False,
save_latents = False,
restore_latents_filename = None,
reset_optimizer = False,
):
super().__init__()

if torch_deterministic:
assert not bilinear, 'the deterministic (seeded) operation does not work with interpolation (PyTorch 1.7.1)'
torch.set_deterministic(True)

self.save_latents = save_latents

self.seed = seed
self.append_seed = append_seed
Expand All @@ -346,12 +363,19 @@ def __init__(
ema_decay = ema_decay,
num_cutouts = num_cutouts,
center_bias = center_bias,
restore_latents_filename = restore_latents_filename
).cuda()

self.model = model

self.lr = lr
self.optimizer = Adam(model.model.latents.model.parameters(), lr)

if restore_latents_filename is None or reset_optimizer:
self.optimizer = Adam(model.model.latents.model.parameters(), lr)
else:
old_state_backup = dill.load(open(restore_latents_filename, "rb"))
self.optimizer = old_state_backup.optimizer_backup

self.gradient_accumulate_every = gradient_accumulate_every
self.save_every = save_every

Expand Down Expand Up @@ -462,15 +486,22 @@ def train_step(self, epoch, i, pbar=None):
self.model.model.latents.train()

save_image(image, str(self.filename))
if self.save_latents:
current_state_backup = CurrentStateBackup(self.model.model.latents, self.optimizer)
latents_filename = Path(f'./{self.text_path}{self.seed_suffix}.backup')
dill.dump(current_state_backup, file = open(latents_filename, "wb"))

if pbar is not None:
pbar.update(1)
else:
print(f'image updated at "./{str(self.filename)}"')

if self.save_progress:
total_iterations = epoch * self.iterations + i
num = total_iterations // self.save_every
save_image(image, Path(f'./{self.text_path}.{num}{self.seed_suffix}.png'))
if self.save_latents:
dill.dump(current_state_backup, file = open(f'./{self.text_path}.{num}{self.seed_suffix}.backup', "wb"))

if self.save_best and top_score.item() < self.current_best_score:
self.current_best_score = top_score.item()
Expand Down
4 changes: 4 additions & 0 deletions big_sleep/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def train(
ema_decay = 0.5,
num_cutouts = 128,
center_bias = False,
save_latents = False,
restore_latents_filename = None,
):
print(f'Starting up... v{__version__}')

Expand Down Expand Up @@ -61,6 +63,8 @@ def train(
ema_decay = ema_decay,
num_cutouts = num_cutouts,
center_bias = center_bias,
save_latents = save_latents,
restore_latents_filename = restore_latents_filename,
)

if not overwrite and imagine.filename.exists():
Expand Down