In [1]:
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import ssl
from PIL import Image
import torchvision.transforms.functional as TF
import gradio as gr
import os
from PIL import Image

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.maxpool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.linear1 = nn.Linear(16 * 5 * 5, 120)
        self.linear2 = nn.Linear(120, 84)
        self.linear3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.maxpool(F.relu(self.conv1(x)))
        x = self.maxpool(F.relu(self.conv2(x)))

        x = x.reshape(-1, 16 * 5 * 5)

        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        return x

In [12]:
class MyApp:
    def __init__(self):
        super().__init__()
        self.epoch = 100
        self.lr = 0.001
        self.model = MyModel()

        self.classes = ('circles', 'squares', 'triangles')
        self.labels = (0, 1, 2)

        self.transform = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
        ])

        self.dataset = torchvision.datasets.ImageFolder(root='./shapes', transform=self.transform)
        self.dataloader = torch.utils.data.DataLoader(self.dataset, batch_size=64, shuffle=True, num_workers=2)

app = MyApp()

In [4]:
def train_model(device):
    loss_fonct = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(app.model.parameters(), lr=0.01)

    train_accuracies = np.zeros(app.epoch)
    train_loss = []

    for epoch in tqdm(range(app.epoch)):
        total_train, correct_train = 0, 0
        for batch_idx, batch in enumerate(tqdm(app.dataloader)):
            images, labels = batch
            images = images.to(device=device)
            labels = labels.to(device=device)

            output = app.model.forward(images)
            loss = loss_fonct(output, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(output.data, 1)
            total_train += labels.size(0)
            correct_train += predicted.eq(labels).sum().item()

            if batch_idx % 10 == 0:
                train_loss.append(loss.item())

        train_accuracies[epoch] = correct_train / total_train * 100
        print(correct_train / total_train * 100, "%\n")

    torch.save(app.model.state_dict(), "shapes-model.pth")


In [5]:
def predict(image):
    out = app.model(image.reshape(1, 3, 32, 32))
    _, pred = torch.max(out, dim=1)
    return app.classes[pred.item()]


In [6]:
IMG_SIZE = 32 if torch.cuda.is_available() else 32
COMPOSED_TRANSFORMERS = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),
])

In [7]:
NORMALIZE_TENSOR = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)

In [8]:
def np_array_to_tensor_image(img, width=IMG_SIZE, height=IMG_SIZE, device='cpu'):
    image = Image.fromarray(img).convert('RGB').resize((width, height))
    image = COMPOSED_TRANSFORMERS(image).unsqueeze(0)
    image = NORMALIZE_TENSOR(image)
    return image.to(device, torch.float)

In [9]:
def sketch_recognition(img):
    img = np_array_to_tensor_image(img)
    app.model.eval()
    with torch.no_grad():
        img = NORMALIZE_TENSOR(img)
        result = predict(img)
    app.model.train()
    return result

In [10]:
ssl._create_default_https_context = ssl._create_unverified_context

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

app.model.to(device)

MyModel(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (linear1): Linear(in_features=400, out_features=120, bias=True)
  (linear2): Linear(in_features=120, out_features=84, bias=True)
  (linear3): Linear(in_features=84, out_features=10, bias=True)
)

In [13]:
# if os.path.isfile("shapes-model.pth"):
#     app.model.load_state_dict(torch.load("shapes-model.pth"))
# else:
    # train_model(device)
train_model(device)
gr.Interface(fn=sketch_recognition, inputs=["sketchpad"], outputs="label").launch(share=True)


100%|██████████| 5/5 [00:00<00:00, 31.69it/s]
  1%|          | 1/100 [00:00<00:16,  6.14it/s]

28.000000000000004 %



100%|██████████| 5/5 [00:00<00:00, 41.45it/s]
  2%|▏         | 2/100 [00:00<00:13,  7.11it/s]

33.33333333333333 %



100%|██████████| 5/5 [00:00<00:00, 35.30it/s]
  3%|▎         | 3/100 [00:00<00:13,  6.99it/s]

28.999999999999996 %



100%|██████████| 5/5 [00:00<00:00, 32.31it/s]
  4%|▍         | 4/100 [00:00<00:14,  6.66it/s]

34.333333333333336 %



