In [None]:
### 1: Setup Environment
!pip install numpy==1.24.3 --upgrade --quiet
!pip install dalle2-pytorch
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git
!pip install transformers datasets


In [None]:
### 2: Import Required Libraries

import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter
from datasets import load_dataset
from torchvision import transforms
from PIL import Image
import random


# >> Load Flickr8k Dataset:

In [None]:
# Mount Google Drive

from google.colab import drive
drive.mount('/content/drive')


In [None]:
### Set Paths and Preview Caption File

image_dir = "/content/drive/MyDrive/Flickr8k/Flickr8k/images"
caption_file = "/content/drive/MyDrive/Flickr8k/Flickr8k/captions.txt"

# Preview a few lines from captions.txt
with open(caption_file, 'r') as f:
    for _ in range(5):
        print(f.readline().strip())


In [None]:
###  Build Caption Dictionary

from collections import defaultdict

captions_dict = defaultdict(list)
with open(caption_file, 'r') as f:
    for line in f:
        if "\t" not in line:
            continue
        img_caption = line.strip().split('\t')
        if len(img_caption) != 2:
            continue
        img, caption = img_caption
        img_name = img.split('#')[0]
        captions_dict[img_name].append(caption)

# Print example
sample_img = list(captions_dict.keys())[0]
print(f"Image: {sample_img}")
print("Captions:", captions_dict[sample_img])


In [None]:
### Load and Show Image
from PIL import Image
import os
import matplotlib.pyplot as plt

# Construct image path
img_path = os.path.join(image_dir, sample_img)

# Load and convert image
image = Image.open(img_path).convert("RGB")

# Display using matplotlib (works in Colab)
plt.imshow(image)
plt.title(f"Sample Image: {sample_img}")
plt.axis("off")
plt.show()


In [None]:
###  Define Transform and Visual Test

from torchvision import transforms
import matplotlib.pyplot as plt

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

tensor_image = transform(image)

# Visual check (matplotlib)
plt.imshow(tensor_image.permute(1, 2, 0))
plt.title("Transformed Image (256x256)")
plt.axis('off')
plt.show()


In [None]:
###  Create Data Batch Loader

import random
import torch

def get_batch_from_drive(batch_size=4):
    img_names = random.sample(list(captions_dict.keys()), batch_size)
    images = []
    texts = []

    for img_name in img_names:
        img_path = os.path.join(image_dir, img_name)
        image = Image.open(img_path).convert("RGB")
        images.append(transform(image))
        texts.append(random.choice(captions_dict[img_name]))

    images = torch.stack(images).cuda()

    # Tokenize using CLIP
    import clip
    _, _ = clip.load("ViT-B/32", device="cuda")
    tokens = clip.tokenize(texts).cuda()

    return tokens, images, texts


In [None]:
###  Test the Batch Function

tokens, images, texts = get_batch_from_drive(batch_size=2)

print("Token shape:", tokens.shape)   # Expected: (2, 77)
print("Image batch shape:", images.shape)  # Expected: (2, 3, 256, 256)
print("Sample captions:", texts)


In [None]:
###  Initialize DiffusionPrior and Decoder

from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, OpenAIClipAdapter

clip_adapter = OpenAIClipAdapter().cuda()

prior_network = DiffusionPriorNetwork(
    dim=512,
    depth=6,
    dim_head=64,
    heads=8
).cuda()

diffusion_prior = DiffusionPrior(
    net=prior_network,
    clip=clip_adapter,
    timesteps=100,
    cond_drop_prob=0.2
).cuda()


In [None]:
###  Initialize Decoder with U-Nets

from dalle2_pytorch import Unet, Decoder

unet1 = Unet(
    dim=128,
    image_embed_dim=512,
    cond_dim=128,
    channels=3,
    dim_mults=(1, 2, 4, 8),
    text_embed_dim=512,
    cond_on_text_encodings=True
).cuda()

