In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.models import vgg16
from skimage.metrics import structural_similarity as ssim
from PIL import Image, ImageEnhance, ImageFilter
import numpy as np
import tkinter as tk
from tkinter import filedialog

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.ToTensor(),
])

class DreamBooth(nn.Module):
    def __init__(self):
        super(DreamBooth, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(256, 3, kernel_size=3, padding=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.conv4(x)
        return x

class VGGFeatures(nn.Module):
    def __init__(self):
        super(VGGFeatures, self).__init__()
        vgg = vgg16(pretrained=True).features
        self.slice = nn.Sequential(*list(vgg)[:16])  

    def forward(self, x):
        return self.slice(x)

def stable_diffusion_augmentation(image, n_iterations=100, initial_sigma=0.1):
    image_tensor = transform(image).unsqueeze(0).to(device)
    
    model = DreamBooth().to(device)

    vgg_features = VGGFeatures().to(device)
    for param in vgg_features.parameters():
        param.requires_grad = False

    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for iteration in range(n_iterations):
        
        sigma = initial_sigma * (1 - iteration / n_iterations)

        noise = torch.randn_like(image_tensor) * sigma
        perturbed_image = image_tensor + noise

        optimizer.zero_grad()
        output = model(perturbed_image)
        
        l2_loss = torch.norm(output - perturbed_image)
        perceptual_loss = torch.norm(vgg_features(output) - vgg_features(image_tensor))
        loss = l2_loss + 0.01 * perceptual_loss  
        
        loss.backward()
        optimizer.step()

    augmented_image = output.squeeze(0).detach().cpu().numpy().transpose(1, 2, 0)
    augmented_image = np.clip(augmented_image, 0, 1) 

    augmented_image_pil = Image.fromarray((augmented_image * 255).astype(np.uint8))

    enhancer = ImageEnhance.Color(augmented_image_pil)
    augmented_image_pil = enhancer.enhance(1.5) 
    
    enhancer = ImageEnhance.Brightness(augmented_image_pil)
    augmented_image_pil = enhancer.enhance(1.2)  

    enhancer = ImageEnhance.Contrast(augmented_image_pil)
    augmented_image_pil = enhancer.enhance(1.2) 

    augmented_image_pil = augmented_image_pil.filter(ImageFilter.SHARPEN)

    augmented_image_pil = augmented_image_pil.filter(ImageFilter.MedianFilter(size=3))

    return augmented_image_pil

def select_files():
    root = tk.Tk()
    root.withdraw() 
    file_paths = filedialog.askopenfilenames()  
    root.destroy()  
    return file_paths

def select_output_directory():
    root = tk.Tk()
    root.withdraw() 
    directory_path = filedialog.askdirectory()  
    root.destroy() 
    return directory_path

def calculate_ssim(image1, image2):
    gray1 = np.array(image1.convert('L'))
    gray2 = np.array(image2.convert('L'))
    return ssim(gray1, gray2, data_range=gray2.max() - gray2.min())

file_paths = select_files()
output_directory = select_output_directory()

if file_paths and output_directory:
    for file_path in file_paths:
        image = Image.open(file_path).convert('RGB')
        
        augmented_image = stable_diffusion_augmentation(image)

        ssim_value = calculate_ssim(image, augmented_image)

        output_path = os.path.join(output_directory, 'augmented_' + os.path.basename(file_path))
        augmented_image.save(output_path)
        print(f'Saved augmented image as {output_path}')
        print(f'SSIM: {ssim_value:.4f}')
else:
    print("No files selected or no output directory specified.")