100%|██████████| 5/5 [00:00<00:00, 45.78it/s]
  5%|▌         | 5/100 [00:00<00:13,  7.23it/s]

32.33333333333333 %



100%|██████████| 5/5 [00:00<00:00, 40.08it/s]
  6%|▌         | 6/100 [00:00<00:12,  7.43it/s]

33.33333333333333 %



100%|██████████| 5/5 [00:00<00:00, 41.77it/s]
  7%|▋         | 7/100 [00:00<00:12,  7.63it/s]

30.333333333333336 %



100%|██████████| 5/5 [00:00<00:00, 37.42it/s]
  8%|▊         | 8/100 [00:01<00:12,  7.51it/s]

33.0 %



100%|██████████| 5/5 [00:00<00:00, 42.36it/s]
  9%|▉         | 9/100 [00:01<00:11,  7.73it/s]

33.33333333333333 %



100%|██████████| 5/5 [00:00<00:00, 36.00it/s]
 10%|█         | 10/100 [00:01<00:12,  7.49it/s]

34.66666666666667 %



100%|██████████| 5/5 [00:00<00:00, 33.09it/s]
 11%|█         | 11/100 [00:01<00:12,  7.12it/s]

33.33333333333333 %



100%|██████████| 5/5 [00:00<00:00, 39.97it/s]
 12%|█▏        | 12/100 [00:01<00:12,  7.32it/s]

30.333333333333336 %



100%|██████████| 5/5 [00:00<00:00, 37.23it/s]
 13%|█▎        | 13/100 [00:01<00:11,  7.28it/s]

29.666666666666668 %



100%|██████████| 5/5 [00:00<00:00, 35.77it/s]
 14%|█▍        | 14/100 [00:01<00:11,  7.19it/s]

31.0 %



100%|██████████| 5/5 [00:00<00:00, 43.09it/s]
 15%|█▌        | 15/100 [00:02<00:11,  7.49it/s]

31.333333333333336 %



100%|██████████| 5/5 [00:00<00:00, 28.34it/s]
 16%|█▌        | 16/100 [00:02<00:12,  6.77it/s]

30.666666666666664 %



100%|██████████| 5/5 [00:00<00:00, 36.76it/s]
 17%|█▋        | 17/100 [00:02<00:12,  6.87it/s]

31.666666666666664 %



100%|██████████| 5/5 [00:00<00:00, 39.98it/s]
 18%|█▊        | 18/100 [00:02<00:11,  7.13it/s]

34.333333333333336 %



100%|██████████| 5/5 [00:00<00:00, 35.79it/s]
 19%|█▉        | 19/100 [00:02<00:11,  7.06it/s]

33.33333333333333 %



100%|██████████| 5/5 [00:00<00:00, 36.19it/s]
 20%|██        | 20/100 [00:02<00:11,  7.02it/s]

33.666666666666664 %



100%|██████████| 5/5 [00:00<00:00, 32.32it/s]
 21%|██        | 21/100 [00:02<00:11,  6.79it/s]

33.33333333333333 %



100%|██████████| 5/5 [00:00<00:00, 37.57it/s]
 22%|██▏       | 22/100 [00:03<00:11,  6.95it/s]

32.33333333333333 %



100%|██████████| 5/5 [00:00<00:00, 35.88it/s]
 23%|██▎       | 23/100 [00:03<00:11,  6.95it/s]

34.66666666666667 %



100%|██████████| 5/5 [00:00<00:00, 39.94it/s]
 24%|██▍       | 24/100 [00:03<00:10,  7.16it/s]

33.33333333333333 %



100%|██████████| 5/5 [00:00<00:00, 35.10it/s]
 25%|██▌       | 25/100 [00:03<00:10,  7.07it/s]

33.666666666666664 %



100%|██████████| 5/5 [00:00<00:00, 37.04it/s]
 26%|██▌       | 26/100 [00:03<00:10,  7.13it/s]

30.666666666666664 %



100%|██████████| 5/5 [00:00<00:00, 41.47it/s]
 27%|██▋       | 27/100 [00:03<00:09,  7.35it/s]

33.33333333333333 %



100%|██████████| 5/5 [00:00<00:00, 36.09it/s]
 28%|██▊       | 28/100 [00:03<00:09,  7.26it/s]

32.0 %



