In [58]:
import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torch.nn.init
import warnings
warnings.filterwarnings('ignore')

In [37]:
device= 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(777)
if device=='cuda':
    torch.cuda.manual_seed_all(777)

In [39]:
learnning_rate=0.001
training_epochs=15
batch_size=100

In [43]:
mnist_train=dsets.MNIST(root="MNIST_data/",
                       train=True,
                       transform=transforms.ToTensor(),
                       download=True)
mnist_test=dsets.MNIST(root="MNIST_data/",
                      train=False,
                      transform=transforms.ToTensor(),
                      download=True)

In [44]:
data_loader=torch.utils.data.DataLoader(dataset=mnist_train,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       drop_last=True)

In [47]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.layer1=nn.Sequential(
        nn.Conv2d(1,32,kernel_size=3,stride=1,padding=1),
        nn.ReLU(),
        nn.MaxPool2d(2))
        self.layer2=nn.Sequential(
        nn.Conv2d(32,64,kernel_size=3,stride=1,padding=1),
        nn.ReLU(),
        nn.MaxPool2d(2))
        
        self.fc=nn.Linear(7*7*64, 10,bias=True)
        torch.nn.init.xavier_uniform_(self.fc.weight)
    
    def forward(self,x):
        out=self.layer1(x)
        out=self.layer2(out)
        out=out.view(out.size(0),-1)
        
        out=self.fc(out)
        return out
    

In [48]:
model=CNN().to(device)

In [50]:
model

CNN(
  (layer1): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Linear(in_features=3136, out_features=10, bias=True)
)

In [51]:
criterion=nn.CrossEntropyLoss().to(device)
optimizer=torch.optim.Adam(model.parameters(),lr=learnning_rate)

In [56]:

total_batch=len(data_loader)

for epoch in range(training_epochs):
    avg_cost=0
    for X,y in data_loader:
        X=X.to(device)
        y=y.to(device)
        
        optimizer.zero_grad()
        hypothesis=model(X)
        cost=criterion(hypothesis,y)
        cost.backward()
        optimizer.step()
        
        avg_cost += cost/total_batch
        
    print("Epoch",epoch,"cost",avg_cost)
print("Learning Finished")

Epoch 0 cost tensor(0.2200, device='cuda:0', grad_fn=<AddBackward0>)
Epoch 1 cost tensor(0.0618, device='cuda:0', grad_fn=<AddBackward0>)
Epoch 2 cost tensor(0.0459, device='cuda:0', grad_fn=<AddBackward0>)
Epoch 3 cost tensor(0.0376, device='cuda:0', grad_fn=<AddBackward0>)
Epoch 4 cost tensor(0.0310, device='cuda:0', grad_fn=<AddBackward0>)
Epoch 5 cost tensor(0.0253, device='cuda:0', grad_fn=<AddBackward0>)
Epoch 6 cost tensor(0.0217, device='cuda:0', grad_fn=<AddBackward0>)
Epoch 7 cost tensor(0.0183, device='cuda:0', grad_fn=<AddBackward0>)
Epoch 8 cost tensor(0.0154, device='cuda:0', grad_fn=<AddBackward0>)
Epoch 9 cost tensor(0.0134, device='cuda:0', grad_fn=<AddBackward0>)
Epoch 10 cost tensor(0.0107, device='cuda:0', grad_fn=<AddBackward0>)
Epoch 11 cost tensor(0.0102, device='cuda:0', grad_fn=<AddBackward0>)
Epoch 12 cost tensor(0.0085, device='cuda:0', grad_fn=<AddBackward0>)
Epoch 13 cost tensor(0.0078, device='cuda:0', grad_fn=<AddBackward0>)
Epoch 14 cost tensor(0.0053, d

In [65]:
with torch.no_grad():
    X_test=mnist_test.test_data.view(len(mnist_test),1,28,28).float().to(device)
    y_test=mnist_test.test_labels.to(device)
    
    prediction=model(X_test)
    correct_prediction=torch.argmax(prediction,1)==y_test
    accuracy=correct_prediction.float().mean()
    print("Accuracy",accuracy.item())

Accuracy 0.9865999817848206
