In [38]:
import torch
from dataset import PlayingCardDataset
from torch.utils.data import DataLoader
import SpadeClassifier
from mapping import cards
import matplotlib.pyplot as plt

In [39]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 1

In [40]:
# Load dataset and print sample batch
dataset: PlayingCardDataset = torch.load("./playing_card_dataset.pt")
image, label = next(iter(dataset))
plt.imshow(image.permute(1,2,0))
print(cards[int(label.nonzero()+1)])

In [41]:
# Create Train/Test Split and data loaders
train_set, test_set = torch.utils.data.random_split(dataset, [0.8, 0.2])
train_load, test_load = DataLoader(train_set, batch_size=1, shuffle=True), DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True)

In [46]:
import importlib
importlib.reload(SpadeClassifier)

In [53]:
# Create model
model = SpadeClassifier.SpadeClassifier(53).to(device)

In [54]:
def train_one_epoch() -> None:
    running_loss = 0.0
    
    model.train()
    for iteration, data in enumerate(train_load):
        # Get data and move to the correct device
        images, labels = data
        images = images.to(device) 
        labels = labels.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if iteration % 500 == 0:
            print(f"Iteration: {iteration}, Loss: {loss.item()}")
    
    train_loss.append(running_loss / len(train_load))
    
def test_one_epoch() -> None:
    labeled_correctly = 0
    running_loss = 0.0
    
    model.eval()
    with torch.no_grad():  # Disable gradient calculation for evaluation
        for images, labels in test_load:
            # Move data to the same device as the model
            images = images.to(device)
            labels = labels.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            running_loss += loss.item()
            if outputs.argmax() == labels.argmax():
                labeled_correctly += 1
    
    running_loss /= len(test_load)
    print("--------------------")
    print(f"Test Loss: {running_loss}\n")
    print(f"Accuracy: {labeled_correctly/len(test_load)*100:.2f}%")
    
    
    test_loss.append(running_loss)


In [55]:
# Define Training loop
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
loss_fn = torch.nn.CrossEntropyLoss()
train_loss = []
test_loss = []
epochs = 5

for epoch in range(epochs):
    train_one_epoch()
    test_one_epoch()

In [None]:
torch.save(model.state_dict(), "./model.pt")