100%|██████████| 5/5 [00:00<00:00, 35.40it/s]
 29%|██▉       | 29/100 [00:04<00:09,  7.11it/s]

30.666666666666664 %



100%|██████████| 5/5 [00:00<00:00, 37.84it/s]
 30%|███       | 30/100 [00:04<00:09,  7.21it/s]

30.666666666666664 %



100%|██████████| 5/5 [00:00<00:00, 36.43it/s]
 31%|███       | 31/100 [00:04<00:09,  7.19it/s]

31.666666666666664 %



100%|██████████| 5/5 [00:00<00:00, 36.70it/s]
 32%|███▏      | 32/100 [00:04<00:09,  7.18it/s]

34.66666666666667 %



100%|██████████| 5/5 [00:00<00:00, 33.38it/s]
 33%|███▎      | 33/100 [00:04<00:09,  6.98it/s]

33.666666666666664 %



100%|██████████| 5/5 [00:00<00:00, 34.80it/s]
 34%|███▍      | 34/100 [00:04<00:09,  6.94it/s]

33.33333333333333 %



100%|██████████| 5/5 [00:00<00:00, 35.05it/s]
 35%|███▌      | 35/100 [00:04<00:09,  6.91it/s]

27.333333333333332 %



100%|██████████| 5/5 [00:00<00:00, 40.16it/s]
 36%|███▌      | 36/100 [00:05<00:08,  7.13it/s]

35.333333333333336 %



100%|██████████| 5/5 [00:00<00:00, 36.77it/s]
 37%|███▋      | 37/100 [00:05<00:08,  7.12it/s]

32.33333333333333 %



100%|██████████| 5/5 [00:00<00:00, 37.14it/s]
 38%|███▊      | 38/100 [00:05<00:08,  7.18it/s]

29.333333333333332 %



100%|██████████| 5/5 [00:00<00:00, 44.79it/s]
 39%|███▉      | 39/100 [00:05<00:08,  7.59it/s]

31.0 %



100%|██████████| 5/5 [00:00<00:00, 22.98it/s]
 40%|████      | 40/100 [00:05<00:09,  6.32it/s]

33.33333333333333 %



100%|██████████| 5/5 [00:00<00:00, 34.41it/s]
 41%|████      | 41/100 [00:05<00:09,  6.44it/s]

31.0 %



100%|██████████| 5/5 [00:00<00:00, 44.65it/s]
 42%|████▏     | 42/100 [00:05<00:08,  6.99it/s]

36.666666666666664 %



100%|██████████| 5/5 [00:00<00:00, 26.56it/s]
 43%|████▎     | 43/100 [00:06<00:08,  6.36it/s]

35.66666666666667 %



100%|██████████| 5/5 [00:00<00:00, 37.07it/s]
 44%|████▍     | 44/100 [00:06<00:08,  6.55it/s]

34.333333333333336 %



100%|██████████| 5/5 [00:00<00:00, 37.26it/s]
 45%|████▌     | 45/100 [00:06<00:08,  6.76it/s]

29.333333333333332 %



100%|██████████| 5/5 [00:00<00:00, 36.92it/s]
 46%|████▌     | 46/100 [00:06<00:07,  6.90it/s]

32.33333333333333 %



100%|██████████| 5/5 [00:00<00:00, 36.32it/s]
 47%|████▋     | 47/100 [00:06<00:07,  6.92it/s]

34.0 %



100%|██████████| 5/5 [00:00<00:00, 36.66it/s]
 48%|████▊     | 48/100 [00:06<00:07,  6.99it/s]

32.0 %



100%|██████████| 5/5 [00:00<00:00, 35.72it/s]
 49%|████▉     | 49/100 [00:06<00:07,  6.95it/s]

32.33333333333333 %



100%|██████████| 5/5 [00:00<00:00, 34.00it/s]
 50%|█████     | 50/100 [00:07<00:07,  6.87it/s]

35.0 %



100%|██████████| 5/5 [00:00<00:00, 34.77it/s]
 51%|█████     | 51/100 [00:07<00:07,  6.85it/s]

30.666666666666664 %



100%|██████████| 5/5 [00:00<00:00, 36.63it/s]
 52%|█████▏    | 52/100 [00:07<00:06,  6.94it/s]

33.0 %



