# README:
Upload the damaged painting and damage mask for inpainting

In [None]:
!pip install -q diffusers transformers accelerate torch pillow matplotlib

In [None]:
import io
import os
import numpy as np
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import tensorflow as tf
from huggingface_hub import hf_hub_download, HfApi
from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline
from google.colab import files
print("GPU available:", torch.cuda.is_available())
print("Device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")

In [None]:
# Upload your image
print("Please upload your painting:")
uploaded = files.upload()
image_path = next(iter(uploaded))
original_image = Image.open(io.BytesIO(uploaded[image_path])).convert("RGB")

# Display the image
plt.figure(figsize=(8, 8))
plt.imshow(original_image)
plt.axis('off')
plt.show()

In [None]:
print("Please upload an existing mask (white=area to fill)")

uploaded_mask = files.upload()
mask_path = next(iter(uploaded_mask))
mask_image = Image.open(io.BytesIO(uploaded_mask[mask_path])).convert("L")

# Convert mask to binary (white=255, black=0)
mask_array = np.array(mask_image)
mask_array = np.where(mask_array > 128, 255, 0)
mask_image = Image.fromarray(mask_array.astype(np.uint8))

# Display mask
plt.figure(figsize=(8, 8))
plt.imshow(mask_image, cmap='gray')
plt.axis('off')
plt.show()

In [None]:
# classifier for art movement
model_path = hf_hub_download(
    repo_id="maximiliannl/art_movement_classifier",
    filename="trained_cnn_final_8_13.pth",
    revision="main"
)
print(f"Model downloaded to: {model_path}")

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1), # [B,16,1024,1024]
            nn.ReLU(),
            nn.MaxPool2d(2), # [B,16,512,512]

            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # [B,32,256,256]
            nn.ReLU(),
            nn.MaxPool2d(2), # [B,32,128,128]

            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # [B,64,64,64]
            nn.ReLU(),
            nn.MaxPool2d(2), # [B,64,32,32]

            nn.Flatten(),
            nn.Linear(64*32*32, 512),
            nn.ReLU(),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.model(x)

# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Download model from Hugging Face Hub
model_path = hf_hub_download(
    repo_id="maximiliannl/art_movement_classifier",
    filename="trained_cnn_final_8_13.pth",
    revision="main"
)
print(f"Model downloaded to: {model_path}")

# Load the model
model_for_classification = SimpleCNN(num_classes=10).to(device)
model_for_classification.load_state_dict(torch.load(model_path, map_location=device))
model_for_classification.eval()

# Verify the model is loaded correctly
print("Model successfully loaded from Hugging Face Hub!")

In [None]:
transform = transforms.Compose([
    transforms.Resize((2048, 2048)),
    transforms.ToTensor(),
])

def classify_existing_image(image):
    try:
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)
        image = image.convert('RGB')

        # Preprocess image
        image_tensor = transform(image).unsqueeze(0).to(device)
        print(f"Input tensor shape: {image_tensor.shape}")

        # Make prediction
        with torch.no_grad():
            output = model_for_classification(image_tensor)
            probabilities = torch.nn.functional.softmax(output[0], dim=0)

        class_names = [
            "cubism", "dada", "early_modern", "impressionist", "medieval",
            "minimalism", "realism", "renaissance", "romantic", "symbolism"
        ]

        predicted_class = class_names[torch.argmax(probabilities).item()]
        display(image)
        print(f"Predicted class: {predicted_class}")
        print("Class probabilities:")
        for i, prob in enumerate(probabilities):
            print(f"{class_names[i]}: {prob.item():.4f}")
        return predicted_class

    except Exception as e:
        print(f"Error during classification: {str(e)}")
        return None

print("Classifying your existing image...")
movement_prediction = classify_existing_image(original_image)

In [None]:
# Load the inpainting pipeline
model_id = f"maximiliannl/{movement_prediction}_expert"  # Your fine-tuned model

pipe = StableDiffusionInpaintPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
).to("cuda")

# Enable attention slicing if you get memory errors
pipe.enable_attention_slicing()

In [None]:
# Perform inpainting with simple prompt and negative prompt
prompt = f"Carefully restore and complete this {movement_prediction} era artwork in a manner that preserves its original composition, brushwork, and textures, emphasizing historical accuracy, authentic color palettes, and fine detail true to the {movement_prediction} period."
negative_prompt = "blurry, distorted, artifacts, bad quality, text, watermark"

# Resize to multiples of 8 (better for diffusion models)
width, height = original_image.size
new_width = width - (width % 8)
new_height = height - (height % 8)
resized_image = original_image.resize((new_width, new_height))
resized_mask = mask_image.resize((new_width, new_height))

# Generate
result = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    image=resized_image,
    mask_image=resized_mask,
    guidance_scale=17.5,
    num_inference_steps=50
).images[0]

# Resize back to original dimensions if needed
result = result.resize(original_image.size)

In [None]:
# Display and save results
plt.figure(figsize=(16, 8))

plt.subplot(1, 3, 1)
plt.imshow(original_image)
plt.title("Original Image")
plt.axis('off')

plt.subplot(1, 3, 2)
plt.imshow(mask_image, cmap='gray')
plt.title("Mask (white=inpainted area)")
plt.axis('off')

plt.subplot(1, 3, 3)
plt.imshow(result)
plt.title("Inpainted Result")
plt.axis('off')

plt.tight_layout()
plt.show()

# Save result
result.save("inpainted_result.png")
print("Saved as 'inpainted_result.png'")
files.download("inpainted_result.png")