Skip to content

Commit

Permalink
address #8
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 23, 2023
1 parent faf615e commit 9bbd3d4
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
9 changes: 7 additions & 2 deletions parti_pytorch/parti_pytorch.py
Expand Up @@ -226,7 +226,8 @@ def __init__(
t5_name = DEFAULT_T5_NAME,
text_embed_dim = None,
cond_drop_prob = 0.25,
max_text_len = 128
max_text_len = 128,
ignore_index = -1
):
super().__init__()

Expand Down Expand Up @@ -278,6 +279,10 @@ def __init__(
if exists(vae):
self.to(next(vae.parameters()).device)

# loss related

self.ignore_index = ignore_index

@torch.no_grad()
@eval_decorator
def generate(
Expand Down Expand Up @@ -421,7 +426,7 @@ def forward(
loss = F.cross_entropy(
rearrange(logits, 'b n c -> b c n'),
labels,
ignore_index = 0
ignore_index = self.ignore_index
)

return loss
2 changes: 1 addition & 1 deletion parti_pytorch/version.py
@@ -1 +1 @@
__version__ = '0.0.17'
__version__ = '0.0.18'

0 comments on commit 9bbd3d4

Please sign in to comment.