100%|██████████| 5/5 [00:00<00:00, 36.01it/s]
 53%|█████▎    | 53/100 [00:07<00:06,  6.94it/s]

35.333333333333336 %



100%|██████████| 5/5 [00:00<00:00, 37.96it/s]
 54%|█████▍    | 54/100 [00:07<00:06,  7.06it/s]

38.666666666666664 %



100%|██████████| 5/5 [00:00<00:00, 35.23it/s]
 55%|█████▌    | 55/100 [00:07<00:06,  7.01it/s]

46.0 %



100%|██████████| 5/5 [00:00<00:00, 32.68it/s]
 56%|█████▌    | 56/100 [00:07<00:06,  6.83it/s]

46.666666666666664 %



100%|██████████| 5/5 [00:00<00:00, 38.43it/s]
 57%|█████▋    | 57/100 [00:08<00:06,  7.03it/s]

41.0 %



100%|██████████| 5/5 [00:00<00:00, 35.26it/s]
 58%|█████▊    | 58/100 [00:08<00:06,  6.97it/s]

32.666666666666664 %



100%|██████████| 5/5 [00:00<00:00, 34.67it/s]
 59%|█████▉    | 59/100 [00:08<00:05,  6.92it/s]

46.0 %



100%|██████████| 5/5 [00:00<00:00, 34.97it/s]
 60%|██████    | 60/100 [00:08<00:05,  6.86it/s]

49.333333333333336 %



100%|██████████| 5/5 [00:00<00:00, 33.35it/s]
 61%|██████    | 61/100 [00:08<00:05,  6.75it/s]

65.66666666666666 %



100%|██████████| 5/5 [00:00<00:00, 32.78it/s]
 62%|██████▏   | 62/100 [00:08<00:05,  6.66it/s]

67.0 %



100%|██████████| 5/5 [00:00<00:00, 33.41it/s]
 63%|██████▎   | 63/100 [00:09<00:05,  6.62it/s]

72.33333333333334 %



100%|██████████| 5/5 [00:00<00:00, 32.57it/s]
 64%|██████▍   | 64/100 [00:09<00:05,  6.50it/s]

80.33333333333333 %



100%|██████████| 5/5 [00:00<00:00, 33.13it/s]
 65%|██████▌   | 65/100 [00:09<00:05,  6.49it/s]

84.0 %



100%|██████████| 5/5 [00:00<00:00, 30.13it/s]
 66%|██████▌   | 66/100 [00:09<00:05,  6.29it/s]

83.33333333333334 %



100%|██████████| 5/5 [00:00<00:00, 32.02it/s]
 67%|██████▋   | 67/100 [00:09<00:05,  6.25it/s]

81.0 %



100%|██████████| 5/5 [00:00<00:00, 33.38it/s]
 68%|██████▊   | 68/100 [00:09<00:05,  6.33it/s]

81.33333333333333 %



100%|██████████| 5/5 [00:00<00:00, 33.22it/s]
 69%|██████▉   | 69/100 [00:09<00:04,  6.35it/s]

84.33333333333334 %



100%|██████████| 5/5 [00:00<00:00, 30.85it/s]
 70%|███████   | 70/100 [00:10<00:04,  6.23it/s]

86.0 %



100%|██████████| 5/5 [00:00<00:00, 30.11it/s]
 71%|███████   | 71/100 [00:10<00:04,  6.13it/s]

88.33333333333333 %



100%|██████████| 5/5 [00:00<00:00, 33.05it/s]
 72%|███████▏  | 72/100 [00:10<00:04,  6.23it/s]

91.66666666666666 %



100%|██████████| 5/5 [00:00<00:00, 31.14it/s]
 73%|███████▎  | 73/100 [00:10<00:04,  6.20it/s]

92.66666666666666 %



100%|██████████| 5/5 [00:00<00:00, 24.42it/s]
 74%|███████▍  | 74/100 [00:10<00:04,  5.68it/s]

94.0 %



100%|██████████| 5/5 [00:00<00:00, 39.90it/s]
 75%|███████▌  | 75/100 [00:10<00:04,  6.17it/s]

94.33333333333334 %



100%|██████████| 5/5 [00:00<00:00, 29.28it/s]
 76%|███████▌  | 76/100 [00:11<00:03,  6.02it/s]

98.0 %



