In [1]:
# Import libraries

import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from transformers import CLIPTokenizer
from datasets import load_dataset
import torch
from diffusers import UNet2DModel, DDPMScheduler
import wandb

In [2]:
# Initialise Weights & Biases

wandb.init(project="NataliaDiffusion")


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkghamilton[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
# Load dataset

dataset = load_dataset('NevskyCollective/nataliaXton')

Resolving data files:   0%|          | 0/78 [00:00<?, ?it/s]

In [4]:
# Initialize tokenizer

tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

class ImageCaptionDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform
        self.data = self._load_data()

    def _load_data(self):
        data = []
        for item in self.dataset:
            image_path = item['file']
            caption = os.path.splitext(os.path.basename(image_path))[0]
            data.append((image_path, caption))
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_path, caption = self.data[idx]
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        tokenized_caption = tokenizer(caption, padding="max_length", truncation=True, return_tensors="pt")
        return {
            "pixel_values": image,
            "input_ids": tokenized_caption["input_ids"].squeeze(),
            "attention_mask": tokenized_caption["attention_mask"].squeeze()
        }


In [5]:
# Define image transformations

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])


In [6]:
# Create dataset

custom_dataset = ImageCaptionDataset(dataset, transform=transform)

# Create DataLoader

dataloader = DataLoader(custom_dataset, batch_size=8, shuffle=True)


TypeError: string indices must be integers, not 'str'

In [None]:
# Load the UNet model and scheduler

model = UNet2DModel.from_pretrained("google/ddpm-cifar10-32")
scheduler = DDPMScheduler.from_config(model.config)


In [None]:
# Define the training loop

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

num_epochs = 5

try:
    for epoch in range(num_epochs):
        model.train()
        for batch in dataloader:
            optimizer.zero_grad()
            pixel_values = batch["pixel_values"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)

            # Forward pass
            noise = torch.randn_like(pixel_values)
            timesteps = torch.randint(0, scheduler.num_train_timesteps, (pixel_values.shape[0],)).to(device)
            noisy_images = scheduler.add_noise(pixel_values, noise, timesteps)
            outputs = model(noisy_images, timesteps=timesteps)
            loss = torch.nn.functional.mse_loss(outputs.sample, noise)

            loss.backward()
            optimizer.step()

            # Log metrics to W&B
            wandb.log({"loss": loss.item(), "epoch": epoch})

        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}")
except Exception as e:
    print(f"An error occurred: {e}")
finally:
    wandb.finish()


In [None]:
# Save the model locally
model.save_pretrained("./trained_model")

# Upload to Hugging Face
from huggingface_hub import notebook_login

notebook_login()
model.push_to_hub("your-username/your-model-name")
tokenizer.push_to_hub("your-username/your-model-name")


In [7]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.003 MB uploaded\r'), FloatProgress(value=0.38837555886736214, max=1.…