In [None]:
import torch
import torchvision
from torchvision.models.resnet import ResNet, BasicBlock
from torchvision.models.inception import Inception3, BasicConv2d
import time

In [None]:
n_epochs = 30
batch_size_train = 50
batch_size_test = 50

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
transform = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(),
     torchvision.transforms.Normalize((0.1307,), (0.3081,))])

train_loader = torch.utils.data.DataLoader(
               torchvision.datasets.MNIST(root='./MNIST', train=True,
                                          download=True, transform=transform),
               batch_size=batch_size_train, shuffle=True)

test_loader =  torch.utils.data.DataLoader(
               torchvision.datasets.MNIST(root='./MNIST', train=False,
                                          download=True, transform=transform),
               batch_size=batch_size_test, shuffle=False)

In [None]:
class MnistResNet(ResNet):
    def __init__(self):
        super(MnistResNet, self).__init__(BasicBlock, [18, 18, 18, 18], num_classes = 10)
        # Modify the first layer to accept MNIST images
        self.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        
    def forward(self, x):
        return super(MnistResNet, self).forward(x)

In [None]:
class MnistInception(Inception3):
    def __init__(self):
        super(MnistInception, self).__init__(num_classes=10, aux_logits=False)
        # Modify the first layer to accept MNIST images
        self.Conv2d_1a_3x3 = BasicConv2d(1, 32, kernel_size=3, stride=2)
    def forward(self, x):
        x = torch.nn.functional.interpolate(x, size=(229, 229), mode='bilinear')
        return super(MnistInception, self).forward(x)

In [None]:
def train(net, criterion, optimizer, trainloader):
    net.train()
    for i, (inputs, labels) in enumerate(trainloader):
        output = net(inputs.to(device))
        loss = criterion(output, labels.to(device))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [None]:
def test(net, testloader):
    net.eval()
    numRight = 0.0
    for inputs, labels in testloader:
        output = net(inputs.to(device))
        _, pred = torch.max(output, 1)
        numRight += torch.sum(pred == labels.to(device))
    return numRight.item()/len(testloader.dataset)

In [None]:
criterion = torch.nn.CrossEntropyLoss().to(device)

In [None]:
resnet = MnistResNet().to(device)
optimizer = torch.optim.SGD(resnet.parameters(), 0.1,
                            momentum=0.9,
                            weight_decay=1e-4)

start=time.time()
for i in range(n_epochs):
    torch.cuda.empty_cache()
    train(resnet, criterion, optimizer, train_loader)
    torch.cuda.empty_cache()
    print(i, 'th epoch:', test(resnet, test_loader), 'time: ', time.time() - start)

In [None]:
for i in range(n_epochs, 2*n_epochs):
    torch.cuda.empty_cache()
    train(resnet, criterion, optimizer, train_loader)
    torch.cuda.empty_cache()
    print(i, 'th epoch:', test(resnet, test_loader), 'time: ', time.time() - start)

In [None]:
torch.save(resnet, 'resnet110MNIST.pt')
del resnet
torch.cuda.empty_cache()

In [None]:
inception = MnistInception().to(device)
optimizer = torch.optim.SGD(inception.parameters(), 0.1,
                            momentum=0.9,
                            weight_decay=1e-4)

start=time.time()
for i in range(2*n_epochs):
    torch.cuda.empty_cache()
    train(inception, criterion, optimizer, train_loader)
    torch.cuda.empty_cache()
    print(i, 'th epoch:', test(inception, test_loader), 'time: ', time.time() - start)

In [None]:
torch.save(inception, 'inceptionMNIST.pt')
del inception
torch.cuda.empty_cache()