#### Imports

In [9]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

#### Hyperparameters

In [10]:
TRAIN_BATCH_SZ = 128
TEST_BATCH_SZ = 1000
EPOCHS = 3
LEARN_RATE = 0.01
CLASS_CNT = 78

device: str = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cpu device


#### Dataset definition

In [11]:
class DrawtexDataset(Dataset):
    classes: list[str] = ['!', '(', ')', '+', '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '=', 'A', 'alpha', 'b', 'beta', 'C', 'cos', 'd', 'Delta', 'div', 'e', 'exists', 'f', 'forall', 'forward_slash', 'G', 'gamma', 'geq', 'gt', 'H', 'i', 'in', 'infty', 'int', 'j', 'k', 'l', 'lambda', 'leq', 'lim', 'log', 'lt', 'M', 'mu', 'N', 'neq', 'o', 'p', 'phi', 'pi', 'pm', 'q', 'R', 'rightarrow', 'S', 'sigma', 'sin', 'sqrt', 'sum', 'T', 'tan', 'theta', 'times', 'u', 'v', 'w', 'X', 'y', 'z', '[', ']', '{', '}']
    mapping: dict[str, int] = {'!': 0, '(': 1, ')': 2, '+': 3, '-': 4, '0': 5, '1': 6, '2': 7, '3': 8, '4': 9, '5': 10, '6': 11, '7': 12, '8': 13, '9': 14, '=': 15, 'A': 16, 'alpha': 17, 'b': 18, 'beta': 19, 'C': 20, 'cos': 21, 'd': 22, 'Delta': 23, 'div': 24, 'e': 25, 'exists': 26, 'f': 27, 'forall': 28, 'forward_slash': 29, 'G': 30, 'gamma': 31, 'geq': 32, 'gt': 33, 'H': 34, 'i': 35, 'in': 36, 'infty': 37, 'int': 38, 'j': 39, 'k': 40, 'l': 41, 'lambda': 42, 'leq': 43, 'lim': 44, 'log': 45, 'lt': 46, 'M': 47, 'mu': 48, 'N': 49, 'neq': 50, 'o': 51, 'p': 52, 'phi': 53, 'pi': 54, 'pm': 55, 'q': 56, 'R': 57, 'rightarrow': 58, 'S': 59, 'sigma': 60, 'sin': 61, 'sqrt': 62, 'sum': 63, 'T': 64, 'tan': 65, 'theta': 66, 'times': 67, 'u': 68, 'v': 69, 'w': 70, 'X': 71, 'y': 72, 'z': 73, '[': 74, ']': 75, '{': 76, '}': 77}

    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    def __init__(self, transform=None):
        self.data: np.ndarray = np.load("../data/data_matrix.npy")
        self.labels: np.ndarray = np.load("../data/label_matrix.npy")
        self.transform = transform

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, item: any) -> tuple[torch.Tensor, torch.Tensor]:
        img: np.ndarray = self.data[item]
        label = self.labels[item]

        if self.transform is not None:
            img_tensor: torch.Tensor = self.transform(img).to(self.device)
        else:
            img_tensor: torch.Tensor = torch.from_numpy(img).to(self.device)

        label_tensor: torch.Tensor = torch.tensor(label).to(self.device)
        return img_tensor, label_tensor

#### Model definition

In [12]:
class DrawTexModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(1, 200, 9, bias=False)  # 45x45 -> 37x37
        self.conv1_bn = nn.BatchNorm2d(200)
        self.conv2 = nn.Conv2d(200, 300, 9, bias=False)  # 37x37 -> 29x29
        self.conv2_bn = nn.BatchNorm2d(300)
        self.conv3 = nn.Conv2d(300, 500, 9, bias=False)  # 29x29 -> 21x21
        self.conv3_bn = nn.BatchNorm2d(500)
        self.conv4 = nn.Conv2d(500, 800, 9, bias=False)  # 21x21 -> 13x13
        self.conv4_bn = nn.BatchNorm2d(800)
        self.conv5 = nn.Conv2d(800, 1000, 9, bias=False)  # 13x13 -> 5x5
        self.conv5_bn = nn.BatchNorm2d(1000)
        self.lin1 = nn.Linear(25000, CLASS_CNT, bias=False)
        self.lin1_bn = nn.BatchNorm1d(CLASS_CNT)

    def forward(self, x: torch.Tensor):
        x: torch.Tensor = self.relu(self.conv1_bn(self.conv1(x)))
        x: torch.Tensor = self.relu(self.conv2_bn(self.conv2(x)))
        x: torch.Tensor = self.relu(self.conv3_bn(self.conv3(x)))
        x: torch.Tensor = self.relu(self.conv4_bn(self.conv4(x)))
        x: torch.Tensor = self.relu(self.conv5_bn(self.conv5(x)))
        x = torch.flatten(x.permute(0, 2, 3, 1), 1)
        x = self.lin1_bn(self.lin1(x))
        return x

