### Use the below two cells to run shell commands in colab without having to pay for Pro

In [None]:
# Use this to run shell commands in Google Colab
from IPython.display import JSON
from google.colab import output
from subprocess import getoutput
import os

def shell(command):
  if command.startswith('cd'):
    path = command.strip().split(maxsplit=1)[1]
    os.chdir(path)
    return JSON([''])
  return JSON([getoutput(command)])
output.register_callback('shell', shell)

In [None]:
#@title Colab Shell
%%html
<div id=term_demo></div>
<script src="https://code.jquery.com/jquery-latest.js"></script>
<script src="https://cdn.jsdelivr.net/npm/jquery.terminal/js/jquery.terminal.min.js"></script>
<link href="https://cdn.jsdelivr.net/npm/jquery.terminal/css/jquery.terminal.min.css" rel="stylesheet"/>
<script>
  $('#term_demo').terminal(async function(command) {
      if (command !== '') {
          try {
              let res = await google.colab.kernel.invokeFunction('shell', [command])
              let out = res.data['application/json'][0]
              this.echo(new String(out))
          } catch(e) {
              this.error(new String(e));
          }
      } else {
          this.echo('');
      }
  }, {
      greetings: 'Welcome to Colab Shell',
      name: 'colab_demo',
      height: 250,
      prompt: 'colab > '
  });

In [None]:
import zipfile

# If running start from the main directory
main_path = "/content/drive/MyDrive/Greek_Pottery_In_Painting/" if colab else "./"

specific_data = "November_24_Dataset/"
full_path = main_path + specific_data + "train.zip"
BASE_DIR = "./ancient-greek-pottery-restoration/dataset"
with zipfile.ZipFile(full_path, "r") as zip:
    zip.extractall(BASE_DIR)

In [2]:
!git clone https://github.com/ddevaul/ancient-greek-pottery-restoration

