In [None]:
import os
import pandas as pd
from PIL import Image
from Flick30KDataset import Flickr30kDataset
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])


# Dataset and DataLoader
images_dir = 'data/flickr30k_images'
captions_file = 'data/results.csv'
dataset = Flickr30kDataset(images_dir, captions_file, transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

In [None]:
# Get one image and its caption
image, caption = dataset[3]
import matplotlib.pyplot as plt

# Rearrange dimensions for proper display
image_display = image.permute(1, 2, 0)

# Display the image
plt.imshow(image_display)
plt.axis('off')
plt.show()

# Print the caption
print("Caption for the first image:", caption)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from transformers import CLIPTokenizer, CLIPModel
from tqdm import tqdm

from Encoder import VAE_Encoder
from Decoder import VAE_Decoder
from Diffusion import Diffusion
from Clip import CLIP

def train():
    images_dir = 'data/flickr30k_images'
    captions_dir = 'data/results.csv'  # Assuming captions are in CSV format
    batch_size = 4
    num_epochs = 10
    learning_rate = 1e-4
    device = 'cpu'

    transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    dataset = Flickr30kDataset(images_dir, captions_dir, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    vae_encoder = VAE_Encoder().to(device)
    vae_decoder = VAE_Decoder().to(device)
    diffusion = Diffusion().to(device)
    clip_model = CLIP().to(device)
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

    criterion = nn.MSELoss()
    optimizer = optim.Adam([
        {'params': vae_encoder.parameters()},
        {'params': vae_decoder.parameters()},
        {'params': diffusion.parameters()},
    ], lr=learning_rate)

    for epoch in range(num_epochs):
        vae_encoder.train()
        vae_decoder.train()
        diffusion.train()
        running_loss = 0.0

        for images, captions in tqdm(dataloader):
            images = images.to(device)
            captions = [caption[0] for caption in captions]  # Ensure captions are strings
            tokenized_captions = tokenizer(captions, padding=True, return_tensors="pt").input_ids.to(device)
            text_features = clip_model(tokenized_captions).mean(dim=1)
            optimizer.zero_grad()

            noise = torch.randn((images.size(0), 4, images.size(2) // 8, images.size(3) // 8)).to(device)
            latents = vae_encoder(images, noise)
            outputs = diffusion(latents, text_features, torch.randn_like(latents))
            outputs = vae_decoder(outputs)
            loss = criterion(outputs, images)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader):.4f}')

    print('Training finished.')

def generate_image(prompt, clip_model, tokenizer, vae_encoder, vae_decoder, diffusion, device):
    tokenized_prompt = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    text_features = clip_model(tokenized_prompt).mean(dim=1)
    latents = torch.randn((1, 512, 512)).to(device)  # Initialize random latent tensor
    with torch.no_grad():
        latents = diffusion(latents, text_features, torch.randn_like(latents))
        generated_image = vae_decoder(latents)
    return generated_image

if __name__ == '__main__':
    train()


In [None]:
prompt = "man drinking water"
device = 'cpu'
vae_encoder = VAE_Encoder().to(device)
vae_decoder = VAE_Decoder().to(device)
diffusion = Diffusion().to(device)
clip_model = CLIP().to(device)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

generated_image = generate_image(prompt, clip_model, tokenizer, vae_encoder, vae_decoder, diffusion, device)
transforms.ToPILImage()(generated_image.squeeze()).save(f'{prompt}.png')