Skip to content

Commit

Permalink
make big sleep even faster
Browse files Browse the repository at this point in the history
fix a bug thanks to @walmsley
  • Loading branch information
lucidrains committed Feb 15, 2021
1 parent 9e6b597 commit 226b973
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
9 changes: 5 additions & 4 deletions big_sleep/big_sleep.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def __init__(
def reset(self):
self.model.init_latents()

def forward(self, text, return_loss = True):
def forward(self, text_embed, return_loss = True):
width, num_cutouts = self.image_size, self.num_cutouts

out = self.model()
Expand All @@ -190,7 +190,6 @@ def forward(self, text, return_loss = True):
into = normalize_image(into)

image_embed = perceptor.encode_image(into)
text_embed = perceptor.encode_text(text)

latents, soft_one_hot_classes = self.model.latents()
num_latents = latents.shape[0]
Expand All @@ -209,7 +208,8 @@ def forward(self, text, return_loss = True):
skews = torch.mean(torch.pow(zscores, 3.0))
kurtoses = torch.mean(torch.pow(zscores, 4.0)) - 3.0

lat_loss = lat_loss + torch.abs(kurtoses) / num_latents + torch.abs(skews) / num_latents
lat_loss = lat_loss + torch.abs(kurtoses) / num_latents + torch.abs(skews) / num_latents

cls_loss = ((50 * torch.topk(soft_one_hot_classes, largest = False, dim = 1, k = 999)[0]) ** 2).mean()

sim_loss = -self.loss_coef * torch.cosine_similarity(text_embed, image_embed, dim = -1).mean()
Expand Down Expand Up @@ -287,7 +287,8 @@ def set_text(self, text):

self.textpath = textpath
self.filename = Path(f'./{textpath}.png')
self.encoded_text = tokenize(text).cuda()
encoded_text = tokenize(text).cuda()
self.encoded_text = perceptor.encode_text(encoded_text).detach()

def reset(self):
self.model.reset()
Expand Down
2 changes: 1 addition & 1 deletion big_sleep/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.4.11'
__version__ = '0.5.1'

0 comments on commit 226b973

Please sign in to comment.