In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
import tkinter as tk
from tkinter import filedialog, messagebox
import os

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --------------------
# U-Net colorization
# --------------------
class UNetColorization(nn.Module):
    def __init__(self):
        super(UNetColorization, self).__init__()
        self.enc1 = nn.Sequential(
            nn.Conv2d(1, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout(0.4)
        )
        self.enc2 = nn.Sequential(
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Dropout(0.4)
        )
        self.enc3 = nn.Sequential(
            nn.Conv2d(256, 512, 3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Dropout(0.4)
        )
        self.dec1 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(512, 128, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.dec3 = nn.Conv2d(256, 3, 3, padding=1)
        self.tanh = nn.Tanh()

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        d1 = self.dec1(e3)
        d2 = self.dec2(torch.cat([d1, e2], dim=1))
        d3 = self.dec3(torch.cat([d2, e1], dim=1))
        return self.tanh(d3)

# --------------------
# VGG Style Features
# --------------------
class VGGStyleLoss(nn.Module):
    def __init__(self):
        super(VGGStyleLoss, self).__init__()
        vgg = torchvision.models.vgg19(weights=torchvision..VGG19_Weights.DEFAULT).features.to(device).eval()
        self.layers = {
            '0': 'conv1_1',
            '5': 'conv2_2',
            '10': 'conv3_2',
            '19': 'conv4_2',
            '28': 'conv5_2'
        }
        self.model = nn.ModuleDict({name: nn.Sequential() for name in self.layers.values()})
        current_name = None
        for i, layer in enumerate(vgg.children()):
            for key, name in self.layers.items():
                if str(i) == key:
                    current_name = name
            if current_name is not None:
                if isinstance(layer, nn.ReLU):
                    self.model[current_name].add_module(str(len(self.model[current_name])), nn.ReLU(inplace=False))
                else:
                    self.model[current_name].add_module(str(len(self.model[current_name])), layer)

    def forward(self, x):
        features = {}
        out = x
        for name, module in self.model.items():
            for layer in module:
                out = layer(out)
            features[name] = out.clone()
        return features

# --------------------
# Visualization Effects
# --------------------
def torch_rgb_to_hsv(rgb):
    r, g, b = rgb[:, 0, :, :], rgb[:, 1, :, :], rgb[:, 2, :, :]
    max_val, _ = torch.max(rgb, dim=1)
    min_val, _ = torch.min(rgb, dim=1)
    diff = max_val - min_val
    h = torch.zeros_like(r)
    mask = (max_val == r) & (g >= b)
    h[mask] = (g[mask] - b[mask]) / diff[mask]
    mask = (max_val == r) & (g < b)
    h[mask] = (g[mask] - b[mask]) / diff[mask] + 6.0
    mask = max_val == g
    h[mask] = (b[mask] - r[mask]) / diff[mask] + 2.0
    mask = max_val == b
    h[mask] = (r[mask] - g[mask]) / diff[mask] + 4.0
    h = h / 6.0
    h[diff == 0.0] = 0.0
    s = torch.zeros_like(r)
    s[diff != 0.0] = diff[diff != 0.0] / max_val[diff != 0.0]
    v = max_val
    return torch.stack([h, s, v], dim=1)

def torch_hsv_to_rgb(hsv):
    h, s, v = hsv[:, 0, :, :], hsv[:, 1, :, :], hsv[:, 2, :, :]
    i = (h * 6.0).floor()
    f = h * 6.0 - i
    p = v * (1.0 - s)
    q = v * (1.0 - s * f)
    t = v * (1.0 - s * (1.0 - f))
    i_mod = i % 6
    r = torch.zeros_like(h); g = torch.zeros_like(h); b = torch.zeros_like(h)
    r[i_mod == 0.0] = v[i_mod == 0.0]; g[i_mod == 0.0] = t[i_mod == 0.0]; b[i_mod == 0.0] = p[i_mod == 0.0]
    r[i_mod == 1.0] = q[i_mod == 1.0]; g[i_mod == 1.0] = v[i_mod == 1.0]; b[i_mod == 1.0] = p[i_mod == 1.0]
    r[i_mod == 2.0] = p[i_mod == 2.0]; g[i_mod == 2.0] = v[i_mod == 2.0]; b[i_mod == 2.0] = t[i_mod == 2.0]
    r[i_mod == 3.0] = p[i_mod == 3.0]; g[i_mod == 3.0] = q[i_mod == 3.0]; b[i_mod == 3.0] = v[i_mod == 3.0]
    r[i_mod == 4.0] = t[i_mod == 4.0]; g[i_mod == 4.0] = p[i_mod == 4.0]; b[i_mod == 4.0] = v[i_mod == 4.0]
    r[i_mod == 5.0] = v[i_mod == 5.0]; g[i_mod == 5.0] = p[i_mod == 5.0]; b[i_mod == 5.0] = q[i_mod == 5.0]
    return torch.stack([r, g, b], dim=1)

def exaggerate_colors(images, saturation_factor=1.5, value_factor=1.2):
    images = (images + 1) / 2.0
    images_hsv = torch_rgb_to_hsv(images)
    images_hsv[:, 1, :, :] = torch.clamp(images_hsv[:, 1, :, :] * saturation_factor, 0, 1)
    images_hsv[:, 2, :, :] = torch.clamp(images_hsv[:, 2, :, :] * value_factor, 0, 1)
    color_exaggerated_images = torch_hsv_to_rgb(images_hsv)
    color_exaggerated_images = color_exaggerated_images * 2.0 - 1.0
    return color_exaggerated_images

# --------------------
# Helper functions
# --------------------
def gram_matrix(tensor):
    b, c, h, w = tensor.size()
    features = tensor.view(b * c, h * w)
    gram = torch.mm(features, features.t())
    return gram.div(b * c * h * w)

def load_style_image(style_path, size=(256, 256)):
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    image = Image.open(style_path).convert('RGB')
    return transform(image).unsqueeze(0).to(device)

def preprocess_image(image_path, size=(256, 256)):
    transform_gray = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    transform_low = transforms.Compose([
        transforms.Resize((128, 128)),  # Increased from 32x32 to 128x128
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    image = Image.open(image_path).convert('L')
    low = transform_low(image).unsqueeze(0).to(device)
    high = transform_gray(image).unsqueeze(0).to(device)
    return low, high

def postprocess_image(tensor):
    tensor = tensor.cpu().clamp(-1, 1)
    tensor = (tensor + 1) / 2
    np_img = tensor.squeeze(0).permute(1, 2, 0).numpy()
    return Image.fromarray((np_img * 255).astype(np.uint8))

def apply_style_transfer(colorized, style_image, vgg, content_weight=1e2, style_weight=1e7, steps=200):
    opt_img = colorized.detach().clone().requires_grad_(True)
    optimizer = optim.Adam([opt_img], lr=0.01)
    content_features = vgg(colorized)
    style_features = vgg(style_image)
    style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}

    for _ in range(steps):
        optimizer.zero_grad()
        out_features = vgg(opt_img)
        content_loss = torch.mean((out_features['conv4_2'] - content_features['conv4_2']) ** 2)
        style_loss = 0
        for layer in style_grams:
            gm_out = gram_matrix(out_features[layer])
            gm_style = style_grams[layer]
            style_loss += torch.mean((gm_out - gm_style) ** 2)
        style_loss /= len(style_grams)
        total_loss = content_weight * content_loss + style_weight * style_loss
        total_loss.backward(retain_graph=True)  # Retain graph to allow multiple backward passes
        optimizer.step()
    return opt_img.detach()

# --------------------
# GUI Class
# --------------------
style_images = {
    'Van Gogh': 'vangogh.jpg',
    'Monet': 'monet.jpg',
    'Ukiyo-e': 'ukiyoe.jpg'
}

class ColorizationGUI:
    def __init__(self, root):
        self.root = root
        self.root.title("Image Colorization + Style Transfer with Enhanced Effects")
        self.model = UNetColorization().to(device)
        self.vgg = VGGStyleLoss().to(device)
        self.load_model()
        try:
            self.styles = {name: load_style_image(path) for name, path in style_images.items()}
        except FileNotFoundError as e:
            messagebox.showerror("Error", f"Missing style image: {e}")
            self.root.quit()

        tk.Label(root, text="Select grayscale image and style:").pack()
        self.style_var = tk.StringVar(value=list(style_images.keys())[0])
        tk.OptionMenu(root, self.style_var, *style_images.keys()).pack()
        tk.Button(root, text="Upload", command=self.upload_image).pack()
        tk.Button(root, text="Colorize & Style", command=self.process_image).pack()
        tk.Button(root, text="Save Output", command=self.save_image).pack()

        self.image_path = None
        self.output_image = None

    def load_model(self):
        weights_path = '../models/model_perceptual_weights.pth'
        if os.path.exists(weights_path):
            print(f"Loading perceptual weights from {weights_path}")
            state_dict = torch.load(weights_path, map_location=device)
            self.model.load_state_dict(state_dict)
        else:
            messagebox.showerror("Error", f"Cannot find {weights_path}. Please train the perceptual model using Code 2 first.")
            self.root.quit()
        self.model.eval()

    def upload_image(self):
        self.image_path = filedialog.askopenfilename(filetypes=[("Image files","*.png *.jpg *.jpeg")])
        if self.image_path:
            print(f"Selected {self.image_path}")

    def process_image(self):
        if not self.image_path:
            messagebox.showerror("Error", "No image selected!")
            return
        low_res, _ = preprocess_image(self.image_path)
        with torch.no_grad():
            colorized_low = self.model(low_res)
            colorized = torch.nn.functional.interpolate(colorized_low, size=(512, 512), mode='bilinear')
            # Apply color exaggeration
            colorized = exaggerate_colors(colorized, saturation_factor=1.5, value_factor=1.2)
        style_name = self.style_var.get()
        styled = apply_style_transfer(colorized, self.styles[style_name], self.vgg)
        self.output_image = postprocess_image(styled)
        self.output_image.show()

    def save_image(self):
        if self.output_image:
            path = filedialog.asksaveasfilename(defaultextension=".png")
            if path:
                self.output_image.save(path)
                messagebox.showinfo("Saved", f"Output saved to {path}")

# --------------------
# Run
# --------------------
def main():
    root = tk.Tk()
    app = ColorizationGUI(root)
    root.mainloop()

if __name__ == "__main__":
    main()

Loading perceptual weights from ../models/model_perceptual_weights.pth
Selected C:/Users/NADER KAREEM/Pictures/Screenshots/newss.png