#### Dataloader setup

In [13]:
data_set = DrawtexDataset(transforms.ToTensor())
TRAIN_SIZE = int(0.8 * len(data_set))
TEST_SIZE = len(data_set) - TRAIN_SIZE
train_set, test_set = torch.utils.data.random_split(data_set, [TRAIN_SIZE, TEST_SIZE])

train_loader = DataLoader(
    dataset=train_set,
    batch_size=TRAIN_BATCH_SZ,
    shuffle=True,
    num_workers=4
)

test_loader = DataLoader(
    dataset=test_set,
    batch_size=TEST_BATCH_SZ,
    shuffle=False,
    num_workers=4
)

model = DrawTexModel().to(device)
print(model)

DrawTexModel(
  (relu): ReLU()
  (conv1): Conv2d(1, 200, kernel_size=(9, 9), stride=(1, 1), bias=False)
  (conv1_bn): BatchNorm2d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(200, 300, kernel_size=(9, 9), stride=(1, 1), bias=False)
  (conv2_bn): BatchNorm2d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(300, 500, kernel_size=(9, 9), stride=(1, 1), bias=False)
  (conv3_bn): BatchNorm2d(500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv2d(500, 800, kernel_size=(9, 9), stride=(1, 1), bias=False)
  (conv4_bn): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv5): Conv2d(800, 1000, kernel_size=(9, 9), stride=(1, 1), bias=False)
  (conv5_bn): BatchNorm2d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (lin1): Linear(in_features=25000, out_features=78, bias=False)
  (lin1_bn): BatchNorm1d(78, eps=1e-05, momentum=0.1, 

#### Training

In [14]:
tens = torch.rand((2, 1, 45, 45)).to(device)

output1 = model(tens)

print(output1)

tensor([[-1.0000, -0.9997, -1.0000, -1.0000, -1.0000,  0.9655,  0.9999,  1.0000,
          0.9972, -0.9996,  0.9995, -0.9998, -1.0000,  0.9982, -0.9999, -1.0000,
          0.9994, -0.9999, -0.9951, -0.9999, -0.9999, -0.9998, -0.9992, -0.9946,
          0.9966, -0.9998,  0.9995,  1.0000,  0.9999,  0.9995,  1.0000,  1.0000,
         -0.9999,  1.0000, -0.9937,  0.9990, -0.9999,  0.9999, -0.9998,  1.0000,
          0.9997,  0.9996,  0.9990,  0.9999,  0.6200, -0.9868, -0.9985,  0.9966,
         -0.9994,  1.0000, -0.9999,  0.9966,  1.0000, -0.9999, -0.9998, -0.9997,
         -0.9998, -0.9999, -0.9998, -0.9991, -0.9999, -0.9990, -0.9998,  0.9991,
          0.9999, -1.0000,  0.9988, -0.9996, -0.9998,  0.9999,  0.9957, -1.0000,
         -0.9118,  0.9998,  1.0000, -0.9335,  0.9985,  0.9984],
        [ 1.0000,  0.9997,  1.0000,  1.0000,  1.0000, -0.9655, -0.9999, -1.0000,
         -0.9972,  0.9996, -0.9995,  0.9998,  1.0000, -0.9982,  0.9999,  1.0000,
         -0.9994,  0.9999,  0.9951,  0.9999, 

In [None]:
def train():
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=LEARN_RATE)
    steps = len(train_loader)
    start_epoch = 0
    try:
        checkpoint = torch.load("model.pt")
        optimizer.load_state_dict(checkpoint["optimizer_state"])
        model.load_state_dict(checkpoint["model_state"])
        start_epoch = checkpoint["epoch"]
    except:
        pass
    for epoch in range(start_epoch, EPOCHS):

        for i, (img, label) in enumerate(train_loader):
            img: torch.Tensor = img.to(device, non_blocking=True)
            label: torch.Tensor = label.to(device, non_blocking=True)
            output = model(img).to(device)
            loss: torch.Tensor = criterion(output, label).to(device)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if i % 10 == 0:
                print(f"Epoch {epoch + 1}/{EPOCHS}, Batch {i + 1}/{steps}, Loss {loss.item():.4f}")

        # Save checkpoint
        torch.save({"epoch" : epoch,
                    "optimizer_state" : optimizer.state_dict(),
                    "model_state" : model.state_dict(),
                   }, "model.pt")

train()
torch.save(model.state_dict(), "./DrawTexModel.pth")


In [None]:
with torch.no_grad():
    correct = 0
    total = 0
    for img, labels in test_loader:
        img = img.to(device, non_blocking= True)
        labels = labels.to(device, non_blocking= True)

        output = model(img)
        _, prediction = torch.max(output, 1)
        total += labels.size(0)
        correct += (prediction == labels).sum().item()

    acc = 100.0 * correct / total
    print(f"Accuracy: {acc}%")