In [1]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from DrawtexModel import DrawTexModel
from DrawtexDataset import DrawtexDataset

#### Hyperparameters

In [2]:
TRAIN_BATCH_SZ = 128
TEST_BATCH_SZ = 1000
EPOCHS = 10
LEARN_RATE = 0.01
CLASS_CNT = 82

device: str = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cpu device


#### Dataloader setup

In [3]:
data_set = DrawtexDataset(transforms.ToTensor())
TRAIN_SIZE = int(0.8 * len(data_set))
TEST_SIZE = len(data_set) - TRAIN_SIZE
train_set, test_set = torch.utils.data.random_split(data_set, [TRAIN_SIZE, TEST_SIZE])

train_loader = DataLoader(
    dataset=train_set,
    batch_size=TRAIN_BATCH_SZ,
    shuffle=True,
    num_workers=4
)

test_loader = DataLoader(
    dataset=test_set,
    batch_size=TEST_BATCH_SZ,
    shuffle=False,
    num_workers=4
)

model = DrawTexModel().to(device)
print(model)

DrawTexModel(
  (relu): ReLU()
  (conv1): Conv2d(1, 200, kernel_size=(9, 9), stride=(1, 1), bias=False)
  (conv1_bn): BatchNorm2d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(200, 400, kernel_size=(9, 9), stride=(1, 1), bias=False)
  (conv2_bn): BatchNorm2d(400, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(400, 800, kernel_size=(9, 9), stride=(1, 1), bias=False)
  (conv3_bn): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv2d(800, 1200, kernel_size=(9, 9), stride=(1, 1), bias=False)
  (conv4_bn): BatchNorm2d(1200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv5): Conv2d(1200, 1800, kernel_size=(9, 9), stride=(1, 1), bias=False)
  (conv5_bn): BatchNorm2d(1800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (lin1): Linear(in_features=45000, out_features=82, bias=False)
)


In [6]:
tens = torch.rand((2, 1, 45, 45))

output1 = model(tens)

print(output1)

tensor([[ 0.2889, -0.2709, -0.2910,  0.9286, -0.4560,  0.1626,  0.1412,  0.0711,
         -0.8352, -0.3151,  0.1168,  0.6586, -0.2296, -0.0791,  0.1228, -0.4435,
          0.1972,  0.8277, -0.3313, -0.1240,  0.0929,  0.4867, -0.1740,  0.3238,
         -0.1622, -0.4580, -0.5672, -0.3615,  0.4676, -0.5255, -0.1395, -0.7301,
         -0.3844, -0.0973,  0.0235, -0.1049,  0.3879,  0.2073,  0.4850,  0.1556,
         -0.2388,  0.1398,  0.2015, -0.5098,  0.3370, -0.0948,  0.3297, -0.1409,
          0.4271,  0.2152, -1.0547,  0.2057, -0.4789,  0.2114,  0.3562,  0.2599,
         -0.5566, -0.0972,  0.0118, -0.5643,  0.3202, -0.1525,  0.6448,  0.4264,
         -0.4849,  0.3997,  0.1793,  0.2557, -0.5028,  0.1553, -0.7345,  0.0840,
          0.9502,  0.5828,  0.2535,  0.1938, -0.0554, -0.4812,  0.0178, -0.0536,
         -0.1998, -1.0176],
        [ 0.2405, -0.0318, -0.1697,  0.4312, -0.3319, -0.4451,  0.5346, -0.0202,
          0.3594, -0.2812, -0.0122, -0.4484,  0.0601,  0.0066,  0.1746,  0.0194,


In [None]:
def train():
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=LEARN_RATE)
    steps = len(train_loader)
    for epoch in range(EPOCHS):
        for i, (img, label) in enumerate(train_loader):
            img: torch.Tensor = img.to(device, non_blocking=True)
            label: torch.Tensor = label.to(device, non_blocking=True)
            output = model(img).to(device)
            loss: torch.Tensor = criterion(output, label).to(device)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (i + 1) % TRAIN_BATCH_SZ == 0:
                print(f"Epoch {epoch + 1}/{EPOCHS}, Batch {i + 1}/{steps}, Loss {loss.item():.4f}")

train()
torch.save(model.state_dict(), "./DrawTexModel.pth")


In [None]:
with torch.no_grad():
    correct = 0
    total = 0
    for img, labels in test_loader:
        img = img.to(device, non_blocking= True)
        labels = labels.to(device, non_blocking= True)

        output = model(img)
        _, prediction = torch.max(output, 1)
        total += labels.size(0)
        correct += (prediction == labels).sum().item()

    acc = 100.0 * correct / total
    print(f"Accuracy: {acc}%")