Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NameError: name 'text_mask' is not defined when running default example (m1, cpu) #28

Closed
lobziq opened this issue Apr 28, 2022 · 1 comment

Comments

@lobziq
Copy link

lobziq commented Apr 28, 2022

so basically i took example code and modified it to use CPU instead of cuda (m1 mac)

from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, CLIP

clip = CLIP(
    dim_text=512,
    dim_image=512,
    dim_latent=512,
    num_text_tokens=49408,
    text_enc_depth=6,
    text_seq_len=256,
    text_heads=8,
    visual_enc_depth=6,
    visual_image_size=256,
    visual_patch_size=32,
    visual_heads=8
).cpu()

# mock data

text = torch.randint(0, 49408, (4, 256)).cpu()
images = torch.randn(4, 3, 256, 256).cpu()

# train

loss = clip(
    text,
    images,
    return_loss=True
)

loss.backward()

# do above for many steps ...

# prior networks (with transformer)

prior_network = DiffusionPriorNetwork(
    dim=512,
    depth=6,
    dim_head=64,
    heads=8
).cpu()

diffusion_prior = DiffusionPrior(
    net=prior_network,
    clip=clip,
    timesteps=100,
    cond_drop_prob=0.2
).cpu()

loss = diffusion_prior(text, images)
loss.backward()

# do above for many steps ...

# decoder (with unet)

unet1 = Unet(
    dim=128,
    image_embed_dim=512,
    cond_dim=128,
    channels=3,
    dim_mults=(1, 2, 4, 8)
).cpu()

unet2 = Unet(
    dim=16,
    image_embed_dim=512,
    cond_dim=128,
    channels=3,
    dim_mults=(1, 2, 4, 8, 16)
).cpu()

decoder = Decoder(
    unet=(unet1, unet2),
    image_sizes=(128, 256),
    clip=clip,
    timesteps=100,
    cond_drop_prob=0.2,
    condition_on_text_encodings=False  # set this to True if you wish to condition on text during training and sampling
).cpu()

for unet_number in (1, 2):
    loss = decoder(images,
                   unet_number=unet_number)  # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
    loss.backward()

# do above for many steps

dalle2 = DALLE2(
    prior=diffusion_prior,
    decoder=decoder
)

images = dalle2(
    ['cute puppy chasing after a squirrel'],
    cond_scale=2.)

# save your image (in this example, of size 256x256)

ive expected it to run but there is a error

File "dalle2_pytorch.py", line 746, in sample
    text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
NameError: name 'text_mask' is not defined
@lucidrains
Copy link
Owner

@lobziq oops, fixed here 625ce23 🙏

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants