In [21]:
%config IPCompleter.greedy = True

In [22]:
import numpy as np

import torch
from torch import nn, optim
import torch.nn.functional as F
from torchvision import datasets, transforms

import matplotlib as plt
%matplotlib inline

import helper

In [23]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_set = datasets.FashionMNIST('FashionMNIST_data/', download = True, train = True, transform = transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size = 8, shuffle = True)

In [24]:
test_set = datasets.FashionMNIST('FashionMNIST_data/', download = True, train = False, transform = transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size = 8, shuffle = True)

In [25]:
class Network(nn.Module):
    def __init__(self, input_size, output_size, hidden_layers, drop_p):
        super().__init__()
        
        self.hidden_layers = nn.ModuleList([nn.Linear(input_size, hidden_layers[0])])
            
        layer_sizes = zip(hidden_layers[:-1], hidden_layers[1:])
        self.hidden_layers.extend([nn.Linear(h1, h2) for h1, h2 in layer_sizes ])
        
        self.output_size = nn.Linear(hidden_layers[-1], output_size)
        
        self.dropout = nn.Dropout(p = drop_p)
        
        
    def forward(self, x):
        
        for linear in self.hidden_layers:
            x = F.relu(linear(x))
            x = self.dropout(x)
        
        x = self.output_size(x)
        
        return F.log_softmax(x, dim = 1)

In [26]:
model = Network(784, 10, [516, 256], 0.5)
#print(model)

In [27]:
optimizer = optim.Adam(model.parameters(), lr = 0.001)
criterion = nn.NLLLoss()

In [28]:
def validation(model, test_loader, criterion):
    accuracy = 0
    test_loss = 0
    for images, labels in test_loader:
        
        images = images.resize_(images.size()[0], 784)
        
        output = model.forward(images)
        ps = torch.exp(output)
        
        test_loss += criterion(output, labels).item()
        
        equality = (labels.data == ps.max(dim = 1)[1])
        accuracy += equality.type(torch.FloatTensor).mean()
    
    return test_loss, accuracy

In [29]:
epochs = 3
steps = 0
print_every = 40
for e in range(epochs):
    running_loss = 0
    model.train()
    for images, labels in train_loader:
        steps +=1
        images.resize_(images.shape[0], 784)
        
        optimizer.zero_grad()
        
        output = model.forward(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        if steps % print_every == 0:
            model.eval()
            
            with torch.no_grad():
                test_loss, accuracy = validation(model, test_loader, criterion)
                print(
                    'running_loss : {:.3f}'.format(running_loss/print_every),
                    'test_loss : {:.3f}'.format(test_loss/len(test_loader)),
                    'accuracy {:.3f}'.format(accuracy/len(test_loader))
                )
                
            running_loss = 0
            model.train()

running_loss : 1.814 test_loss : 1.244 accuracy 0.606
running_loss : 1.329 test_loss : 1.169 accuracy 0.579
running_loss : 1.093 test_loss : 0.888 accuracy 0.699
running_loss : 1.043 test_loss : 0.832 accuracy 0.662
running_loss : 1.112 test_loss : 0.823 accuracy 0.697
running_loss : 0.890 test_loss : 0.725 accuracy 0.716
running_loss : 0.879 test_loss : 0.709 accuracy 0.726
running_loss : 0.917 test_loss : 0.692 accuracy 0.733
running_loss : 0.890 test_loss : 0.757 accuracy 0.718
running_loss : 0.900 test_loss : 0.686 accuracy 0.726
running_loss : 0.716 test_loss : 0.671 accuracy 0.761
running_loss : 0.667 test_loss : 0.701 accuracy 0.734
running_loss : 0.838 test_loss : 0.756 accuracy 0.714
running_loss : 0.755 test_loss : 0.698 accuracy 0.735
running_loss : 0.803 test_loss : 0.690 accuracy 0.744
running_loss : 0.718 test_loss : 0.723 accuracy 0.734
running_loss : 0.837 test_loss : 0.626 accuracy 0.770
running_loss : 0.807 test_loss : 0.705 accuracy 0.733
running_loss : 0.847 test_lo

running_loss : 0.585 test_loss : 0.503 accuracy 0.815
running_loss : 0.599 test_loss : 0.502 accuracy 0.821
running_loss : 0.675 test_loss : 0.506 accuracy 0.819
running_loss : 0.543 test_loss : 0.507 accuracy 0.810
running_loss : 0.595 test_loss : 0.544 accuracy 0.800
running_loss : 0.698 test_loss : 0.509 accuracy 0.817
running_loss : 0.615 test_loss : 0.519 accuracy 0.815
running_loss : 0.699 test_loss : 0.509 accuracy 0.822
running_loss : 0.628 test_loss : 0.492 accuracy 0.826
running_loss : 0.595 test_loss : 0.517 accuracy 0.814
running_loss : 0.685 test_loss : 0.517 accuracy 0.820
running_loss : 0.585 test_loss : 0.520 accuracy 0.817
running_loss : 0.624 test_loss : 0.498 accuracy 0.824
running_loss : 0.628 test_loss : 0.517 accuracy 0.824
running_loss : 0.624 test_loss : 0.500 accuracy 0.819
running_loss : 0.594 test_loss : 0.509 accuracy 0.821
running_loss : 0.635 test_loss : 0.507 accuracy 0.819
running_loss : 0.570 test_loss : 0.480 accuracy 0.826
running_loss : 0.653 test_lo

running_loss : 0.588 test_loss : 0.499 accuracy 0.823
running_loss : 0.606 test_loss : 0.491 accuracy 0.826
running_loss : 0.508 test_loss : 0.484 accuracy 0.824
running_loss : 0.659 test_loss : 0.474 accuracy 0.830
running_loss : 0.672 test_loss : 0.505 accuracy 0.818
running_loss : 0.621 test_loss : 0.493 accuracy 0.825
running_loss : 0.570 test_loss : 0.486 accuracy 0.825
running_loss : 0.610 test_loss : 0.507 accuracy 0.814
running_loss : 0.579 test_loss : 0.476 accuracy 0.822
running_loss : 0.618 test_loss : 0.472 accuracy 0.828
running_loss : 0.543 test_loss : 0.481 accuracy 0.819
running_loss : 0.567 test_loss : 0.506 accuracy 0.814
running_loss : 0.743 test_loss : 0.517 accuracy 0.818
running_loss : 0.658 test_loss : 0.494 accuracy 0.834
running_loss : 0.616 test_loss : 0.466 accuracy 0.832
running_loss : 0.574 test_loss : 0.469 accuracy 0.831
running_loss : 0.585 test_loss : 0.490 accuracy 0.826
running_loss : 0.611 test_loss : 0.483 accuracy 0.827
running_loss : 0.600 test_lo

running_loss : 0.537 test_loss : 0.478 accuracy 0.827
running_loss : 0.710 test_loss : 0.497 accuracy 0.829
running_loss : 0.578 test_loss : 0.496 accuracy 0.819
running_loss : 0.574 test_loss : 0.506 accuracy 0.809
running_loss : 0.511 test_loss : 0.479 accuracy 0.833
running_loss : 0.593 test_loss : 0.489 accuracy 0.830
running_loss : 0.612 test_loss : 0.480 accuracy 0.836
running_loss : 0.600 test_loss : 0.479 accuracy 0.830
running_loss : 0.696 test_loss : 0.473 accuracy 0.834
running_loss : 0.587 test_loss : 0.490 accuracy 0.831
running_loss : 0.587 test_loss : 0.492 accuracy 0.821
running_loss : 0.586 test_loss : 0.474 accuracy 0.829
running_loss : 0.760 test_loss : 0.499 accuracy 0.821
running_loss : 0.554 test_loss : 0.515 accuracy 0.830
running_loss : 0.543 test_loss : 0.473 accuracy 0.836
running_loss : 0.503 test_loss : 0.487 accuracy 0.829
running_loss : 0.683 test_loss : 0.481 accuracy 0.831
running_loss : 0.595 test_loss : 0.479 accuracy 0.837
running_loss : 0.533 test_lo

In [30]:
checkpoint = {
                'input_size':784,
                'output_size':10,
                'hidden_layers':[each.out_features for each in model.hidden_layers],
                'state_dict':model.state_dict()
            }

In [31]:
torch.save(checkpoint, 'checkpoint.pth')