100%|██████████| 5/5 [00:00<00:00, 32.25it/s]
 77%|███████▋  | 77/100 [00:11<00:03,  6.12it/s]

98.66666666666667 %



100%|██████████| 5/5 [00:00<00:00, 32.98it/s]
 78%|███████▊  | 78/100 [00:11<00:03,  6.19it/s]

98.66666666666667 %



100%|██████████| 5/5 [00:00<00:00, 38.82it/s]
 79%|███████▉  | 79/100 [00:11<00:03,  6.53it/s]

98.66666666666667 %



100%|██████████| 5/5 [00:00<00:00, 36.17it/s]
 80%|████████  | 80/100 [00:11<00:03,  6.67it/s]

99.0 %



100%|██████████| 5/5 [00:00<00:00, 35.09it/s]
 81%|████████  | 81/100 [00:11<00:02,  6.67it/s]

99.66666666666667 %



100%|██████████| 5/5 [00:00<00:00, 39.61it/s]
 82%|████████▏ | 82/100 [00:12<00:02,  6.91it/s]

99.66666666666667 %



100%|██████████| 5/5 [00:00<00:00, 34.80it/s]
 83%|████████▎ | 83/100 [00:12<00:02,  6.87it/s]

99.66666666666667 %



100%|██████████| 5/5 [00:00<00:00, 33.64it/s]
 84%|████████▍ | 84/100 [00:12<00:02,  6.80it/s]

99.66666666666667 %



100%|██████████| 5/5 [00:00<00:00, 34.98it/s]
 85%|████████▌ | 85/100 [00:12<00:02,  6.82it/s]

100.0 %



100%|██████████| 5/5 [00:00<00:00, 33.51it/s]
 86%|████████▌ | 86/100 [00:12<00:02,  6.74it/s]

100.0 %



100%|██████████| 5/5 [00:00<00:00, 34.21it/s]
 87%|████████▋ | 87/100 [00:12<00:01,  6.73it/s]

100.0 %



100%|██████████| 5/5 [00:00<00:00, 33.72it/s]
 88%|████████▊ | 88/100 [00:12<00:01,  6.70it/s]

100.0 %



100%|██████████| 5/5 [00:00<00:00, 36.93it/s]
 89%|████████▉ | 89/100 [00:13<00:01,  6.85it/s]

100.0 %



100%|██████████| 5/5 [00:00<00:00, 40.05it/s]
 90%|█████████ | 90/100 [00:13<00:01,  7.06it/s]

100.0 %



100%|██████████| 5/5 [00:00<00:00, 34.32it/s]
 91%|█████████ | 91/100 [00:13<00:01,  6.97it/s]

100.0 %



100%|██████████| 5/5 [00:00<00:00, 38.50it/s]
 92%|█████████▏| 92/100 [00:13<00:01,  7.13it/s]

100.0 %



100%|██████████| 5/5 [00:00<00:00, 35.70it/s]
 93%|█████████▎| 93/100 [00:13<00:00,  7.08it/s]

100.0 %



100%|██████████| 5/5 [00:00<00:00, 33.63it/s]
 94%|█████████▍| 94/100 [00:13<00:00,  6.89it/s]

100.0 %



100%|██████████| 5/5 [00:00<00:00, 38.79it/s]
 95%|█████████▌| 95/100 [00:13<00:00,  7.08it/s]

100.0 %



100%|██████████| 5/5 [00:00<00:00, 36.76it/s]
 96%|█████████▌| 96/100 [00:14<00:00,  7.12it/s]

100.0 %



100%|██████████| 5/5 [00:00<00:00, 35.83it/s]
 97%|█████████▋| 97/100 [00:14<00:00,  7.06it/s]

100.0 %



100%|██████████| 5/5 [00:00<00:00, 33.79it/s]
 98%|█████████▊| 98/100 [00:14<00:00,  6.91it/s]

100.0 %



100%|██████████| 5/5 [00:00<00:00, 42.03it/s]
 99%|█████████▉| 99/100 [00:14<00:00,  7.26it/s]

100.0 %



100%|██████████| 5/5 [00:00<00:00, 33.93it/s]
100%|██████████| 100/100 [00:14<00:00,  6.85it/s]


100.0 %

Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://18b616b1-5742-4490.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces


