In [284]:
import matplotlib.pyplot as plt 
from torch import nn
import tensorflow as tf
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader,TensorDataset 
from tqdm import tqdm
import torchsummary

### Cuda Check

In [12]:
USE_CUDA = torch.cuda.is_available()
device = torch.device('cuda:0' if USE_CUDA else 'cpu')

### getting ready for data

In [117]:
def get_one_hot(data, type_num):  # valid for any shape

        table = np.eye(type_num)
        def mapping(x): return table[x]

        _new = mapping(data)

        new_shape = [data.shape[i] for i in range(len(data.shape))]
        new_shape.append(type_num)
        _new = _new.reshape(new_shape)

        return _new

In [144]:
_random_state = 32

In [145]:
def get_MNIST_data(device):
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    x_train, x_valid, y_train, y_valid = train_test_split(x_train, y_train, stratify=y_train, test_size=0.1, random_state=_random_state)
    
    y_train_encoded = get_one_hot(y_train, 10)
    y_valid_encoded = get_one_hot(y_valid, 10)
    y_test_encoded = get_one_hot(y_test, 10)
    
    x_train = torch.FloatTensor(x_train).to(device)
    y_train_encoded = torch.FloatTensor(y_train_encoded).to(device)

    x_valid = torch.FloatTensor(x_valid).to(device)
    y_valid_encoded = torch.FloatTensor(y_valid_encoded).to(device)

    x_test = torch.FloatTensor(x_test).to(device)
    y_test_encoded = torch.FloatTensor(y_test_encoded).to(device)

    x_train /= 255
    x_valid /= 255
    x_test /= 255
    
    x_train = x_train.unsqueeze(1)
    x_valid = x_valid.unsqueeze(1)
    x_test = x_test.unsqueeze(1)
    
    
    
    return x_train, y_train_encoded, x_valid, y_valid_encoded, x_test, y_test_encoded

In [146]:
x_train, y_train, x_valid, y_valid, x_test, y_test = get_MNIST_data(device)

### dataloader

In [147]:
class MNISTDataset(Dataset):
    
    
    def __init__(self,x,y):
        
        self.len = len(x)
        self.x_data = x
        self.y_data = y
        
        
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]
        
        
    def __len__(self):
        return self.len
    

In [148]:
train_dataset = MNISTDataset(x_train, y_train)
train_loader = DataLoader(dataset = train_dataset, batch_size = 32, shuffle = True)

test_dataset = MNISTDataset(x_test, y_test)
test_loader = DataLoader(dataset = test_dataset, batch_size = 32, shuffle = True)

In [149]:
next(iter(train_loader))[0].shape

torch.Size([32, 1, 28, 28])

### model

In [150]:
#n_input = 784
#n_hidden = 128
#n_target = 10

#epochs = 50
#learning_rate = 1e-3

In [278]:
n_input = 784
n_hidden1 = 2500
n_hidden2 = 2000
n_hidden3 = 1500
n_hidden4 = 1000
n_hidden5 = 500
n_target = 10

epochs = 50
learning_rate = 5e-4

