## Standard Lenet model training with all MNIST data, for the purpose of further evaluation of generative model

In [1]:
import numpy as np
import torch
from lenet_emnist import Model
from torchvision import datasets
from torch.nn import CrossEntropyLoss
from torch.optim import Adam, lr_scheduler
from torch.utils.data import DataLoader
from torchvision import transforms

In [2]:
torch.cuda.set_device(0)
device = torch.device("cuda")

In [3]:
batch_size = 64
transform = transforms.Compose([
    transforms.ToTensor()])

train_dataset = datasets.EMNIST(root='./emnist_data/', train=True, transform=transform,split = "letters", download=True)
test_dataset = datasets.EMNIST(root='./emnist_data/', train=False, transform=transform,split = "letters", download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [4]:
len(train_dataset.classes)

26

In [5]:
from torch.nn import Module
from torch import nn
import torch


class Model(Module):
    def __init__(self):
        super(Model, self).__init__()
        dropout_rate = 0.4
        self.conv1 = nn.Conv2d(1, 32, 5)
        self.relu1 = nn.LeakyReLU()
        self.pool1 = nn.MaxPool2d(2)
        self.dropout_1 = nn.Dropout2d(dropout_rate)
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.relu2 = nn.LeakyReLU()
        self.dropout_2 = nn.Dropout2d(dropout_rate)
        self.conv3 = nn.Conv2d(64, 16, 5)
        self.relu3 = nn.LeakyReLU()
        # self.pool2 = nn.MaxPool2d(2)
        self.dropout_3 = nn.Dropout2d(dropout_rate)
        self.fc1 = nn.Linear(256, 256)
        self.relu3 = nn.LeakyReLU()
        self.dropout_4 = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(256, 128)
        self.relu4 = nn.LeakyReLU()
        self.dropout_5 = nn.Dropout(dropout_rate)
        self.fc3 = nn.Linear(128, 26)

    def forward(self, x):
        y = self.conv1(x)
        y = self.relu1(y)
        y = self.pool1(y)
        y = self.dropout_1(y)
        y = self.conv2(y)
        y = self.relu2(y)
#         y = self.pool2(y)
        y = self.dropout_2(y)
        y = self.conv3(y)
        y = self.relu3(y)
        y = self.dropout_3(y)

        y = y.view(y.shape[0], -1)
        y = self.fc1(y)
        y = self.relu3(y)
        y = self.dropout_4(y)
        y = self.fc2(y)
        y = self.relu4(y)
        y = self.dropout_5(y)
        y = self.fc3(y)
        # y = torch.softmax(y, dim=1)
        return y

    ##### Part forward without last classification layer for the purpose of FID computing
    def part_forward(self, x):
        y = self.conv1(x)
        y = self.relu1(y)
        y = self.pool1(y)
        y = self.conv2(y)
        y = self.relu2(y)
        y = self.conv3(y)
        y = self.relu3(y)

        y = y.view(y.shape[0], -1)
        y = self.fc1(y)
        y = self.relu3(y)
        y = self.fc2(y)
        y = self.relu4(y)
        y = self.fc3(y)
        # y = torch.softmax(y, dim=1)
        return y


lenet = Model()


In [6]:
# from model import Model
model = Model().to(device)
optim = Adam(model.parameters(), lr=0.001)
scheduler = lr_scheduler.ExponentialLR(optim, gamma=0.95)
cross_error = CrossEntropyLoss()
epoch = 50

In [7]:
for _epoch in range(epoch):
    model.train()
    for idx, (train_x, train_label) in enumerate(train_loader):
        train_x = train_x.to(device)
        train_label = train_label.to(device)
        label_np = np.zeros((train_label.shape[0], 10))
        optim.zero_grad()
        predict_y = model(train_x)
        _error = cross_error(predict_y, train_label.long())
        if idx % 100 == 0:
            print('epoch:{}, idx: {}, loss: {}'.format(_epoch, idx, _error))
        _error.backward()
        optim.step()

    correct = 0
    _sum = 0
    with torch.no_grad():
        model.eval()
        for idx, (test_x, test_label) in enumerate(test_loader):
            test_x = test_x.to(device)
            test_label = test_label.to(device)
            predict_y = model(test_x)
            predict_ys = torch.argmax(predict_y,dim=1)
            correct_predictions = predict_ys == test_label
            correct += torch.sum(correct_predictions)
            _sum += len(test_label)

    print('accuracy: {:.2f}'.format(correct.item() / _sum))
    scheduler.step()

epoch:0, idx: 0, loss: 3.2685587406158447
epoch:0, idx: 100, loss: 2.7975127696990967
epoch:0, idx: 200, loss: 2.2975375652313232
epoch:0, idx: 300, loss: 2.1261141300201416
epoch:0, idx: 400, loss: 1.8837556838989258
epoch:0, idx: 500, loss: 1.4434443712234497
epoch:0, idx: 600, loss: 1.529410481452942
epoch:0, idx: 700, loss: 1.6519105434417725
epoch:0, idx: 800, loss: 1.184548258781433
epoch:0, idx: 900, loss: 1.2210208177566528
epoch:0, idx: 1000, loss: 1.3087923526763916
epoch:0, idx: 1100, loss: 1.319075584411621
epoch:0, idx: 1200, loss: 0.9897813200950623
epoch:0, idx: 1300, loss: 1.3043535947799683
epoch:0, idx: 1400, loss: 0.745810866355896
epoch:0, idx: 1500, loss: 1.1315323114395142
epoch:0, idx: 1600, loss: 1.4407765865325928
epoch:0, idx: 1700, loss: 1.075335144996643
epoch:0, idx: 1800, loss: 0.9410991668701172
epoch:0, idx: 1900, loss: 1.2537108659744263
accuracy: 0.81
epoch:1, idx: 0, loss: 1.3426481485366821
epoch:1, idx: 100, loss: 0.8932811617851257
epoch:1, idx: 20

epoch:9, idx: 300, loss: 0.5972140431404114
epoch:9, idx: 400, loss: 1.007644772529602
epoch:9, idx: 500, loss: 0.5057883262634277
epoch:9, idx: 600, loss: 0.36606132984161377
epoch:9, idx: 700, loss: 0.6791837215423584
epoch:9, idx: 800, loss: 0.48555147647857666
epoch:9, idx: 900, loss: 0.5882794260978699
epoch:9, idx: 1000, loss: 0.6092739105224609
epoch:9, idx: 1100, loss: 0.7305454611778259
epoch:9, idx: 1200, loss: 0.7028895020484924
epoch:9, idx: 1300, loss: 0.6569573879241943
epoch:9, idx: 1400, loss: 0.48185229301452637
epoch:9, idx: 1500, loss: 0.8030858039855957
epoch:9, idx: 1600, loss: 0.7719143033027649
epoch:9, idx: 1700, loss: 0.6259778738021851
epoch:9, idx: 1800, loss: 0.47757673263549805
epoch:9, idx: 1900, loss: 0.6519085764884949
accuracy: 0.88
epoch:10, idx: 0, loss: 0.879122257232666
epoch:10, idx: 100, loss: 0.5775210857391357
epoch:10, idx: 200, loss: 0.6367624998092651
epoch:10, idx: 300, loss: 0.6633958220481873
epoch:10, idx: 400, loss: 0.7212265133857727
ep

epoch:18, idx: 100, loss: 0.631492018699646
epoch:18, idx: 200, loss: 0.509729266166687
epoch:18, idx: 300, loss: 0.5739425420761108
epoch:18, idx: 400, loss: 0.6544072031974792
epoch:18, idx: 500, loss: 0.5183764100074768
epoch:18, idx: 600, loss: 0.7357341051101685
epoch:18, idx: 700, loss: 0.8293105959892273
epoch:18, idx: 800, loss: 0.5184540748596191
epoch:18, idx: 900, loss: 0.5597798824310303
epoch:18, idx: 1000, loss: 0.5427014231681824
epoch:18, idx: 1100, loss: 0.7260130643844604
epoch:18, idx: 1200, loss: 0.9140374660491943
epoch:18, idx: 1300, loss: 0.5888085961341858
epoch:18, idx: 1400, loss: 0.4833970069885254
epoch:18, idx: 1500, loss: 0.4787976145744324
epoch:18, idx: 1600, loss: 0.7982016205787659
epoch:18, idx: 1700, loss: 0.7488061189651489
epoch:18, idx: 1800, loss: 0.4796580672264099
epoch:18, idx: 1900, loss: 0.662778377532959
accuracy: 0.89
epoch:19, idx: 0, loss: 0.7569707036018372
epoch:19, idx: 100, loss: 0.3736054599285126
epoch:19, idx: 200, loss: 0.4136036

epoch:26, idx: 1900, loss: 0.4962373673915863
accuracy: 0.89
epoch:27, idx: 0, loss: 0.7405327558517456
epoch:27, idx: 100, loss: 0.44571200013160706
epoch:27, idx: 200, loss: 0.6797658801078796
epoch:27, idx: 300, loss: 0.6276440620422363
epoch:27, idx: 400, loss: 0.6678110361099243
epoch:27, idx: 500, loss: 0.48219773173332214
epoch:27, idx: 600, loss: 0.4392123818397522
epoch:27, idx: 700, loss: 0.7110159993171692
epoch:27, idx: 800, loss: 0.27225497364997864
epoch:27, idx: 900, loss: 0.399553507566452
epoch:27, idx: 1000, loss: 0.6479641795158386
epoch:27, idx: 1100, loss: 0.6785836219787598
epoch:27, idx: 1200, loss: 0.5211825966835022
epoch:27, idx: 1300, loss: 0.7330880761146545
epoch:27, idx: 1400, loss: 0.5411509275436401
epoch:27, idx: 1500, loss: 0.5284163355827332
epoch:27, idx: 1600, loss: 0.7888737916946411
epoch:27, idx: 1700, loss: 0.849892258644104
epoch:27, idx: 1800, loss: 0.39794594049453735
epoch:27, idx: 1900, loss: 0.5094201564788818
accuracy: 0.89
epoch:28, idx:

epoch:35, idx: 1700, loss: 0.4718596041202545
epoch:35, idx: 1800, loss: 0.3202572464942932
epoch:35, idx: 1900, loss: 0.5263855457305908
accuracy: 0.89
epoch:36, idx: 0, loss: 0.7286061644554138
epoch:36, idx: 100, loss: 0.37976765632629395
epoch:36, idx: 200, loss: 0.46091437339782715
epoch:36, idx: 300, loss: 0.5026434063911438
epoch:36, idx: 400, loss: 0.4722321927547455
epoch:36, idx: 500, loss: 0.3854318857192993
epoch:36, idx: 600, loss: 0.5342596769332886
epoch:36, idx: 700, loss: 0.6330614686012268
epoch:36, idx: 800, loss: 0.23235294222831726
epoch:36, idx: 900, loss: 0.6063835024833679
epoch:36, idx: 1000, loss: 0.4272679388523102
epoch:36, idx: 1100, loss: 0.6629315614700317
epoch:36, idx: 1200, loss: 0.6044929623603821
epoch:36, idx: 1300, loss: 0.761053204536438
epoch:36, idx: 1400, loss: 0.3693996071815491
epoch:36, idx: 1500, loss: 0.41807007789611816
epoch:36, idx: 1600, loss: 0.8235099911689758
epoch:36, idx: 1700, loss: 0.6241347789764404
epoch:36, idx: 1800, loss: 0

epoch:44, idx: 1500, loss: 0.5402437448501587
epoch:44, idx: 1600, loss: 0.7867254614830017
epoch:44, idx: 1700, loss: 0.6375436186790466
epoch:44, idx: 1800, loss: 0.31676018238067627
epoch:44, idx: 1900, loss: 0.49724334478378296
accuracy: 0.90
epoch:45, idx: 0, loss: 0.5738134384155273
epoch:45, idx: 100, loss: 0.4453425705432892
epoch:45, idx: 200, loss: 0.5189844369888306
epoch:45, idx: 300, loss: 0.7662016153335571
epoch:45, idx: 400, loss: 1.1147212982177734
epoch:45, idx: 500, loss: 0.5518215298652649
epoch:45, idx: 600, loss: 0.5387885570526123
epoch:45, idx: 700, loss: 0.6782339811325073
epoch:45, idx: 800, loss: 0.28186890482902527
epoch:45, idx: 900, loss: 0.465010404586792
epoch:45, idx: 1000, loss: 0.6501250863075256
epoch:45, idx: 1100, loss: 0.5659326910972595
epoch:45, idx: 1200, loss: 0.477780818939209
epoch:45, idx: 1300, loss: 0.5951284766197205
epoch:45, idx: 1400, loss: 0.5100560188293457
epoch:45, idx: 1500, loss: 0.33821675181388855
epoch:45, idx: 1600, loss: 0.

In [8]:
with torch.no_grad():
    for idx, (test_x, test_label) in enumerate(test_loader):
        test_x = test_x.to(device)
        test_label = test_label.to(device)
        predict_y = model(test_x)
        predict_ys = torch.argmax(predict_y,dim=1)
        correct_predictions = predict_ys == test_label
        correct += torch.sum(correct_predictions)
        _sum += len(test_label)

print('accuracy: {:.2f}'.format(correct.item() / _sum))

accuracy: 0.90


In [9]:
model = model.cpu()

In [10]:
torch.save(model.state_dict(), 'lenet_emnist_letters')