In [None]:
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from utils import train_classifier, plot_images
from datasets import load_nist_data
from architectures import LeNet5, ResNet18, ResNet34

In [None]:
dataname = 'MNIST'
train_dataset = load_nist_data(dataset=dataname, binary_threshold=0.5, train=True)
test_dataset = load_nist_data(dataset=dataname, binary_threshold=0.5, train=False)

train_dataloader = DataLoader(train_dataset , batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset , batch_size=64, shuffle=False)
images, _ = next(iter(train_dataloader))

plot_images(images, dataname, figsize=(3.5, 3.5))

#...train classifier

model = LeNet5(num_classes=10) 
train_classifier(model, 
                 train_dataloader, 
                 test_dataloader,
                 accuracy_goal=98.5,
                 save_as='models/model_{}_LeNet5.pth'.format('_'.join(dataname.split(' '))), 
                 epochs=100, 
                 lr=0.001)


In [None]:
dataname = 'EMNIST Letters'
train_dataset = load_nist_data(dataset=dataname, binary_threshold=0.5, train=True)
test_dataset = load_nist_data(dataset=dataname, binary_threshold=0.5, train=False)

train_dataloader = DataLoader(train_dataset , batch_size=256, shuffle=True)
test_dataloader = DataLoader(test_dataset , batch_size=256, shuffle=False)
images, _ = next(iter(train_dataloader))

plot_images(images, dataname, figsize=(3.5, 3.5))

#...train classifier

model = ResNet18(num_classes=27) 
train_classifier(model, 
                 train_dataloader, 
                 test_dataloader,
                 accuracy_goal=95,
                 save_as='models/model_{}_ResNet18.pth'.format('_'.join(dataname.split(' '))), 
                 epochs=100, 
                 lr=0.001)

In [None]:
dataname = 'FashionMNIST'

train_dataset = load_nist_data(dataset=dataname, binary_threshold=None, train=True)
test_dataset = load_nist_data(dataset=dataname, binary_threshold=None, train=False)

train_dataloader = DataLoader(train_dataset , batch_size=128, shuffle=True)
test_dataloader = DataLoader(test_dataset , batch_size=64, shuffle=False)
images, _ = next(iter(train_dataloader))

plot_images(images, dataname, figsize=(3.5, 3.5))

#...train classifier

model =  ResNet34(num_classes=10) 
train_classifier(model, 
                 train_dataloader, 
                 test_dataloader,
                 accuracy_goal=95,
                 save_as='model_{}_ResNet34.pth'.format('_'.join(dataname.split(' '))), 
                 epochs=100, 
                 lr=0.001)