In [None]:
import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import random

In [None]:
device = 'cuda'
random.seed(777)
torch.manual_seed(777)
if device == 'cuda':
    torch.cuda.manual_seed_all(777)

In [None]:
#datasets
mnist_train = dsets.MNIST(root='MNIST_data/',train=True,transform=transforms.ToTensor(),download=True)
mnist_test = dsets.MNIST(root='MNIST_data/',train=False,transform=transforms.ToTensor(),download=True)

In [None]:
#parameters
learning_rate=0.1
epochs=15
batch_size=100
drop_prob=0.5

In [None]:
#dataset loader
data_loader = torch.utils.data.DataLoader(dataset=mnist_train,batch_size=batch_size,shuffle=True,drop_last=True)

In [None]:
#layers
layer1=torch.nn.Linear(784,512,True)
layer2=torch.nn.Linear(512,512,True)
layer3=torch.nn.Linear(512,512,True)
layer4=torch.nn.Linear(512,512,True)
layer5=torch.nn.Linear(512,10,True)

relu=torch.nn.ReLU()
dropout=torch.nn.Dropout(p=drop_prob)

#weight initialization
torch.nn.init.xavier_uniform_(layer1.weight)
torch.nn.init.xavier_uniform_(layer2.weight)
torch.nn.init.xavier_uniform_(layer3.weight)
torch.nn.init.xavier_uniform_(layer4.weight)
torch.nn.init.xavier_uniform_(layer5.weight)

#model
model=torch.nn.Sequential(layer1,relu,dropout,layer2,relu,dropout,layer3,relu,dropout,layer4,relu,dropout,layer5).to(device)


In [None]:
criterion=torch.nn.CrossEntropyLoss().to(device)
optimizer=torch.optim.SGD(model.parameters(),lr=learning_rate)

In [None]:

#training
total_batch=len(data_loader)
model.train()

for epoch in range(epochs):
  total_cost=0
  for X,Y in data_loader:
    X=X.view(-1,28*28).to(device)
    Y=Y.to(device)

    optimizer.zero_grad()
    hypothesis=model(X)
    cost=criterion(hypothesis,Y)
    cost.backward()
    optimizer.step()

    total_cost+=cost
  avg_cost=total_cost/total_batch

  print('Epoch: ','%3d' %(epoch+1),'Cost: ','{:.8f}'.format(avg_cost))
print('learning finished')

In [None]:
#test
print(len(mnist_test))
with torch.no_grad():
  model.eval()
  X_test=mnist_test.test_data.view(-1,28*28).float().to(device)
  Y_test=mnist_test.test_labels.to(device)

  prediction=model(X_test)
  correct_prediction=torch.argmax(prediction,1)==Y_test
  accuracy=correct_prediction.float().mean()
  print('Accuracy:', accuracy.item())

  r=random.randint(1,len(mnist_test)-1)
  X_single_data = mnist_test.test_data[r:r + 1].view(-1, 28 * 28).float().to(device)
  Y_single_data = mnist_test.test_labels[r:r + 1].to(device)

  print('Label: ', Y_single_data.item())
  single_prediction = model(X_single_data)
  print('Prediction: ', torch.argmax(single_prediction, 1).item())
print(X_single_data)
print(Y_single_data)