In [279]:
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.n_input = n_input
        self.n_hidden1 = n_hidden1
        self.n_hidden2 = n_hidden2
        self.n_hidden3 = n_hidden3
        self.n_hidden4 = n_hidden4
        self.n_hidden5 = n_hidden5
        self.n_target = n_target
        self._build()

    def _build(self):
        self.layer = nn.Sequential(
            nn.Linear(self.n_input, self.n_hidden1),
            nn.LeakyReLU(),
            
            nn.Linear(self.n_hidden1, self.n_hidden2),
            nn.LeakyReLU(),
            
            nn.Linear(self.n_hidden2, self.n_hidden3),
            nn.LeakyReLU(),
            
            nn.Linear(self.n_hidden3, self.n_hidden4),
            nn.LeakyReLU(),
            
            nn.Linear(self.n_hidden4, self.n_hidden5),
            nn.LeakyReLU(),
            
            nn.Linear(self.n_hidden5, self.n_target)
        )

    
    def forward(self, x):
        x = x.view(-1, self.n_input)
        return self.layer(x)

    def train(self,):
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=int(epochs/10), gamma=0.6)
        
        
        
        self.dataloader = train_loader
        self.epochs = epochs
        
        global_step = 0
        for epoch in range(self.epochs):

            t = tqdm(self.dataloader)
            for i, data in enumerate(t):
                inputs, labels = data

                inputs = inputs.to(device)
                labels = labels.to(device)

                self.optimizer.zero_grad()
                self.zero_grad()

                y_pred = self.forward(inputs)
                loss = self.criterion(y_pred, labels)

                loss.backward()
                self.optimizer.step()
                
                global_step += 1
                t.set_description("epoch: {} | global_step: {:8d} | loss: {:.4f}".format(
                    epoch + 1, global_step, loss))
        
            predicts = self.forward(x_valid)
            predicts = torch.argmax(predicts, 1)
            label = torch.argmax(y_valid, 1)
            

            correct = 0
            for i in range(len(predicts)):
                if(predicts[i].item() == label[i]):
                    correct += 1

            acc = correct / len(predicts)
            print("epoch: {} | global_step: {} | valid acc: {:.3%}".format(
                epoch + 1, global_step, acc))
            
            
            self.scheduler.step()
            if(epoch % 5 == 0):
                print('lr:{}'.format(self.optimizer.param_groups[0]['lr']))
            
        
        
        
        
    def evaluate(self, x, y):

        pred = self.forward(x)
        pred = torch.argmax(pred, 1)

        label = y
        label = torch.argmax(label, 1)
        data_len = len(pred)
        cor = 0

        for i in range(data_len):
            if(pred[i] == label[i]):
                cor += 1

        acc = cor / data_len
        print("acc: {}".format(acc))

    def inference(self, x, y):
        pred = self.forward(x)
        pred = torch.argmax(pred, 1)

        print("inference: {}, label: {}".format(pred.item(), y.item()))


In [280]:
model = MyModel().to(device)

In [281]:
torchsummary.summary(model,(1,28,28,))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                 [-1, 2500]       1,962,500
         LeakyReLU-2                 [-1, 2500]               0
            Linear-3                 [-1, 2000]       5,002,000
         LeakyReLU-4                 [-1, 2000]               0
            Linear-5                 [-1, 1500]       3,001,500
         LeakyReLU-6                 [-1, 1500]               0
            Linear-7                 [-1, 1000]       1,501,000
         LeakyReLU-8                 [-1, 1000]               0
            Linear-9                  [-1, 500]         500,500
        LeakyReLU-10                  [-1, 500]               0
           Linear-11                   [-1, 10]           5,010
Total params: 11,972,510
Trainable params: 11,972,510
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Fo

### Train

In [282]:
model.train()

epoch: 1 | global_step:     1688 | loss: 0.1393: 100%|████████████████████████████| 1688/1688 [00:09<00:00, 171.76it/s]


epoch: 1 | global_step: 1688 | valid acc: 96.133%
lr:0.0005


epoch: 2 | global_step:     3376 | loss: 0.2642: 100%|████████████████████████████| 1688/1688 [00:08<00:00, 190.51it/s]


epoch: 2 | global_step: 3376 | valid acc: 96.733%


epoch: 3 | global_step:     5064 | loss: 0.1220: 100%|████████████████████████████| 1688/1688 [00:09<00:00, 185.13it/s]


epoch: 3 | global_step: 5064 | valid acc: 97.450%


epoch: 4 | global_step:     6752 | loss: 0.0607: 100%|████████████████████████████| 1688/1688 [00:09<00:00, 177.10it/s]


epoch: 4 | global_step: 6752 | valid acc: 97.667%


epoch: 5 | global_step:     8440 | loss: 0.0366: 100%|████████████████████████████| 1688/1688 [00:09<00:00, 180.33it/s]


epoch: 5 | global_step: 8440 | valid acc: 97.717%


epoch: 6 | global_step:    10128 | loss: 0.0066: 100%|████████████████████████████| 1688/1688 [00:09<00:00, 181.70it/s]


