### Imports

In [8]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms

from DT_Model import DT_Model
from DT_Dataset import DT_Dataset

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
train_set = DT_Dataset(path="../data",
                       train=True,
                       transform=transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.75, 1), fill=255))
train_load = DataLoader(dataset=train_set,
                        batch_size=128,
                        num_workers=4,
                        shuffle=True)
test_set = DT_Dataset(path="../data",
                       train=False,
                       transform=transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.75, 1), fill=255))
test_load = DataLoader(dataset=test_set,
                       batch_size=1000,
                       num_workers=4,
                       shuffle=False)
model = DT_Model().to(device)

### Hyperparameters

In [3]:
EPOCHS = 20
LEARN_RATE = 0.001


In [6]:
def train(epoch_cnt):
    """
    Train for 1 epoch
    """
    total_steps = len(train_load)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(epoch_cnt):
        correct = 0
        total = 0
        for i, (img, label) in enumerate(train_load):
            img = img.to(device)
            label = label.to(device)

            output = model(img)
            _, prediction = torch.max(output, 1)
            loss = loss_fn(output, label).to(device, non_blocking=True)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            correct += (prediction == label).float().sum()
            total += len(label)
            if (i + 1) % 100 == 0:
                print(f"Epoch: [{epoch + 1}/{epoch_cnt}], Step: [{i + 1}/{total_steps}], Loss: [{loss.item():.4f}], Accuracy: [{100 * correct / total}%]")
                correct = 0
                total = 0

        torch.save(model.state_dict(), "./DT_Model.pt")


In [7]:
train(EPOCHS)

hi
hi


KeyboardInterrupt: 