Cloning into 'ancient-greek-pottery-restoration'...
remote: Enumerating objects: 56, done.[K
remote: Counting objects: 100% (56/56), done.[K
remote: Compressing objects: 100% (41/41), done.[K
remote: Total 56 (delta 21), reused 49 (delta 14), pack-reused 0 (from 0)[K
Receiving objects: 100% (56/56), 12.33 MiB | 40.21 MiB/s, done.
Resolving deltas: 100% (21/21), done.


In [2]:
!cd PotteryRestoration
!ls
!rm -rf PotterRestoration/dataset/train/full/.ipynb_checkpoints

drive  PotteryRestoration  sample_data


In [5]:
# Check if we are running on google colab or locally
try:
    import google.colab

    print("Running on Google Colab")
    colab = True
except ImportError:
    print("Running locally")
    colab = False

Running locally


# image pipeline code

In [6]:
if colab:
    from google.colab import drive

    drive.mount("/content/drive")

In [7]:
#!rm -rf ./PotteryRestoration/dataset/train/full # Be careful with this command

In [8]:
import zipfile

# If running start from the main directory
main_path = "drive/MyDrive/Greek_Pottery_In_Painting/" if colab else "./"

specific_data = "November_21_Images/"
full_path = main_path + specific_data + "images.zip"
BASE_DIR = "./images"
with zipfile.ZipFile(full_path, "r") as zip:
    zip.extractall(BASE_DIR)

In [9]:
import os
import random

path_to_sub_dir = BASE_DIR + "/" + os.listdir(BASE_DIR)[0]
image_names = [image_name for image_name in os.listdir(path_to_sub_dir)]
random.shuffle(image_names)
print(image_names[0])
print(f"Total Number of Images: {len(image_names)}")

302189-BLACK-FIGURE.jpg
Total Number of Images: 6714


In [10]:
import pandas as pd

column_names = ["image_path", "image_url", "pottery_style", "pottery_shape"]
aggregate_data_df = pd.read_csv(
    f"{main_path + specific_data}/aggregrate_data.csv", names=column_names
)
aggregate_data_df.head(1)

Unnamed: 0,image_path,image_url,pottery_style,pottery_shape
0,images-nicky-test/208223-BLACK-FIGURE.jpg,http://www.beazley.ox.ac.uk/record/8A9E1A2D-DE...,BLACK-FIGURE,LEKYTHOS


In [11]:
new_image_prefix = "BF_image"
aggregate_data_df['uniform_image_name'] = [f"{new_image_prefix}{i}.jpg" for i in range(len(aggregate_data_df))]

In [12]:
aggregate_data_df.head(1)

Unnamed: 0,image_path,image_url,pottery_style,pottery_shape,uniform_image_name
0,images-nicky-test/208223-BLACK-FIGURE.jpg,http://www.beazley.ox.ac.uk/record/8A9E1A2D-DE...,BLACK-FIGURE,LEKYTHOS,BF_image0.jpg


In [13]:
import shutil

image_train_base_path = (
    "./PotteryRestoration/dataset/train/original_images"
    if colab
    else "./dataset/train/original_images"
)
os.makedirs(image_train_base_path, exist_ok=True)

for index, row in aggregate_data_df.iterrows():
    original_image_path = row["image_path"]
    new_name = row["uniform_image_name"]
    shutil.copy(
        f"{BASE_DIR}/{original_image_path}", f"{image_train_base_path}/{new_name}"
    )

In [14]:
# Save our aggregate data to the correct directory (We can just use this as "captions.csv")
aggregate_data_location = (
    "./PotteryRestoration/dataset/train/aggregate_data.csv"
    if colab
    else "./dataset/train/aggregate_data.csv"
)
aggregate_data_df.to_csv(aggregate_data_location, index=False)

In [15]:
mask_images_code_path = "./PotteryRestoration/dataset/mask_images.py" if colab else "./dataset/mask_images.py"
!python {mask_images_code_path}

Generating masks for 6716 images using 9 CPUs...
Processing Images: 100%|████████████████████████| 34/34 [00:47<00:00,  1.39s/it]
Saved mappings to ./dataset/train/mask_mappings.csv


In [1]:
colab = False
if colab:
    from PotteryRestoration.dataset.VaseDataset import VaseDataset
else:
    from dataset.VaseDataset import VaseDataset
from torchvision import transforms
from torch.utils.data import DataLoader

from torch.utils.data import random_split
# We can try 512 x 512 later but it takes much more GPU Ram and training time
transform = transforms.Compose(
    [
        transforms.Resize((256, 256)),
        transforms.ToTensor(),  # Convert images to PyTorch tensors
    ]
)

dataset_root_dir = "./PotteryRestoration/dataset/train" if colab else "./dataset/train"
agg_data_file_name = "aggregate_data.csv"
mask_mappings_file_path = "mask_mappings.csv"


dataset = VaseDataset(
    dataset_root_dir=dataset_root_dir,
    agg_data_file_name=agg_data_file_name,
    mask_mappings_file_path=mask_mappings_file_path,
    transform=transform,
)

train_fraction = 0.8
val_fraction = 0.1
test_fraction = 0.1

# Get size of each split
total_size = len(dataset)
train_size = int(total_size * train_fraction)
val_size = int(total_size * val_fraction)
test_size = total_size - train_size - val_size # Make sure the sizes add up

# Perform the split
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# Make data loaders for each of our datasets
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Check on the sizes of each dataset
print(f"Train size: {len(train_dataset)}")
print(f"Validation size: {len(val_dataset)}")
print(f"Test size: {len(test_dataset)}")

Train size: 53728
Validation size: 6716
Test size: 6716


In [2]:
for batch in train_loader:
    print(batch["masked_images"].shape)
    print(batch["full_images"].shape)
    print(batch["masks"].shape)
    print(batch["text"])
    break

(564, 276)
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
(600, 241)
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
(600, 348)
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
(531, 357)
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
(346, 657)
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
(600, 196)
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]

In [3]:
print(batch['masked_images'].shape)
print(batch['full_images'].shape)
print(batch['masks'].shape)
print(batch['text'])

torch.Size([32, 3, 256, 256])
torch.Size([32, 3, 256, 256])
torch.Size([32, 1, 256, 256])
['Ancient Greek Pottery. Style: BLACK-FIGURE, Shape: LEKYTHOS.', 'Ancient Greek Pottery. Style: BLACK-FIGURE, Shape: LEKYTHOS.', 'Ancient Greek Pottery. Style: BLACK-FIGURE, Shape: LEKYTHOS.', 'Ancient Greek Pottery. Style: BLACK-FIGURE, Shape: AMPHORA NECK.', 'Ancient Greek Pottery. Style: BLACK-FIGURE, Shape: AMPHORA NECK.', 'Ancient Greek Pottery. Style: BLACK-FIGURE, Shape: LEKYTHOS.', 'Ancient Greek Pottery. Style: BLACK-FIGURE, Shape: AMPHORA NECK.', 'Ancient Greek Pottery. Style: BLACK-FIGURE, Shape: AMPHORA NECK.', 'Ancient Greek Pottery. Style: BLACK-FIGURE, Shape: AMPHORA NECK.', 'Ancient Greek Pottery. Style: BLACK-FIGURE, Shape: AMPHORA NECK.', 'Ancient Greek Pottery. Style: BLACK-FIGURE, Shape: LEKYTHOS.', 'Ancient Greek Pottery. Style: BLACK-FIGURE, Shape: AMPHORA NECK.', 'Ancient Greek Pottery. Style: BLACK-FIGURE, Shape: AMPHORA NECK.', 'Ancient Greek Pottery. Style: BLACK-FIGURE, 

In [70]:
!mkdir /content/drive/MyDrive/Greek_Pottery_In_Painting/nov_23_image_and_mask/

In [None]:
!zip -r /content/drive/MyDrive/Greek_Pottery_In_Painting/nov_23_image_and_mask/train.zip ./PotteryRestoration/dataset/train

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  adding: PotteryRestoration/dataset/train/masked/BF_image4616_masked_9.png (deflated 2%)
  adding: PotteryRestoration/dataset/train/masked/BF_image4390_masked_7.png (deflated 1%)
  adding: PotteryRestoration/dataset/train/masked/BF_image2968_masked_0.png (deflated 1%)
  adding: PotteryRestoration/dataset/train/masked/BF_image2846_masked_1.png (deflated 1%)
  adding: PotteryRestoration/dataset/train/masked/BF_image4110_masked_0.png (deflated 1%)
  adding: PotteryRestoration/dataset/train/masked/BF_image4590_masked_5.png (deflated 1%)
  adding: PotteryRestoration/dataset/train/masked/BF_image1698_masked_8.png (deflated 2%)
  adding: PotteryRestoration/dataset/train/masked/BF_image3057_masked_3.png (deflated 2%)
  adding: PotteryRestoration/dataset/train/masked/BF_image1610_masked_1.png (deflated 1%)
  adding: PotteryRestoration/dataset/train/masked/BF_image2350_masked_5.png (deflated 1%)
  adding: PotteryRestoration/datase

In [66]:
!zip -r ./PotteryRestoration/dataset/train . -i /content/drive/MyDrive/Greek_Pottery_In_Painting/nov_23_image_and_mask/train



### Some online resources for training unets
- 
- https://huggingface.co/learn/diffusion-course/en/unit2/2
- https://github.com/huggingface/diffusers/discussions/8458
- https://discuss.huggingface.co/t/fine-tuning-controlnet-xs-with-sdxl/92652

In [4]:
from diffusers import StableDiffusionInpaintPipeline
from accelerate import Accelerator
import torch
from torch.utils.data import DataLoader
from transformers import AdamW

# Load pipeline
pipe = StableDiffusionInpaintPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting")
device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"

# Freeze VAE parameters
for param in pipe.vae.parameters():
    param.requires_grad = False

# Fine-tune only the U-Net and text encoder
for param in pipe.unet.parameters():
    param.requires_grad = True
for param in pipe.text_encoder.parameters():
    param.requires_grad = True

pipe.to(device)

# Optimizer
optimizer = AdamW(
    [{"params": pipe.unet.parameters()}, {"params": pipe.text_encoder.parameters()}],
    lr=5e-5
)

# DataLoader placeholder (replace `train_dataloader` with your actual DataLoader)
# train_dataloader = DataLoader(...)

# Use Accelerator for distributed training
accelerator = Accelerator()
pipe, optimizer, train_dataloader = accelerator.prepare(pipe, optimizer, train_loader)

NUM_EPOCHS = 2

# Training loop
for epoch in range(NUM_EPOCHS):
    pipe.unet.train()
    pipe.text_encoder.train()

    for batch in train_dataloader:
        # Get inputs
        masked_images = batch["masked_images"].to(device)
        full_images = batch["full_images"].to(device)
        masks = batch["masks"].to(device)  # Binary masks
        prompts = batch["text"]

        # Tokenize text prompts
        tokenized_prompts = pipe.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(device)
        text_embeddings = pipe.text_encoder(**tokenized_prompts).last_hidden_state

        # Encode masked images into latent space
        latents = pipe.vae.encode(masked_images).latent_dist.sample()
        latents = latents * pipe.vae.config.scaling_factor

        # Assert latent dimensions
        assert latents.shape[1] == 4, f"Latent channels should be 4, got {latents.shape[1]}"
        assert latents.shape[2] % 8 == 0 and latents.shape[3] % 8 == 0, \
            "Latent dimensions should be divisible by 8 for the UNet"

        # Add noise to the latents
        batch_size = latents.size(0)
        timesteps = torch.randint(0, pipe.scheduler.num_train_timesteps, (batch_size,), device=device).long()
        noise = torch.randn_like(latents)
        noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)

        # Resize mask to match latent spatial dimensions
        latent_masks = torch.nn.functional.interpolate(masks, size=noisy_latents.shape[-2:])  # Resize mask
        latent_masks = latent_masks[:, None, :, :]  # Ensure shape is (B, 1, H, W)

        # Fix latent_masks shape by squeezing the extra singleton dimension
        if latent_masks.ndim == 5:  # Check if there's an extra dimension
            latent_masks = latent_masks.squeeze(2)  # Remove the extra dimension
        # Assert mask shape matches expected dimensions
        print(f"Noisy latents shape: {noisy_latents.shape}")
        print(f"Latent masks shape: {latent_masks.shape}")

        assert latent_masks.ndim == 4, f"Mask should have 4 dimensions, got {latent_masks.ndim}"
        assert latent_masks.shape[1] == 1, f"Mask must have 1 channel, got {latent_masks.shape[1]}"
        assert latent_masks.shape[2:] == noisy_latents.shape[2:], \
            f"Mask spatial dimensions {latent_masks.shape[2:]} must match latents {noisy_latents.shape[2:]}"

        # Generate spatial encodings
        batch_size, _, height, width = noisy_latents.shape
        x = torch.linspace(-1, 1, steps=width, device=device).view(1, 1, 1, -1).expand(batch_size, 1, height, width)
        y = torch.linspace(-1, 1, steps=height, device=device).view(1, 1, -1, 1).expand(batch_size, 1, height, width)
        spatial_encodings = torch.cat([x, y], dim=1)  # Shape: (B, 2, H, W)

        # Concatenate noisy latents, mask, and spatial encodings
        unet_input = torch.cat([noisy_latents, latent_masks, spatial_encodings], dim=1)

        # Add extra dummy channels (if required)
        extra_channels = torch.zeros(unet_input.shape[0], 2, unet_input.shape[2], unet_input.shape[3], device=device)
        unet_input = torch.cat([unet_input, extra_channels], dim=1)

        # Assert the input shape
        assert unet_input.shape[1] == 9, f"UNet input must have 9 channels, got {unet_input.shape[1]}"

        # Forward pass through UNet
        unet_output = pipe.unet(
            sample=unet_input,
            timestep=timesteps,
            encoder_hidden_states=text_embeddings
        ).sample

        # Assert UNet output shape matches latent input
        assert unet_output.shape == latents.shape, \
            f"UNet output shape mismatch: {unet_output.shape} != {latents.shape}"

        # Decode the output latents back to image space
        reconstructed_images = pipe.vae.decode(unet_output / pipe.vae.config.scaling_factor).sample

        # Assert decoded images match the size of full images
        assert reconstructed_images.shape == full_images.shape, \
            f"Decoded images shape mismatch: {reconstructed_images.shape} != {full_images.shape}"

        # Compute pixel-wise loss
        loss = torch.nn.functional.mse_loss(reconstructed_images, full_images)

        # Backpropagation
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()

    print(f"Epoch {epoch + 1}/{NUM_EPOCHS} completed. Loss: {loss.item():.4f}")

  from .autonotebook import tqdm as notebook_tqdm
Loading pipeline components...: 100%|██████████| 6/6 [00:00<00:00, 21.49it/s]


(600, 220)
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
(545, 390)
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
(600, 450)
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
(600, 361)
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
(600, 266)
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
(546, 156)
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


Noisy latents shape: torch.Size([32, 4, 32, 32])
Latent masks shape: torch.Size([32, 1, 32, 32])


RuntimeError: MPS backend out of memory (MPS allocated: 19.69 GB, other allocations: 648.86 MB, max allowed: 20.40 GB). Tried to allocate 640.00 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [None]:
!rm -rf `find -type d -name .ipynb_checkpoints`