In [1]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torchvision import transforms
from PIL import Image, ImageTk
import tempfile
from tkinter import *
from tkinter import Tk, Checkbutton, IntVar

# Define SimpleUNet model
class SimpleUNet(nn.Module):
    def __init__(self, num_classes=50):
        super(SimpleUNet, self).__init__()
        self.num_classes = num_classes  # Ensure num_classes is an attribute

        self.enc1 = self.conv_block(1 + num_classes, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)

        self.bottleneck = self.conv_block(256, 512)

        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2, padding=0)
        self.dec3 = self.conv_block(512, 256)
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2, padding=0)
        self.dec2 = self.conv_block(256, 128)
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = self.conv_block(128, 64)

        self.out_conv = nn.Conv2d(64, 1, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                init.xavier_uniform_(m.weight)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, image, label):
        # Convert label to one-hot encoding
        label_onehot = F.one_hot(label, num_classes=self.num_classes).float()
        label_onehot = label_onehot.unsqueeze(-1).unsqueeze(-1)
        label_onehot = label_onehot.expand(-1, -1, image.size(2), image.size(3))

        x = torch.cat((image, label_onehot), dim=1)

        # Encoder path
        enc1 = self.enc1(x)
        enc2 = self.enc2(F.max_pool2d(enc1, 2))
        enc3 = self.enc3(F.max_pool2d(enc2, 2))
        # Bottleneck
        bottleneck = self.bottleneck(F.max_pool2d(enc3, 2))
        dec3 = self.upconv3(bottleneck)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.dec3(dec3)

        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.dec2(dec2)

        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.dec1(dec1)

        # Output layer
        output = self.out_conv(dec1)
        output = self.sigmoid(output)

        return output

  warn(
  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

class SimpleUNet32(nn.Module):
    def __init__(self, num_classes=50):
        super(SimpleUNet32, self).__init__()
        self.num_classes = num_classes  # Ensure num_classes is an attribute

        # Encoder
        self.enc1 = self.conv_block(1 + num_classes, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)
        
        # Bottleneck
        self.bottleneck = self.conv_block(512, 1024)
        
        # Decoder
        self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = self.conv_block(1024, 512)
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2,padding=0)
        self.dec3 = self.conv_block(512, 256)
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2,padding=0)
        self.dec2 = self.conv_block(256, 128)
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = self.conv_block(128, 64)
        
        # Output layer
        self.out_conv = nn.Conv2d(64, 1, kernel_size=1)
        self.sigmoid = nn.Sigmoid()
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                # init.xavier_uniform_(m.weight)

                init.xavier_uniform(m.weight)
                # init.xavier_uniform(m.bias)


    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, image, label):
        # Convert label to one-hot encoding
        label_onehot = F.one_hot(label, num_classes=self.num_classes).float()
        label_onehot = label_onehot.unsqueeze(-1).unsqueeze(-1)
        label_onehot = label_onehot.expand(-1, -1, image.size(2), image.size(3))

        # Concatenate image and label along the channel dimension
        x = torch.cat((image, label_onehot), dim=1)

        # Encoder path
        enc1 = self.enc1(x)
        enc2 = self.enc2(F.max_pool2d(enc1, 2))
        enc3 = self.enc3(F.max_pool2d(enc2, 2))
        enc4 = self.enc4(F.max_pool2d(enc3, 2))
        
        # Bottleneck
        bottleneck = self.bottleneck(F.max_pool2d(enc4, 2))
    
        # print(bottleneck.shape)
        # Decoder path with skip connections
        dec4 = self.upconv4(bottleneck)
        # dec4 = torch.cat((dec4, enc4), dim=1)
        # dec4 = self.dec4(dec4)

        dec3 = self.upconv3(dec4)
        # dec3 = torch.cat((dec3, enc3), dim=1)
        # dec3 = self.dec3(dec3)

        dec2 = self.upconv2(dec3)
        # dec2 = torch.cat((dec2, enc2), dim=1)
        # dec2 = self.dec2(dec2)

        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.dec1(dec1)

        # Output layer
        output = self.out_conv(dec1)
        output = self.sigmoid(output)

        return output


In [10]:

import numpy as np
from PIL import Image, ImageTk, ImageGrab
import torch
from PIL import ImageOps

from torchvision import transforms

class Paint(object):
    def __init__(self):
        self.root = Tk()

        self.model = SimpleUNet32(num_classes=50)
        self.model.load_state_dict(torch.load('Unet32x48/model4.pth'))
        
        # self.model = SimpleUNet(num_classes=50)
        # self.model.load_state_dict(torch.load('Unet/model5.pth'))
        self.model.eval()

        self.options = ["ಅ", "ಆ", "ಇ", "ಈ", "ಉ", "ಊ", "ಋ", "ಎ", "ಏ", "ಐ", "ಒ", "ಓ", "ಔ", "ಅಂ", "ಅಃ", 'ಕ', 'ಖ', 'ಗ', 'ಘ', 'ಚ', 'ಛ', 'ಜ', 'ಝ', 'ಟ', 'ಠ', 'ಡ', 'ಢ', 'ಣ', 'ತ', 'ಥ', 'ದ', 'ಧ', 'ನ', 'ಪ', 'ಫ', 'ಬ', 'ಭ', 'ಮ', 'ಯ', 'ರ', 'ಲ', 'ಳ', 'ವ', 'ಶ', 'ಷ', 'ಸ', 'ಹ']
        self.clicked = StringVar()
        self.clicked.set(self.options[0])  # Setting the default value

        drop = OptionMenu(self.root, self.clicked, *self.options)
        drop.grid(row=0, column=4)

        self.label = Label(self.root, text=" ")
        self.label.grid(row=0, column=3)

        self.button = Button(self.root, text="Regenerate", command=self.regenerate)
        self.button.grid(row=1, column=6)

        self.eraser_button = Button(self.root, text='eraser', command=self.use_eraser)
        self.eraser_button.grid(row=0, column=0)

        self.reset_button = Button(self.root, text='reset', command=self.reset_screen)
        self.reset_button.grid(row=0, column=1)

        self.choose_size_button = Scale(self.root,from_=22, to=48, orient=HORIZONTAL)
        self.choose_size_button.set(32)
        self.choose_size_button.grid(row=0, column=2)

        self.c = Canvas(self.root, bg='white', width=640, height=480)
        self.c.grid(row=1, columnspan=5)

        self.threshold_checkbox_var = IntVar()
        self.threshold_checkbox = Checkbutton(self.root, text="Threshold",variable=self.threshold_checkbox_var,command=self.threshold_function)
        self.threshold_checkbox.select()
        self.threshold_checkbox.grid(row=0, column=7)

        self.choose_threshold_button = Scale(self.root, length=90,from_=5, to=240, orient=HORIZONTAL ,command=self.threshold_function)
        self.choose_threshold_button.set(128)
        self.choose_threshold_button.grid(row=0, column=8)


        self.choose_Noise_button = Scale(self.root,length=90,from_=0, to=100, orient=HORIZONTAL ,command=self.regenerate)
        self.choose_Noise_button.set(0)
        self.choose_Noise_button.grid(row=0, column=10)


        self.c2 = Canvas(self.root, bg='white', width=640, height=480)
        self.c2.grid(row=1, column=7, columnspan=5)

        self.setup()
        self.root.mainloop()
        # self.reset_screen()
    def setup(self):
        self.old_x = None
        self.old_y = None
        self.line_width = self.choose_size_button.get()
        self.eraser_on = False
        self.c.bind('<B1-Motion>', self.paint)
        self.c.bind('<ButtonRelease-1>', self.reset)

    def preprocess_image(self, pil_image, image_dim):
        transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=1),  # Convert image to grayscale
            transforms.Resize(image_dim),  # Resize to the desired dimensions
            transforms.ToTensor(),  # Convert image to tensor
            transforms.Lambda(lambda img: 1 - img),  # Apply binary transformation
            transforms.RandomAffine(degrees=0, translate=(0, 0.1), scale=(0.95, 0.96)),  # Normalize the image
        ])
        image = transform(pil_image)
        image = image.unsqueeze(0)  # Add batch dimension
        return image


    def threshold_function(self,event=None):
        if event:
            self.threshold_checkbox.select()
        try:
            self.set_output_canvas()
        except:
            pass

    def regenerate(self,event=None):
        # Update label text
        # self.label.config(text="id:" + str(self.options.index(self.clicked.get())))
        img_size=(32,48)
        # img_size=(24,32)

        # Get the canvas bounding box
        x = self.root.winfo_rootx() + self.c.winfo_x()
        y = self.root.winfo_rooty() + self.c.winfo_y()
        x1 = x + self.c.winfo_width()
        y1 = y + self.c.winfo_height()

        # Capture the canvas as a PIL image
        canvas_image = ImageGrab.grab().crop((x, y, x1, y1))
        canvas_image = ImageOps.grayscale(canvas_image)

        canvas_image = canvas_image.resize(img_size,Image.LANCZOS)
        # Preprocess the captured image
        processed_image = self.preprocess_image(canvas_image, img_size)
        input_image= processed_image
        input_image = processed_image + torch.randn(processed_image.size())*self.choose_Noise_button.get()/100
        input_image = input_image.clamp(0, 1)

        input_label = torch.tensor([self.options.index(self.clicked.get())])

        with torch.no_grad():
            output = self.model(input_image, input_label)
        # with torch.no_grad():
        #     output = self.model(output, input_label)
        # # with torch.no_grad():
        #     output = self.model(output, input_label)
        # # with torch.no_grad():
        #     output = self.model(output, input_label)

        output_image_np = output.squeeze(0).squeeze(0).numpy()

        output_image_pil = Image.fromarray((output_image_np * 255).astype('uint8'))
        # Binarize the output image
        output_image_pil = output_image_pil.resize((640, 480), Image.LANCZOS)

        self.output_image_before_threshold = output_image_pil
        self.set_output_canvas()

    def set_output_canvas(self):
        output_image_pil=self.output_image_before_threshold
        if self.threshold_checkbox_var.get() == 1:
            threshold = self.choose_threshold_button.get()
            output_image_pil = output_image_pil.point(lambda x: 0 if x < threshold else 255)
        # Convert resized PIL image to NumPy array
        output_image_resized_np = np.array(output_image_pil) / 255.0

        # Compute the maximum pixel values between canvas and output images
        max_image_np = 1-output_image_resized_np  # Ensure values are in the range [0, 1]
        # Convert the maximum image to a PIL image
        max_image_pil = Image.fromarray((max_image_np * 255).astype('uint8'))

        # Display the image on the canvas

        self.output_image = ImageTk.PhotoImage(max_image_pil)
        self.c2.create_image(0, 0, image=self.output_image, anchor='nw')


    def use_eraser(self):
        if not self.eraser_on:
            self.eraser_button.config(relief=SUNKEN)
            self.eraser_on = True
        else:
            self.eraser_button.config(relief=RAISED)
            self.eraser_on = False

    def paint(self, event):
        self.line_width = self.choose_size_button.get()
        paint_color = 'white' if self.eraser_on else 'black'
        if self.old_x and self.old_y:
            self.c.create_line(self.old_x, self.old_y, event.x, event.y, width=self.line_width, fill=paint_color, capstyle=ROUND, smooth=True, splinesteps=36)
        self.old_x = event.x
        self.old_y = event.y

    def reset(self, event):
        self.old_x, self.old_y = None, None
    def reset_screen(self):
        self.c.delete("all")
        self.c2.delete("all")



if __name__ == '__main__':
    Paint()


  init.xavier_uniform(m.weight)


In [9]:

# import gc
# gc.collect()

# del Paint

# if torch.cuda.is_available():
#     # Empty the CUDA cache
#     torch.cuda.empty_cache()
# else:
#     # Print a message indicating that CUDA is not available
#     print("CUDA is not available.")
