In [None]:
import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.init

In [None]:
input=torch.Tensor(1,1,28,28) #배치크기(한번에 로딩하는 이미지 수), 채널, 높이, 너비
print(input.size())
print(input)

In [None]:
conv1=nn.Conv2d(1, 32, 3, padding=1) 
# Conv2d(입력 채널 수, 출력 채널 수, kernal_size(필터의 크기), stride=1, padding=1)
print(conv1)

In [None]:
conv2=nn.Conv2d(32, 64, kernel_size=3, padding=1)
print(conv2)

In [None]:
pool=nn.MaxPool2d(2)
print(pool)

In [None]:
out1=conv1(input)
out2=pool(out1)
print(out1.size())
print(out2.size())

In [None]:
out3=conv2(out2)
out4=pool(out3)
print(out3.size())
print(out4.size())

In [None]:
out=out4.view(out4.size(0), -1)
print(out.size())

In [None]:
fc=nn.Linear(3136, 10)
outf=fc(out)
print(outf.size())
print(outf)

In [10]:
device='cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(777)

if device=='cuda':
    torch.cuda.manual_seed_all(777)

# device=''
# if torch.cuda.is_available:
#     device='cuda'
# else:
#     device='cpu'

In [11]:
learing_rate=0.001
epochs=15
batch_size=100

In [None]:
mnist_train=dsets.MNIST(root='MNIST_data',
                        train=True,
                        transform=transforms.ToTensor(),
                        download=True)
mnist_test=dsets.MNIST(root='MNIST_data',
                        train=False,
                        transform=transforms.ToTensor(),
                        download=True)

In [None]:
print(mnist_train)
print(mnist_test)

In [14]:
train_loader=DataLoader(dataset=mnist_train,
                       batch_size=batch_size,
                       shuffle=True,
                       drop_last=False)

test_loader=DataLoader(dataset=mnist_test,
                       batch_size=batch_size,
                       shuffle=True,
                       drop_last=False)

In [None]:
for X, Y in train_loader:
    print(X.size())
    print(Y.size())
    break
for Y,Y in test_loader:
    print(X.size())
    print(Y.size())
    break

In [16]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        #conv layer1
        #image in shape(100, 1, 28, 28)
        # conv -> (?, 32, 28, 28)
        # pool ->(?, 32, 14,14)

        self.layer1=nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        #conv layer1
        #image in shape(?, 32, 14, 14)
        # conv -> (?, 64, 14, 14)
        # pool ->(?, 64, 7, 7)
        self.layer2=nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        #완전 연결층 (100,64*7*7)
        self.fc=nn.Linear(64*7*7, 10, bias=True)
        #완전 연결층 한정으로 가중치 초기화
        nn.init.xavier_uniform_(self.fc.weight)

    def forward(self, x):
            out=self.layer1(x)
            out=self.layer2(out)
            out=out.view(out.size(0),-1)
            out=self.fc(out)
            return out

In [17]:
model=CNN().to(device)
criterion=nn.CrossEntropyLoss().to(device)
optimizer=torch.optim.Adam(model.parameters(), lr=learing_rate)

In [None]:
print(model)
print(list(model.parameters()))

In [None]:
train_total_batch=len(train_loader)
test_total_batch=len(test_loader)
print(train_total_batch)
print(test_total_batch)

In [None]:
for epoch in range(epochs):
    avg_cost=0

    for X, Y in train_loader:
        X=X.to(device)
        Y=Y.to(device)

        optimizer.zero_grad()
        y_hat=model(X)
        cost=criterion(y_hat, Y)
        cost.backward()
        optimizer.step()

        avg_cost+=cost/train_total_batch
    print('Epoc:', epoch, 'cost:', avg_cost)

In [None]:
with torch.no_grad():
    avg_accuracy=0

    for X,Y in test_loader:
        X=X.to(device)
        Y=Y.to(device)
        pred=model(X)
        correct_prde=torch.argmax(pred, -1)==Y
        accuracy=correct_prde.float().sum()
        avg_accuracy+=accuracy
    avg_accuracy=avg_accuracy/test_total_batch
print('Accuracy:', avg_accuracy)