epoch: 6 | global_step: 10128 | valid acc: 98.150%
lr:0.0003


epoch: 7 | global_step:    11816 | loss: 0.0437: 100%|████████████████████████████| 1688/1688 [00:09<00:00, 186.51it/s]


epoch: 7 | global_step: 11816 | valid acc: 98.167%


epoch: 8 | global_step:    13504 | loss: 0.0002: 100%|████████████████████████████| 1688/1688 [00:08<00:00, 192.49it/s]


epoch: 8 | global_step: 13504 | valid acc: 97.717%


epoch: 9 | global_step:    15192 | loss: 0.0004: 100%|████████████████████████████| 1688/1688 [00:08<00:00, 196.71it/s]


epoch: 9 | global_step: 15192 | valid acc: 98.183%


epoch: 10 | global_step:    16880 | loss: 0.0024: 100%|███████████████████████████| 1688/1688 [00:08<00:00, 196.74it/s]


epoch: 10 | global_step: 16880 | valid acc: 98.083%


epoch: 11 | global_step:    18568 | loss: 0.0000: 100%|███████████████████████████| 1688/1688 [00:08<00:00, 195.92it/s]


epoch: 11 | global_step: 18568 | valid acc: 98.433%
lr:0.00017999999999999998


epoch: 12 | global_step:    20256 | loss: 0.0001: 100%|███████████████████████████| 1688/1688 [00:09<00:00, 182.46it/s]


epoch: 12 | global_step: 20256 | valid acc: 98.350%


epoch: 13 | global_step:    21944 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:09<00:00, 180.79it/s]


epoch: 13 | global_step: 21944 | valid acc: 98.050%


epoch: 14 | global_step:    23632 | loss: 0.0000: 100%|███████████████████████████| 1688/1688 [00:09<00:00, 179.50it/s]


epoch: 14 | global_step: 23632 | valid acc: 98.333%


epoch: 15 | global_step:    25320 | loss: 0.0000: 100%|███████████████████████████| 1688/1688 [00:09<00:00, 181.50it/s]


epoch: 15 | global_step: 25320 | valid acc: 98.467%


epoch: 16 | global_step:    27008 | loss: 0.0000: 100%|███████████████████████████| 1688/1688 [00:09<00:00, 183.91it/s]


epoch: 16 | global_step: 27008 | valid acc: 98.283%
lr:0.00010799999999999998


epoch: 17 | global_step:    28696 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:08<00:00, 189.59it/s]


epoch: 17 | global_step: 28696 | valid acc: 98.317%


epoch: 18 | global_step:    30384 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:08<00:00, 188.35it/s]


epoch: 18 | global_step: 30384 | valid acc: 98.450%


epoch: 19 | global_step:    32072 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:08<00:00, 190.47it/s]


epoch: 19 | global_step: 32072 | valid acc: 98.467%


epoch: 20 | global_step:    33760 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:08<00:00, 190.25it/s]


epoch: 20 | global_step: 33760 | valid acc: 98.500%


epoch: 21 | global_step:    35448 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:08<00:00, 192.39it/s]


epoch: 21 | global_step: 35448 | valid acc: 98.517%
lr:6.479999999999999e-05


epoch: 22 | global_step:    37136 | loss: 0.0000: 100%|███████████████████████████| 1688/1688 [00:08<00:00, 189.37it/s]


epoch: 22 | global_step: 37136 | valid acc: 98.500%


epoch: 23 | global_step:    38824 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:08<00:00, 189.30it/s]


epoch: 23 | global_step: 38824 | valid acc: 98.483%


epoch: 24 | global_step:    40512 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:09<00:00, 184.87it/s]


epoch: 24 | global_step: 40512 | valid acc: 98.417%


epoch: 25 | global_step:    42200 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:08<00:00, 191.10it/s]


epoch: 25 | global_step: 42200 | valid acc: 98.417%


epoch: 26 | global_step:    43888 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:09<00:00, 179.57it/s]


epoch: 26 | global_step: 43888 | valid acc: 98.417%
lr:3.8879999999999994e-05


epoch: 27 | global_step:    45576 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:09<00:00, 183.05it/s]


epoch: 27 | global_step: 45576 | valid acc: 98.433%


epoch: 28 | global_step:    47264 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:10<00:00, 163.35it/s]


epoch: 28 | global_step: 47264 | valid acc: 98.433%


epoch: 29 | global_step:    48952 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:09<00:00, 177.03it/s]


epoch: 29 | global_step: 48952 | valid acc: 98.433%


epoch: 30 | global_step:    50640 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:09<00:00, 178.33it/s]


epoch: 30 | global_step: 50640 | valid acc: 98.450%


epoch: 31 | global_step:    52328 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:08<00:00, 190.66it/s]


epoch: 31 | global_step: 52328 | valid acc: 98.450%
lr:2.3327999999999994e-05


epoch: 32 | global_step:    54016 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:08<00:00, 190.45it/s]


epoch: 32 | global_step: 54016 | valid acc: 98.450%


epoch: 33 | global_step:    55704 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:09<00:00, 185.07it/s]


epoch: 33 | global_step: 55704 | valid acc: 98.450%


epoch: 34 | global_step:    57392 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:08<00:00, 188.68it/s]


epoch: 34 | global_step: 57392 | valid acc: 98.450%


epoch: 35 | global_step:    59080 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:09<00:00, 184.91it/s]


epoch: 35 | global_step: 59080 | valid acc: 98.450%


epoch: 36 | global_step:    60768 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:08<00:00, 191.01it/s]


epoch: 36 | global_step: 60768 | valid acc: 98.450%
lr:1.3996799999999996e-05


epoch: 37 | global_step:    62456 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:08<00:00, 193.08it/s]


epoch: 37 | global_step: 62456 | valid acc: 98.450%


epoch: 38 | global_step:    64144 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:08<00:00, 188.31it/s]


epoch: 38 | global_step: 64144 | valid acc: 98.450%


epoch: 39 | global_step:    65832 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:09<00:00, 170.10it/s]


epoch: 39 | global_step: 65832 | valid acc: 98.450%


epoch: 40 | global_step:    67520 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:09<00:00, 170.86it/s]


epoch: 40 | global_step: 67520 | valid acc: 98.467%


epoch: 41 | global_step:    69208 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:09<00:00, 185.51it/s]


epoch: 41 | global_step: 69208 | valid acc: 98.467%
lr:8.398079999999997e-06


epoch: 42 | global_step:    70896 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:08<00:00, 192.75it/s]


epoch: 42 | global_step: 70896 | valid acc: 98.483%


epoch: 43 | global_step:    72584 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:09<00:00, 186.16it/s]


epoch: 43 | global_step: 72584 | valid acc: 98.483%


epoch: 44 | global_step:    74272 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:09<00:00, 182.05it/s]


epoch: 44 | global_step: 74272 | valid acc: 98.500%


epoch: 45 | global_step:    75960 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:09<00:00, 181.04it/s]


epoch: 45 | global_step: 75960 | valid acc: 98.500%


epoch: 46 | global_step:    77648 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:09<00:00, 179.35it/s]


epoch: 46 | global_step: 77648 | valid acc: 98.500%
lr:5.038847999999998e-06


epoch: 47 | global_step:    79336 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:09<00:00, 182.93it/s]


epoch: 47 | global_step: 79336 | valid acc: 98.483%


epoch: 48 | global_step:    81024 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:09<00:00, 186.69it/s]


epoch: 48 | global_step: 81024 | valid acc: 98.483%


epoch: 49 | global_step:    82712 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:08<00:00, 194.47it/s]


epoch: 49 | global_step: 82712 | valid acc: 98.483%


epoch: 50 | global_step:    84400 | loss: -0.0000: 100%|██████████████████████████| 1688/1688 [00:09<00:00, 187.28it/s]


epoch: 50 | global_step: 84400 | valid acc: 98.483%


### evaluation w/ test set
- 98.66%

In [283]:
model.evaluate(x_test, y_test)

acc: 0.9866