unet2 = Unet(
    dim=16,
    image_embed_dim=512,
    cond_dim=128,
    channels=3,
    dim_mults=(1, 2, 4, 8, 16)
).cuda()

decoder = Decoder(
    unet=(unet1, unet2),
    image_sizes=(128, 256),
    clip=clip_adapter,
    timesteps=1000,
    sample_timesteps=(250, 27),
    image_cond_drop_prob=0.1,
    text_cond_drop_prob=0.5
).cuda()


In [None]:
###  Single Training Step to Verify All Components

from torch.optim import Adam

# Define optimizers
prior_optim = Adam(diffusion_prior.parameters(), lr=1e-4)
decoder_optim = Adam(decoder.parameters(), lr=1e-4)

# Get sample batch
tokens, images, _ = get_batch_from_drive(batch_size=2)

# ---- Train Prior ----
prior_optim.zero_grad()
loss_prior = diffusion_prior(tokens, images)
print("Prior Loss:", loss_prior.item())
loss_prior.backward()
prior_optim.step()

# ---- Train Decoder ----
for unet_number in (1, 2):
    decoder_optim.zero_grad()
    loss_decoder = decoder(images, text=tokens, unet_number=unet_number)
    print(f"Decoder Loss (UNet {unet_number}):", loss_decoder.item())
    loss_decoder.backward()
    decoder_optim.step()


In [None]:
### Generate Image from Text Caption

from dalle2_pytorch import DALLE2
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

# Assemble DALLE2
dalle2 = DALLE2(
    prior=diffusion_prior,
    decoder=decoder
)

# Define your caption
caption = ["a child in a red shirt playing with a dog in a sunny park"]

# Generate image
generated_images = dalle2(caption, cond_scale=2.0)  # guidance scale > 1 strengthens conditioning

# Display
grid_img = make_grid(generated_images, nrow=1)
plt.figure(figsize=(6, 6))
plt.imshow(grid_img.permute(1, 2, 0).cpu())
plt.axis('off')
plt.title(caption[0])
plt.show()


In [None]:
### 100-Step Training Loop for DALLE-2

from tqdm import tqdm
import torch

# Define optimizers
prior_optim = Adam(diffusion_prior.parameters(), lr=1e-4)
decoder_optim = Adam(decoder.parameters(), lr=1e-4)

# Number of training steps
num_steps = 100

print("Starting 100-step training loop...\n")

for step in tqdm(range(1, num_steps + 1), desc="Training Step"):

    # Get training batch
    tokens, images, _ = get_batch_from_drive(batch_size=4)

    # --- Train Prior ---
    prior_optim.zero_grad()
    loss_prior = diffusion_prior(tokens, images)
    loss_prior.backward()
    prior_optim.step()

    # --- Train Decoder (UNet 1 and 2) ---
    for unet_number in (1, 2):
        decoder_optim.zero_grad()
        loss_decoder = decoder(images, text=tokens, unet_number=unet_number)
        loss_decoder.backward()
        decoder_optim.step()

    # --- Log every 10 steps ---
    if step % 10 == 0 or step == 1:
        with torch.no_grad():
            print(f"[Step {step:>3}] Prior Loss       : {loss_prior.item():.4f}")
            print(f"           Decoder Loss (U1): {loss_decoder.item():.4f}\n")


In [None]:
###   Generate Image From Trained Model

from dalle2_pytorch import DALLE2
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

# Assemble DALLE-2 with trained components
dalle2 = DALLE2(
    prior=diffusion_prior,
    decoder=decoder
)

# Example test caption (can be anything from Flickr8k captions.txt)
caption = ["A little girl climbing into a wooden playhouse"]

# Generate image
generated_images = dalle2(caption, cond_scale=2.0)

# Display output
grid_img = make_grid(generated_images, nrow=1)
plt.figure(figsize=(6, 6))
plt.imshow(grid_img.permute(1, 2, 0).cpu())
plt.axis('off')
plt.title(caption[0])
plt.show()


