In [13]:
import torch
import torchvision
import torchvision.transforms as transforms
from sklearn.svm import SVC
from tqdm import tqdm
import matplotlib.pyplot as plt
import time

In [14]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, ), (0.5, ))]
)

train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=1000, shuffle=True, num_workers=2)

test_set = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1000, shuffle=False, num_workers=2)

train_features = []
train_labels = []

for i, data in enumerate(train_loader, 0):
    inputs, labels = data
    inputs = inputs.view(inputs.shape[0], -1)
    
    train_features.append(inputs)
    train_labels.append(labels)
    
train_features = torch.cat(train_features, dim=0)
train_labels = torch.cat(train_labels, dim=0)

test_features = []
test_labels = []

for i, data in enumerate(test_loader, 0):
    inputs, labels = data
    inputs = inputs.view(inputs.shape[0], -1)
    
    test_features.append(inputs)
    test_labels.append(labels)
    
test_features = torch.cat(test_features, dim=0)
test_labels = torch.cat(test_labels, dim=0)

#

print("SVM model Train & Test")

start_time = time.time()

model = SVC(C=1.0, kernel='rbf', gamma=0.01)

model.fit(train_features.numpy(), train_labels.numpy())

acc = model.score(test_features.numpy(), test_labels.numpy())

end_time = time.time()

print(f"Accuarcy: {acc}")

spent_time = end_time - start_time

print(f"Training, Test time: {spent_time:.2f} seconds")

#

test_images, test_labels = next(iter(test_loader))
test_images = test_images.view(test_images.shape[0], -1)

test_preds = model.predict(test_images.numpy())

#

def plot_images(images, labels, preds):
    fig, axes = plt.subplots(1, 10, figsize=(10, 1))
    
    for i, ax in enumerate(axes):
        ax.imshow(images[i].reshape((28, 28)), cmap='gray')
        title = f"Label: {labels[i]}\nPredicted: {preds[i]}"
        ax.set_title(title)
        ax.axis('off')
        
    plt.subplots_adjust(top=0.5, bottom=0, hspace=0, wspace=0.5)
    
    plt.show()
    
plot_images(test_images.numpy(), test_labels.numpy(), test_preds)

KeyboardInterrupt: 