In [5]:
"""
gui_generate.py
Simple Tkinter GUI to load a DCGAN generator checkpoint and sample images.

Usage:
python gui_generate.py --ckpt ./checkpoints/dcgan_epoch_10.pth --nz 100 --out_dir ./samples

Dependencies:
pip install torch torchvision pillow customtkinter
"""
import os
import argparse
import torch
import torch.nn as nn
from PIL import Image, ImageTk
import numpy as np
import customtkinter as ctk
import tkinter as tk
from tkinter import filedialog

# ---- Generator (same architecture as train) ----
class Generator(nn.Module):
    def __init__(self, nz=100, ngf=64, nc=3):
        super().__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

# ---- Helpers ----
def denorm(img_tensor):
    # tensor in [-1,1] -> [0,255] uint8
    img = (img_tensor + 1.0) / 2.0
    img = img.clamp(0,1)
    img = (img * 255).byte().cpu().numpy()
    # shape: (B,C,H,W)
    img = np.transpose(img, (0,2,3,1))
    return img

# ---- GUI app ----
def run_gui(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    nz = args.nz
    netG = Generator(nz=nz, ngf=args.ngf, nc=3).to(device)
    if not os.path.exists(args.ckpt):
        raise FileNotFoundError("Checkpoint not found: " + args.ckpt)
    ckpt = torch.load(args.ckpt, map_location=device)
    netG.load_state_dict(ckpt['netG_state_dict'])
    netG.eval()

    ctk.set_appearance_mode("System")
    ctk.set_default_color_theme("blue")

    app = ctk.CTk()
    app.title("DCGAN Anime Generator")
    app.geometry("820x560")

    # canvas
    canvas = tk.Canvas(app, width=512, height=512, bg="#222")
    canvas.pack(side="left", padx=10, pady=10)

    # right controls
    panel = ctk.CTkFrame(app)
    panel.pack(side="right", fill="both", expand=True, padx=10, pady=10)

    status = ctk.StringVar(value="Ready")

    def generate_and_show():
        status.set("Generating...")
        app.update()
        with torch.no_grad():
            z = torch.randn(1, nz, 1, 1, device=device)
            fake = netG(z).detach()
            imgs = denorm(fake)  # (1,H,W,3)
            img = Image.fromarray(imgs[0])
            img = img.resize((512,512))
            tkimg = ImageTk.PhotoImage(img)
            # keep reference
            canvas.image = tkimg
            canvas.create_image(0,0, anchor="nw", image=tkimg)
            status.set("Done")

            # store last generated image
            panel.last_image = img

    def save_image():
        if not hasattr(panel, "last_image"):
            status.set("No image to save")
            return
        fname = filedialog.asksaveasfilename(defaultextension=".png", filetypes=[("PNG","*.png")])
        if fname:
            panel.last_image.save(fname)
            status.set("Saved: " + os.path.basename(fname))

    ctk.CTkLabel(panel, text="DCGAN Anime Generator", font=ctk.CTkFont(size=18, weight="bold")).pack(pady=10)
    ctk.CTkButton(panel, text="Generate Random Face", command=generate_and_show).pack(pady=8)
    ctk.CTkButton(panel, text="Save Last Image", command=save_image).pack(pady=8)
    ctk.CTkLabel(panel, textvariable=status).pack(pady=8)

    app.mainloop()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--ckpt", type=str, required=True)
    parser.add_argument("--nz", type=int, default=100)
    parser.add_argument("--ngf", type=int, default=64)
    args = parser.parse_args()
    run_gui(args)


usage: ipykernel_launcher.py [-h] --ckpt CKPT [--nz NZ] [--ngf NGF]
ipykernel_launcher.py: error: the following arguments are required: --ckpt


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


TclError: image "pyimage3" doesn't exist