In [23]:
import torch
from PIL import Image
import numpy as np
import torchvision.transforms.functional as TF
from torchvision.utils import save_image
import random
import os

import cv2

from unet_2 import UNetGenerator, N_CLASSES,DEVICE,IMG_SIZE

In [24]:
OUTPUT_DIR = "./test_outputs"
IMG_SIZE = 256
N_CLASSES = 7
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

os.makedirs(OUTPUT_DIR, exist_ok=True)


In [25]:
color_to_class = {
    (0, 0, 0): 0,
    (0, 255, 255): 1,
    (255, 255, 0): 2,
    (255, 0, 255): 3,
    (0, 255, 0): 4,
    (0, 0, 255): 5,
    (255, 255, 255): 6,
}

In [26]:
def mask_to_onehot(mask_crop):
    mask_np = np.array(mask_crop)
    class_mask = np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.int64)

    for rgb, idx in color_to_class.items():
        class_mask[np.all(mask_np == rgb, axis=-1)] = idx

    mask_tensor = torch.from_numpy(class_mask).long()
    onehot = torch.zeros((N_CLASSES, IMG_SIZE, IMG_SIZE), dtype=torch.float32)
    onehot.scatter_(0, mask_tensor.unsqueeze(0), 1.0)

    return onehot.unsqueeze(0)

In [27]:
G = UNetGenerator().to(DEVICE)
G.load_state_dict(torch.load("G_120.pt", map_location=DEVICE))
G.eval()


UNetGenerator(
  (d1): UNetDown(
    (model): Sequential(
      (0): Conv2d(7, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): LeakyReLU(negative_slope=0.2)
    )
  )
  (d2): UNetDown(
    (model): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
  )
  (d3): UNetDown(
    (model): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.2)
    )
  )
  (d4): UNetDown(
    (model): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(neg

In [28]:
import tkinter as tk
from tkinter import ttk, colorchooser
from PIL import Image, ImageDraw, ImageTk


In [29]:
PALETTE = {
    "background": (0, 0, 0),
    "urban": (0, 255, 255),
    "agriculture": (255, 255, 0),
    "rangeland": (255, 0, 255),
    "forest": (0, 255, 0),
    "water": (0, 0, 255),
    "barren": (255, 255, 255)
}


In [30]:
COLOR_TO_CLASS = {v: i for i, v in enumerate(PALETTE.values())}
CLASS_COLORS = list(PALETTE.values())

In [31]:
class MaskDrawer:
    def __init__(self, root):
        self.root = root
        root.title("DeepGlobe Mask Drawer")

        self.canvas = tk.Canvas(root, width=IMG_SIZE, height=IMG_SIZE, bg="black")
        self.canvas.grid(row=0, column=0, rowspan=20)

        self.image = Image.new("RGB", (IMG_SIZE, IMG_SIZE), (0, 0, 0))
        self.draw = ImageDraw.Draw(self.image)

        self.brush_size = 8
        self.current_color = CLASS_COLORS[0]

        self.canvas.bind("<B1-Motion>", self.paint)

        tk.Label(root, text="Class Colors").grid(row=0, column=1)
        row = 1

        for name, rgb in PALETTE.items():
            b = tk.Button(root, text=name, bg=self.rgb_hex(rgb), width=12,
                          command=lambda c=rgb: self.set_color(c))
            b.grid(row=row, column=1, pady=2)
            row += 1

        tk.Label(root, text="Brush Size").grid(row=row, column=1)
        row += 1

        self.size_slider = tk.Scale(root, from_=1, to=30, orient=tk.HORIZONTAL)
        self.size_slider.set(self.brush_size)
        self.size_slider.grid(row=row, column=1)
        row += 1

        tk.Button(root, text="Clear", command=self.clear).grid(row=row, column=1, pady=10)
        row += 1
        tk.Button(root, text="Generate", command=self.generate).grid(row=row, column=1, pady=10)

    # -------------------------------------------------------------------

    def rgb_hex(self, rgb):
        return "#%02x%02x%02x" % rgb

    def set_color(self, rgb):
        self.current_color = rgb

    def paint(self, event):
        x, y = event.x, event.y
        r = self.size_slider.get()

        self.canvas.create_oval(
            x-r, y-r, x+r, y+r,
            fill=self.rgb_hex(self.current_color),
            outline=self.rgb_hex(self.current_color)
        )

        self.draw.ellipse([x-r, y-r, x+r, y+r], fill=self.current_color)

    def clear(self):
        self.canvas.delete("all")
        self.draw.rectangle([0, 0, IMG_SIZE, IMG_SIZE], fill=(0, 0, 0))

    def clean_mask(self, pil_img):
        img_np = np.array(pil_img)
        h, w, _ = img_np.shape

        colors = np.array(CLASS_COLORS)
        class_ids = np.arange(len(CLASS_COLORS))

        img_flat = img_np.reshape(-1, 3)
        dist = np.sum((img_flat[:, None, :] - colors[None, :, :])**2, axis=2)
        nearest = np.argmin(dist, axis=1)
        class_mask = class_ids[nearest].reshape(h, w).astype(np.uint8)

        kernel = np.ones((5, 5), np.uint8)
        class_mask = cv2.morphologyEx(class_mask, cv2.MORPH_OPEN, kernel)

        class_mask = cv2.medianBlur(class_mask, 5)

        return class_mask


    def generate(self):

        class_mask = self.clean_mask(self.image)

        mask_tensor = torch.from_numpy(class_mask).long()
        mask_onehot = torch.zeros((N_CLASSES, IMG_SIZE, IMG_SIZE), dtype=torch.float32)
        mask_onehot.scatter_(0, mask_tensor.unsqueeze(0), 1.0)

        mask_onehot = mask_onehot.unsqueeze(0).to(DEVICE)

        with torch.no_grad():
            fake = G(mask_onehot)[0].cpu()

        fake_img = ((fake + 1) / 2).clamp(0, 1)
        fake_np = (fake_img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
        fake_pil = Image.fromarray(fake_np)

        win = tk.Toplevel(self.root)
        win.title("Generated Image")

        tk_img = ImageTk.PhotoImage(fake_pil)
        panel = tk.Label(win, image=tk_img)
        panel.image = tk_img
        panel.pack()

        fake_pil.save("generated_from_gui.png")
        print("Saved generated_from_gui.png")

In [32]:
root = tk.Tk()
app = MaskDrawer(root)
root.mainloop()

Saved generated_from_gui.png
Saved generated_from_gui.png