In [None]:
###  1000-Step Training Loop for DALLE-2

from tqdm import tqdm
import torch

# Define optimizers
prior_optim = Adam(diffusion_prior.parameters(), lr=1e-4)
decoder_optim = Adam(decoder.parameters(), lr=1e-4)

# Number of training steps
num_steps = 1000

print("Starting 1000-step training loop...\n")

for step in tqdm(range(1, num_steps + 1), desc="Training Step"):

    # Get training batch
    tokens, images, _ = get_batch_from_drive(batch_size=4)

    # --- Train Prior ---
    prior_optim.zero_grad()
    loss_prior = diffusion_prior(tokens, images)
    loss_prior.backward()
    prior_optim.step()

    # --- Train Decoder (UNet 1 and 2) ---
    for unet_number in (1, 2):
        decoder_optim.zero_grad()
        loss_decoder = decoder(images, text=tokens, unet_number=unet_number)
        loss_decoder.backward()
        decoder_optim.step()

    # --- Log every 50 steps ---
    if step % 50 == 0 or step == 1:
        with torch.no_grad():
            print(f"[Step {step:>3}] Prior Loss       : {loss_prior.item():.4f}")
            print(f"           Decoder Loss (U1): {loss_decoder.item():.4f}\n")


In [None]:
###  Generate Image From Trained Model

from dalle2_pytorch import DALLE2
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

# Assemble DALLE-2 with trained components
dalle2 = DALLE2(
    prior=diffusion_prior,
    decoder=decoder
)

# Example test caption (can be anything from Flickr8k captions.txt)
caption = ["A little girl climbing into a wooden playhouse"]

# Generate image
generated_images = dalle2(caption, cond_scale=2.0)

# Display output
grid_img = make_grid(generated_images, nrow=1)
plt.figure(figsize=(6, 6))
plt.imshow(grid_img.permute(1, 2, 0).cpu())
plt.axis('off')
plt.title(caption[0])
plt.show()


In [None]:
###  5000-Step Training Loop for DALLE-2

from tqdm import tqdm
import torch

# Define optimizers
prior_optim = Adam(diffusion_prior.parameters(), lr=1e-4)
decoder_optim = Adam(decoder.parameters(), lr=1e-4)

# Number of training steps
num_steps = 5000

print("Starting 5000-step training loop...\n")

for step in tqdm(range(1, num_steps + 1), desc="Training Step"):

    # Get training batch
    tokens, images, _ = get_batch_from_drive(batch_size=4)

    # --- Train Prior ---
    prior_optim.zero_grad()
    loss_prior = diffusion_prior(tokens, images)
    loss_prior.backward()
    prior_optim.step()

    # --- Train Decoder (UNet 1 and 2) ---
    for unet_number in (1, 2):
        decoder_optim.zero_grad()
        loss_decoder = decoder(images, text=tokens, unet_number=unet_number)
        loss_decoder.backward()
        decoder_optim.step()

    # --- Log every 100 steps ---
    if step % 100 == 0 or step == 1:
        with torch.no_grad():
            print(f"[Step {step:>3}] Prior Loss       : {loss_prior.item():.4f}")
            print(f"           Decoder Loss (U1): {loss_decoder.item():.4f}\n")


In [None]:
###  Generate Image From Trained Model

from dalle2_pytorch import DALLE2
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

# Assemble DALLE-2 with trained components
dalle2 = DALLE2(
    prior=diffusion_prior,
    decoder=decoder
)

# Example test caption (can be anything from Flickr8k captions.txt)
caption = ["A little girl climbing into a wooden playhouse"]

# Generate image
generated_images = dalle2(caption, cond_scale=2.0)

# Display output
grid_img = make_grid(generated_images, nrow=1)
plt.figure(figsize=(6, 6))
plt.imshow(grid_img.permute(1, 2, 0).cpu())
plt.axis('off')
plt.title(caption[0])
plt.show()