In [1]:
import torch
from torch import nn
from torch import optim

import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import datasets, transforms

In [2]:
transform = transforms.Compose([transforms.ToTensor(), 
                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = datasets.FashionMNIST('../F_MNIST_data', download=False, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

testset = datasets.FashionMNIST('../F_MNIST_data', download=False, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)

In [3]:
class Network(nn.Module):
    def __init__(self, input_size, output_size, hidden_layers, drop_p = 0.5):
        super().__init__()
        
        self.hidden_layers = nn.ModuleList([nn.Linear(input_size, hidden_layers[0])])
        
        layers = zip(hidden_layers[:-1], hidden_layers[1:])
        self.hidden_layers.extend([nn.Linear(h1, h2) for h1, h2 in layers])
        
        self.output = nn.Linear(hidden_layers[-1], output_size)
        self.dropout = nn.Dropout(p = drop_p)
    
    def forward(self, x):
        for linear in self.hidden_layers:
            x = F.relu(linear(x))
            x = self.dropout(x)
        x = self.output(x)
        return F.log_softmax(x, dim=1)

In [4]:
model = Network(784, 10, [128, 64], drop_p=0.5)
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [5]:
def validation(model, testloader, criterion):
    model.eval()   #shut off dropout
    current_loss = 0
    accuracy = 0
    for images, labels in iter(testloader):
        images.resize_(images.size()[0], 28 * 28)
        with torch.no_grad():
            output = model.forward(images)
            current_loss += criterion(output, labels).item()
        
            ps = torch.exp(output).data
            equality = (labels.data == ps.max(dim=1)[1])
            accuracy += equality.type(torch.FloatTensor).mean()
            print(accuracy)
        
    return current_loss / len(testloader), accuracy / len(testloader)

In [6]:
epochs = 1
steps = 0
running_loss = 0
print_every = 40

for e in range(epochs):
    model.train()
    for images, labels in iter(trainloader):
        steps += 1
        
        images.resize_(images.size()[0], 28 * 28)
        inputs = Variable(images)
        targets = Variable(labels)
        optimizer.zero_grad()
        
        output = model.forward(inputs)
        loss = criterion(output, targets)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        if steps % print_every == 0:
            test_loss, test_accuracy = validation(model, testloader, criterion)
            print('{} / {} epoch'.format(e + 1,  epochs))
            print('train loss: {:.4f}'.format(running_loss / print_every) )
            print('test accuracy: {:.4f}'.format(test_accuracy))
            print('test loss: {:.4f}'.format(test_loss))
            running_loss = 0
                

tensor(0.6094)
tensor(1.2969)
tensor(1.9531)
tensor(2.5469)
tensor(3.1094)
tensor(3.7656)
tensor(4.5000)
tensor(5.0625)
tensor(5.6719)
tensor(6.3281)
tensor(6.8281)
tensor(7.4375)
tensor(8.0781)
tensor(8.7188)
tensor(9.2656)
tensor(9.9219)
tensor(10.6094)
tensor(11.1875)
tensor(11.7969)
tensor(12.4531)
tensor(13.0312)
tensor(13.6719)
tensor(14.3750)
tensor(15.0625)
tensor(15.6875)
tensor(16.3281)
tensor(16.9844)
tensor(17.5469)
tensor(18.1875)
tensor(18.8281)
tensor(19.3594)
tensor(19.9844)
tensor(20.7188)
tensor(21.2969)
tensor(21.9375)
tensor(22.5469)
tensor(23.1562)
tensor(23.8281)
tensor(24.3438)
tensor(25.0312)
tensor(25.6719)
tensor(26.2500)
tensor(26.8906)
tensor(27.5000)
tensor(28.2031)
tensor(28.7812)
tensor(29.4062)
tensor(30.1406)
tensor(30.7969)
tensor(31.4375)
tensor(32.1250)
tensor(32.7812)
tensor(33.3750)
tensor(33.9375)
tensor(34.5625)
tensor(35.2188)
tensor(35.8438)
tensor(36.4844)
tensor(37.0625)
tensor(37.8281)
tensor(38.4844)
tensor(39.0469)
tensor(39.7031)
tensor(4

tensor(48.4531)
tensor(49.2188)
tensor(49.8906)
tensor(50.7188)
tensor(51.5000)
tensor(52.3125)
tensor(53.0312)
tensor(53.7656)
tensor(54.4688)
tensor(55.2188)
tensor(56.)
tensor(56.9062)
tensor(57.6406)
tensor(58.4219)
tensor(59.1875)
tensor(59.9688)
tensor(60.8750)
tensor(61.6875)
tensor(62.4375)
tensor(63.2812)
tensor(64.0625)
tensor(64.8906)
tensor(65.7188)
tensor(66.5312)
tensor(67.2812)
tensor(67.8906)
tensor(68.6250)
tensor(69.3906)
tensor(70.2031)
tensor(71.0469)
tensor(71.7188)
tensor(72.5312)
tensor(73.3438)
tensor(74.0938)
tensor(74.8438)
tensor(75.6094)
tensor(76.4219)
tensor(77.2188)
tensor(78.0938)
tensor(78.8438)
tensor(79.6094)
tensor(80.4062)
tensor(81.2344)
tensor(82.0625)
tensor(82.7812)
tensor(83.6250)
tensor(84.3594)
tensor(85.1250)
tensor(85.9688)
tensor(86.6875)
tensor(87.5156)
tensor(88.2344)
tensor(88.9375)
tensor(89.5781)
tensor(90.3438)
tensor(91.0781)
tensor(91.8750)
tensor(92.6250)
tensor(93.4219)
tensor(94.2500)
tensor(95.0625)
tensor(95.8594)
tensor(96.59

tensor(73.3281)
tensor(74.1250)
tensor(75.0156)
tensor(75.8281)
tensor(76.6406)
tensor(77.5312)
tensor(78.4219)
tensor(79.3125)
tensor(80.0781)
tensor(80.9844)
tensor(81.8438)
tensor(82.5781)
tensor(83.3125)
tensor(84.1406)
tensor(84.9531)
tensor(85.7656)
tensor(86.4844)
tensor(87.3594)
tensor(88.1406)
tensor(89.0156)
tensor(89.8750)
tensor(90.5938)
tensor(91.2812)
tensor(92.0938)
tensor(92.9375)
tensor(93.7031)
tensor(94.5312)
tensor(95.3125)
tensor(96.0781)
tensor(96.9219)
tensor(97.7500)
tensor(98.5938)
tensor(99.3125)
tensor(100.1562)
tensor(100.9688)
tensor(101.7969)
tensor(102.6406)
tensor(103.4062)
tensor(104.2031)
tensor(105.0312)
tensor(105.8281)
tensor(106.6094)
tensor(107.4062)
tensor(108.2656)
tensor(109.0625)
tensor(109.8750)
tensor(110.7969)
tensor(111.6094)
tensor(112.3594)
tensor(113.0781)
tensor(113.9531)
tensor(114.7188)
tensor(115.4844)
tensor(116.2969)
tensor(117.0781)
tensor(117.9062)
tensor(118.7344)
tensor(119.5625)
tensor(120.3906)
tensor(121.2031)
tensor(122.06

tensor(100.3594)
tensor(101.1719)
tensor(102.0625)
tensor(102.8281)
tensor(103.6562)
tensor(104.4375)
tensor(105.3125)
tensor(106.0625)
tensor(106.8750)
tensor(107.6875)
tensor(108.4688)
tensor(109.3906)
tensor(110.2344)
tensor(111.1094)
tensor(111.9062)
tensor(112.7656)
tensor(113.5938)
tensor(114.4844)
tensor(115.3281)
tensor(116.1562)
tensor(116.8750)
tensor(117.7188)
tensor(118.4844)
tensor(119.2344)
tensor(120.)
tensor(120.7500)
tensor(121.5469)
tensor(122.3594)
tensor(123.2031)
tensor(124.0938)
tensor(124.8438)
tensor(125.5781)
tensor(126.3750)
tensor(127.2500)
1 / 1 epoch
train loss: 0.4671
test accuracy: 0.8105
test loss: 0.5265
tensor(0.7969)
tensor(1.5938)
tensor(2.3750)
tensor(3.1875)
tensor(4.0312)
tensor(4.7344)
tensor(5.5469)
tensor(6.3750)
tensor(7.2031)
tensor(8.0156)
tensor(8.7656)
tensor(9.6562)
tensor(10.4062)
tensor(11.1875)
tensor(12.0469)
tensor(12.8125)
tensor(13.7344)
tensor(14.6406)
tensor(15.4219)
tensor(16.2812)
tensor(17.0625)
tensor(17.7812)
tensor(18.5938)

KeyboardInterrupt: 

In [None]:
test_images, test_labels = next(iter(testloader))

model.eval()
test_images.resize_(test_images.size()[0], 28 * 28)

with torch.no_grad():
    output = model.forward(test_images)
    ps = torch.exp(output).data
    print(ps.max(dim=1)[1])
