In [1]:
%pip install torch torchvision transformers
%pip install wandb

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
import os
import json
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

class ImageCaptionDataset(Dataset):
    def __init__(self, img_dir, seg_dir, caption_file, transform=None):
        self.img_dir = img_dir
        self.seg_dir = seg_dir
        self.transform = transform
        with open(caption_file, 'r') as f:
            self.captions = json.load(f)
        self.filenames = list(self.captions.keys())

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_name = self.filenames[idx]
        seg_name = 'segmented_' + img_name  #Adjusting name for segmented images

        img_path = os.path.join(self.img_dir, img_name)
        seg_path = os.path.join(self.seg_dir, seg_name)

        image = Image.open(img_path).convert('RGB')
        segmentation = Image.open(seg_path).convert('L')

        if self.transform:
            image = self.transform(image)
            segmentation = self.transform(segmentation)

        caption = self.captions[img_name]
        return image, segmentation, caption


# Transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
import torch
import torch.nn.functional as F

def apply_sobel_operator(segmentation):
    sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
    sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3)
    
    # Ensure the segmentation tensor is float and has a batch dimension
    if len(segmentation.shape) == 2:
        segmentation = segmentation.unsqueeze(0).unsqueeze(0)
    elif len(segmentation.shape) == 3:
        segmentation = segmentation.unsqueeze(1)

    edges_x = F.conv2d(segmentation, sobel_x, padding=1)
    edges_y = F.conv2d(segmentation, sobel_y, padding=1)

    edges = torch.sqrt(edges_x**2 + edges_y**2)

    threshold = edges.mean() * 1.5
    binary_mask = (edges > threshold).float()

    return binary_mask

In [6]:
import torch.nn as nn
from transformers import BertModel, BertTokenizer

class DiffusionModel(nn.Module):
    def __init__(self):
        super(DiffusionModel, self).__init__()
        self.text_encoder = BertModel.from_pretrained('bert-base-uncased')
        self.image_encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )

        self.mu_network = nn.Sequential(
            nn.Conv2d(128, 3, kernel_size=1)
        )

        # Network to predict dynamic beta values (noise levels)
        self.beta_network = nn.Sequential(
            nn.Linear(768, 100),
            nn.Sigmoid()
        )

    def forward(self, x_t, m, captions, t):
        # Encode text
        inputs = tokenizer(captions, return_tensors="pt", padding=True, truncation=True)
        text_features = self.text_encoder(**inputs).pooler_output

        # Get dynamic noise level (beta_t) from text features
        beta_schedule = self.beta_network(text_features)
        beta_t = beta_schedule[:, t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)

        # Apply noise where m is 0 (non-segmentation areas)
        noise = torch.randn_like(x_t) * torch.sqrt(beta_t)
        x_t1_pred = torch.sqrt(1 - beta_t) * x_t + (1 - m) * noise

        return x_t1_pred

model = DiffusionModel()
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [7]:
def diffusion_loss(x_t, x_t1_pred, m):
    return ((x_t1_pred - x_t) ** 2 * (1 - m)).mean()

In [11]:
import wandb
from sklearn.model_selection import KFold

T = 1000  #Total number of diffusion steps
beta_start = 0.0001
beta_end = 0.02

# Linear schedule
beta_t = torch.linspace(beta_start, beta_end, steps=T)

n_splits = 5
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
num_epochs = 10

# Initialize a new W&B run
wandb.init(project='image_captioning_with_diffusion_models', entity='michaelpeng72', config={
    "epochs": num_epochs,
    "batch_size": 16,
    "learning_rate": 1e-4,
    "n_splits": n_splits,
    "shuffle": True,
    "random_state": 42
})

# Config is accessible via wandb.config
config = wandb.config

full_dataset = ImageCaptionDataset('datasets/images', 'datasets/segmented', 'datasets/map.json', transform=transform)

for fold, (train_idx, valid_idx) in enumerate(kf.split(full_dataset)):
    print(f"Starting Fold {fold+1}/{n_splits}")
    wandb.init(project='image_captioning_with_diffusion_models', entity='michaelpeng72',
               group="Experiment-X", job_type=f"Fold-{fold+1}", reinit=True)

    # Splitting the dataset into train and validation for the current fold
    train_subset = torch.utils.data.Subset(full_dataset, train_idx)
    valid_subset = torch.utils.data.Subset(full_dataset, valid_idx)

    # Create DataLoader for train and validation subsets
    train_loader = DataLoader(train_subset, batch_size=16, shuffle=True)
    valid_loader = DataLoader(valid_subset, batch_size=16, shuffle=False)

    # Initialize the model and optimizer
    model = DiffusionModel()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    # Training phase
    model.train()
    for epoch in range(num_epochs):
        train_loss = 0
        for images, segmentations, captions in train_loader:
            m = apply_sobel_operator(segmentations)
            t = torch.randint(0, 100, (1,)).item()

            # Forward pass
            x_t1_pred = model(images, m, captions, t)
            loss = diffusion_loss(images, x_t1_pred, m)
            train_loss += loss.item()

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        train_loss /= len(train_loader)
        wandb.log({"Train Loss": train_loss})

        # Validation phase
        model.eval() 
        valid_loss = 0
        with torch.no_grad():
            for images, segmentations, captions in valid_loader:
                m = (segmentations > 0.5).float()
                t = torch.randint(0, 100, (1,)).item()
                
                x_t1_pred = model(images, m, captions, t)
                loss = diffusion_loss(images, x_t1_pred, m)
                valid_loss += loss.item()

        valid_loss /= len(valid_loader)

        wandb.log({"Validation Loss": valid_loss})
        print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Validation Loss: {valid_loss:.4f}")
        
wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mmichaelpeng72[0m. Use [1m`wandb login --relogin`[0m to force relogin


Starting Fold 1/5


Epoch 1, Train Loss: 0.3825, Validation Loss: 0.1059
Epoch 2, Train Loss: 0.2413, Validation Loss: 0.0558
Epoch 3, Train Loss: 0.1154, Validation Loss: 0.0299
Epoch 4, Train Loss: 0.0789, Validation Loss: 0.0190
Epoch 5, Train Loss: 0.0456, Validation Loss: 0.0094
Epoch 6, Train Loss: 0.0283, Validation Loss: 0.0090
Epoch 7, Train Loss: 0.0223, Validation Loss: 0.0065
Epoch 8, Train Loss: 0.0163, Validation Loss: 0.0044
Epoch 9, Train Loss: 0.0128, Validation Loss: 0.0035
Epoch 10, Train Loss: 0.0097, Validation Loss: 0.0031
Starting Fold 2/5


0,1
Train Loss,█▅▃▂▂▁▁▁▁▁
Validation Loss,█▅▃▂▁▁▁▁▁▁

0,1
Train Loss,0.00975
Validation Loss,0.00312


Epoch 1, Train Loss: 0.3762, Validation Loss: 0.0933
Epoch 2, Train Loss: 0.2318, Validation Loss: 0.0446
Epoch 3, Train Loss: 0.1155, Validation Loss: 0.0297
Epoch 4, Train Loss: 0.0654, Validation Loss: 0.0183
Epoch 5, Train Loss: 0.0474, Validation Loss: 0.0129
Epoch 6, Train Loss: 0.0290, Validation Loss: 0.0075
Epoch 7, Train Loss: 0.0187, Validation Loss: 0.0062
Epoch 8, Train Loss: 0.0181, Validation Loss: 0.0043
Epoch 9, Train Loss: 0.0124, Validation Loss: 0.0035
Epoch 10, Train Loss: 0.0104, Validation Loss: 0.0032
Starting Fold 3/5


0,1
Train Loss,█▅▃▂▂▁▁▁▁▁
Validation Loss,█▄▃▂▂▁▁▁▁▁

0,1
Train Loss,0.01035
Validation Loss,0.0032


Epoch 1, Train Loss: 0.3588, Validation Loss: 0.0847
Epoch 2, Train Loss: 0.2158, Validation Loss: 0.0427
Epoch 3, Train Loss: 0.0932, Validation Loss: 0.0324
Epoch 4, Train Loss: 0.0851, Validation Loss: 0.0196
Epoch 5, Train Loss: 0.0435, Validation Loss: 0.0129
Epoch 6, Train Loss: 0.0289, Validation Loss: 0.0075
Epoch 7, Train Loss: 0.0201, Validation Loss: 0.0052
Epoch 8, Train Loss: 0.0138, Validation Loss: 0.0034
Epoch 9, Train Loss: 0.0113, Validation Loss: 0.0031
Epoch 10, Train Loss: 0.0100, Validation Loss: 0.0025
Starting Fold 4/5


0,1
Train Loss,█▅▃▃▂▁▁▁▁▁
Validation Loss,█▄▄▂▂▁▁▁▁▁

0,1
Train Loss,0.01002
Validation Loss,0.00248


Epoch 1, Train Loss: 0.3652, Validation Loss: 0.0988
Epoch 2, Train Loss: 0.2257, Validation Loss: 0.0518
Epoch 3, Train Loss: 0.1258, Validation Loss: 0.0212
Epoch 4, Train Loss: 0.0700, Validation Loss: 0.0123
Epoch 5, Train Loss: 0.0357, Validation Loss: 0.0130
Epoch 6, Train Loss: 0.0260, Validation Loss: 0.0078
Epoch 7, Train Loss: 0.0201, Validation Loss: 0.0047
Epoch 8, Train Loss: 0.0156, Validation Loss: 0.0040
Epoch 9, Train Loss: 0.0123, Validation Loss: 0.0032
Epoch 10, Train Loss: 0.0096, Validation Loss: 0.0026
Starting Fold 5/5


0,1
Train Loss,█▅▃▂▂▁▁▁▁▁
Validation Loss,█▅▂▂▂▁▁▁▁▁

0,1
Train Loss,0.00955
Validation Loss,0.00262


Epoch 1, Train Loss: 0.3697, Validation Loss: 0.0804
Epoch 2, Train Loss: 0.2088, Validation Loss: 0.0376
Epoch 3, Train Loss: 0.1089, Validation Loss: 0.0254
Epoch 4, Train Loss: 0.0557, Validation Loss: 0.0104
Epoch 5, Train Loss: 0.0343, Validation Loss: 0.0069
Epoch 6, Train Loss: 0.0218, Validation Loss: 0.0057
Epoch 7, Train Loss: 0.0173, Validation Loss: 0.0045
Epoch 8, Train Loss: 0.0127, Validation Loss: 0.0029
Epoch 9, Train Loss: 0.0099, Validation Loss: 0.0024
Epoch 10, Train Loss: 0.0074, Validation Loss: 0.0023


0,1
Train Loss,█▅▃▂▂▁▁▁▁▁
Validation Loss,█▄▃▂▁▁▁▁▁▁

0,1
Train Loss,0.00742
Validation Loss,0.00231


In [12]:
#save the model to ../models
torch.save(model.state_dict(), '../models/diffusion_model.pth')

In [7]:
#load model from ../models
model = DiffusionModel()
model.load_state_dict(torch.load('../models/diffusion_model.pth'))

<All keys matched successfully>