In [16]:
import torch
import torchvision
import wandb
from torchvision import transforms, datasets
import torch.nn as nn
import torch.nn.functional as F

train = datasets.MNIST('', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor()
                       ]))

test = datasets.MNIST('', train=False, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor()
                       ]))


trainset = torch.utils.data.DataLoader(train, batch_size=10, shuffle=True)
testset = torch.utils.data.DataLoader(test, batch_size=10, shuffle=False)

wandb.init(project="my-test-project")
wandb.config = {
  "learning_rate": 0.001,
  "epochs": 5,
  "batch_size": len(trainset)
}


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 64)
        self.fc4 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return F.log_softmax(x, dim=1)

net = Net()
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("GPU run.")
else:
    device = torch.device("cpu")
    print("CPU run.")
device = torch.device("cuda:0")
net = net.to(device)
print(net)

GPU run.
Net(
  (fc1): Linear(in_features=784, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=64, bias=True)
  (fc4): Linear(in_features=64, out_features=10, bias=True)
)


In [17]:
import torch.optim as optim

loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

for epoch in range(5): # 3 full passes over the data
    for data in trainset:  # `data` is a batch of data
        X, y = data  # X is the batch of features, y is the batch of targets.
        X, y = X.to(device), y.to(device)
        net.zero_grad()  # sets gradients to 0 before loss calc. You will do this likely every step.
        output = net(X.view(-1,784)).to(device)  # pass in the reshaped batch (recall they are 28x28 atm)
        loss = F.nll_loss(output, y).to(device)  # calc and grab the loss value
        loss.backward()  # apply this loss backwards thru the network's parameters
        optimizer.step()  # attempt to optimize weights to account for loss/gradients
    print(loss)  # print loss. We hope loss (a measure of wrong-ness) declines! 
    wandb.log({"loss": loss})
    wandb.watch(net)
    
correct = 0
total = 0

with torch.no_grad():
    for data in testset:
        X, y = data
        X, y = X.to(device), y.to(device)
        output = net(X.view(-1,784)).to(device)
        #print(output)
        for idx, i in enumerate(output):
            #print(torch.argmax(i), y[idx])
            if torch.argmax(i) == y[idx]:
                correct += 1
            total += 1

print("Accuracy: ", round(correct/total, 3))
wandb.log({"Accuracy": round(correct/total, 3)})

tensor(0.0465, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.5917, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.6021, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.0106, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(0.0275, device='cuda:0', grad_fn=<NllLossBackward0>)
Accuracy:  0.973
