<a href="https://colab.research.google.com/github/jinnyjinny/Fashion-MNIST-Pytorch/blob/main/Fashion_MNIST_Lenet5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchvision
import torchvision.datasets as dsets
import torchvision.transforms as transforms 
from torch.utils.data import Dataset, DataLoader
import torch.nn.init
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision import models
import time

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

torch.manual_seed(777) 
if device == 'cuda':
  torch.cuda.manual_seed_all(777) 

In [None]:
# Parameters

learning_rate = 0.001 
training_epochs = 50 
batch_size = 200

In [None]:
# input tensor resizing

transforms=transforms.Compose([
                               transforms.Resize((35,35)),
                               transforms.ToTensor(),
                               ])

mnist_train = dsets.FashionMNIST(root='MNIST_data/',
                          train=True,
                          transform=transforms,
                          download=True)

mnist_test = dsets.FashionMNIST(root='MNIST_data/',
                         train=False,
                         transform=transforms,
                         download=True)



In [None]:
train_loader = DataLoader(dataset=mnist_train,
                          batch_size=batch_size,
                          shuffle=True,
                          drop_last=True)

test_loader = DataLoader(dataset=mnist_test,
                         batch_size=batch_size,
                         shuffle=True,
                         drop_last=True)

In [None]:
class Lenet5(torch.nn.Module):
  def __init__(self):
    super(Lenet5,self).__init__()

    self.l1=torch.nn.Conv2d(1,6,kernel_size=5,padding=0,stride=1)
    self.x1=torch.nn.Tanh()
    self.l2=torch.nn.AvgPool2d(kernel_size=2,padding=0,stride=2)

    self.l3=torch.nn.Conv2d(6,16,kernel_size=5,padding=0,stride=1)
    self.x2=torch.nn.Tanh()
    self.l4=torch.nn.AvgPool2d(kernel_size=2,padding=0,stride=2)

    self.l5=torch.nn.Flatten()

    self.l6=torch.nn.Linear(16*5*5,120,bias=True)
    self.x3=torch.nn.Tanh()

    self.l7=torch.nn.Linear(120,84,bias=True)
    self.x4=torch.nn.Tanh()

    self.l8=torch.nn.Linear(84,10,bias=True)
  


  def forward(self,x):
    out=self.l1(x)
    out=self.x1(out)
    out=self.l2(out)
    out=self.l3(out)
    out=self.x2(out)
    out=self.l4(out)

    out=out.view(out.size(0),-1)

    out=self.l6(out)
    out=self.x3(out)
    out=self.l7(out)
    out=self.x4(out)
    out=self.l8(out)
    return out

In [None]:
model = Lenet5().to(device) # device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
# train my model
total_batch = len(train_loader)
print('Learning stared. It takes sometime.')
for epoch in range(training_epochs):
  avg_cost = 0

  for X, Y in train_loader:
      X = X.to(device)
      Y = Y.to(device)

      optimizer.zero_grad() # forward pass: 파이토치에서는 미분을 통해 얻은 기울기(gradient)를 0으로 초기화하고 학습을 진행
      hypothesis = model(X) 
      cost = loss(hypothesis, Y) 

      cost.backward() # backward pass: 가중치 업데이트 
      optimizer.step() 

      avg_cost += cost / total_batch # 각 배치마다 계산된 손실값의 평균
  
  print('[Epoch: {:>2}] cost = {:>.9}'.format(epoch + 1, avg_cost)) # 1 에폭마다 손실값이 얼마나 나오는지 확인
print('Learning Finished!')

In [None]:
# validation loss 계산하기
accuracy = 0
total_batch = len(test_loader)

# 학습을 진행하지 않을 것이므로 torch.no_grad()
with torch.no_grad(): 
  model.eval() 

  for X_test, Y_test in test_loader:
      X_test = X_test.to(device)
      Y_test = Y_test.to(device)

      prediction = model(X_test)
      correction_prediction = torch.argmax(prediction, 1) == Y_test 
      accuracy += correct_prediction.float().mean() # 모든 배치의 정확도를 누적
      # accuracy = correct_prediction.float().mean()을 하면 하나의 배치에 해당된 정확도를 알 수 있다

print('Accuracy:' accuracy.item()/total_batch) 