In [39]:
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import os
import numpy as  np
# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,)),
                              ])
# Download and load the training data
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

In [2]:
from torch import nn, optim
import torch.nn.functional as F

#  device variable for gpu functionality
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, 10)
        
    def forward(self, x):
        # make sure input tensor is flattened
        x = x.view(x.shape[0], -1)
        
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.log_softmax(self.fc4(x), dim=1)
        
        return x

In [5]:
model = Classifier().to(device)
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.003)

In [6]:
epochs = 30

for e in range(epochs):
    running_loss = 0

    for images, labels in trainloader:
        
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        
        log_ps = model(images)
        loss = criterion(log_ps, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print("epoch: {}/{}, loss: {}".format(e,epochs,running_loss))

epoch: 0/30, loss: 321.020621791482
epoch: 1/30, loss: 156.9900071863085
epoch: 2/30, loss: 127.4544461928308
epoch: 3/30, loss: 109.00662127323449
epoch: 4/30, loss: 100.54807238653302
epoch: 5/30, loss: 90.9983499981463
epoch: 6/30, loss: 85.11708540096879
epoch: 7/30, loss: 79.23691769316792
epoch: 8/30, loss: 76.31451794691384
epoch: 9/30, loss: 70.71322200167924
epoch: 10/30, loss: 66.28947063721716
epoch: 11/30, loss: 63.524978118017316
epoch: 12/30, loss: 63.60489966254681
epoch: 13/30, loss: 59.04888444766402
epoch: 14/30, loss: 60.378593353554606
epoch: 15/30, loss: 58.23979605548084
epoch: 16/30, loss: 54.764619713649154
epoch: 17/30, loss: 50.87073614727706
epoch: 18/30, loss: 51.05850944761187
epoch: 19/30, loss: 55.57285486860201
epoch: 20/30, loss: 48.443306404165924
epoch: 21/30, loss: 49.181739325635135
epoch: 22/30, loss: 47.950982017442584
epoch: 23/30, loss: 41.47129496792331
epoch: 24/30, loss: 52.33026483748108
epoch: 25/30, loss: 41.376682286616415
epoch: 26/30, l

In [8]:
torch.save(model.state_dict(), './project/mnist_checkpoint.pth')

In [58]:
base_path = "./project/mnist_results/node "
for images,labels in trainloader:
    images = images.to(device)
    labels = labels.to(device)
    logpbs = model(images)
    values, class_idx = logpbs.topk(1,dim=1)
    
    for image,label in zip(images,class_idx):
        path = base_path + str(label.item())
        if not os.path.exists(path): 
            os.mkdir(path)
    
        p, dirs, files = next(os.walk(path))
        file_count = len(files)
        img = image.cpu().numpy()
        img = np.squeeze(img)
        plt.imsave(path+"/"+ "img_" +str(file_count) + '.jpg',img)
        
        #         print(image.shape,label.item())