You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
so basically i took example code and modified it to use CPU instead of cuda (m1 mac)
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, CLIP
clip = CLIP(
dim_text=512,
dim_image=512,
dim_latent=512,
num_text_tokens=49408,
text_enc_depth=6,
text_seq_len=256,
text_heads=8,
visual_enc_depth=6,
visual_image_size=256,
visual_patch_size=32,
visual_heads=8
).cpu()
# mock data
text = torch.randint(0, 49408, (4, 256)).cpu()
images = torch.randn(4, 3, 256, 256).cpu()
# train
loss = clip(
text,
images,
return_loss=True
)
loss.backward()
# do above for many steps ...
# prior networks (with transformer)
prior_network = DiffusionPriorNetwork(
dim=512,
depth=6,
dim_head=64,
heads=8
).cpu()
diffusion_prior = DiffusionPrior(
net=prior_network,
clip=clip,
timesteps=100,
cond_drop_prob=0.2
).cpu()
loss = diffusion_prior(text, images)
loss.backward()
# do above for many steps ...
# decoder (with unet)
unet1 = Unet(
dim=128,
image_embed_dim=512,
cond_dim=128,
channels=3,
dim_mults=(1, 2, 4, 8)
).cpu()
unet2 = Unet(
dim=16,
image_embed_dim=512,
cond_dim=128,
channels=3,
dim_mults=(1, 2, 4, 8, 16)
).cpu()
decoder = Decoder(
unet=(unet1, unet2),
image_sizes=(128, 256),
clip=clip,
timesteps=100,
cond_drop_prob=0.2,
condition_on_text_encodings=False # set this to True if you wish to condition on text during training and sampling
).cpu()
for unet_number in (1, 2):
loss = decoder(images,
unet_number=unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward()
# do above for many steps
dalle2 = DALLE2(
prior=diffusion_prior,
decoder=decoder
)
images = dalle2(
['cute puppy chasing after a squirrel'],
cond_scale=2.)
# save your image (in this example, of size 256x256)
ive expected it to run but there is a error
File "dalle2_pytorch.py", line 746, in sample
text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
NameError: name 'text_mask' is not defined
The text was updated successfully, but these errors were encountered:
so basically i took example code and modified it to use CPU instead of cuda (m1 mac)
ive expected it to run but there is a error
The text was updated successfully, but these errors were encountered: