In [1]:
import sys
!{sys.executable} -m pip install torch torchvision pillow

You should consider upgrading via the 'pip install --upgrade pip' command.[0m


In [2]:
from torchvision import datasets, transforms
import torch
import PIL

In [3]:
IMG_SIZE = 28 * 28
IS_3 = 0.
IS_7 = 1.

In [4]:
# download and reference dataset
mnist_train = datasets.MNIST(download=True, root="storage", train=True)
mnist_valid = datasets.MNIST(download=True, root="storage", train=False)

In [5]:
threes_or_sevens_train = [x for x in mnist_train if x[1]==3 or x[1]==7]
threes_or_sevens_valid = [x for x in mnist_valid if x[1]==3 or x[1]==7]

len(threes_or_sevens_train), len(threes_or_sevens_valid)

(12396, 2038)

In [6]:
x_train = torch.stack([transforms.ToTensor()(d[0])for d in threes_or_sevens_train]).view(-1, IMG_SIZE)
x_valid = torch.stack([transforms.ToTensor()(d[0]) for d in threes_or_sevens_valid]).view(-1, IMG_SIZE)

x_train.shape, x_valid.shape

(torch.Size([12396, 784]), torch.Size([2038, 784]))

In [7]:
y_train = torch.tensor([IS_3 if d[1] == 3 else IS_7 for d in threes_or_sevens_train]).unsqueeze(1)
y_valid = torch.tensor([IS_3 if d[1] == 3 else IS_7 for d in threes_or_sevens_valid]).unsqueeze(1)

y_train.shape, y_train[3], y_train[0], y_valid.shape

(torch.Size([12396, 1]), tensor([1.]), tensor([0.]), torch.Size([2038, 1]))

In [8]:
# ds is our Dataset, list of (image, value) tuples
train_ds = list(zip(x_train, y_train))
valid_ds = list(zip(x_valid, y_valid))

train_ds[0][0].shape, train_ds[0][1], valid_ds[0][0].shape, valid_ds[0][1]

(torch.Size([784]), tensor([0.]), torch.Size([784]), tensor([1.]))

In [9]:
# Create the DataLoaders
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=256)
valid_dl = torch.utils.data.DataLoader(valid_ds, batch_size=256)

train_dl.dataset[0][0].shape, train_dl.dataset[0][1], valid_dl.dataset[0][0].shape, valid_dl.dataset[0][1]

(torch.Size([784]), tensor([0.]), torch.Size([784]), tensor([1.]))

In [10]:
# create simple ReLU model
net = torch.nn.Sequential(
    torch.nn.Linear(IMG_SIZE, 30),
    torch.nn.ReLU(),
    torch.nn.Linear(30, 1)
)

In [11]:
LR = 0.1
optimizer = torch.optim.SGD(net.parameters(), LR)

In [12]:
def get_loss(predictions, targets):
    predictions = predictions.sigmoid()
    return torch.where(targets==1, 1-predictions, predictions).mean()

In [13]:
def calc_gradient(xb, yb, model):
    predictions = model(xb)
    loss = get_loss(predictions, yb)
    loss.backward()

In [14]:
def train_epoch(model):
    # Iterate over dataset batches
    # xb is a tensor with the independent variables for the batch (tensor of pixel values)
    # yb         ""           dependent             ""            (which digit it is)
    for xb, yb in train_dl:
        calc_gradient(xb, yb, model)
        optimizer.step()
        optimizer.zero_grad()

In [15]:
def batch_accuracy(xb, yb):
    preds = xb.sigmoid()
    correct = (preds>0.5).float() == yb
    return correct.float().mean()

In [16]:
def validate_epoch(model):
    accuracy = [batch_accuracy(model(xb), yb) for xb, yb in valid_dl]
    return round(torch.stack(accuracy).mean().item(), 4)

In [17]:
for epoch in range(40):
    print(f"Epoch Accuracy  : {validate_epoch(net)}")
    train_epoch(net) # or linear
    print("====================================================")
    print(f"Epoch           : {epoch}")
    print(f"Epoch Accuracy  : {validate_epoch(net)}")

Epoch Accuracy  : 0.5045
Epoch           : 0
Epoch Accuracy  : 0.9661
Epoch Accuracy  : 0.9661
Epoch           : 1
Epoch Accuracy  : 0.9691
Epoch Accuracy  : 0.9691
Epoch           : 2
Epoch Accuracy  : 0.9706
Epoch Accuracy  : 0.9706
Epoch           : 3
Epoch Accuracy  : 0.9711
Epoch Accuracy  : 0.9711
Epoch           : 4
Epoch Accuracy  : 0.9725
Epoch Accuracy  : 0.9725
Epoch           : 5
Epoch Accuracy  : 0.9735
Epoch Accuracy  : 0.9735
Epoch           : 6
Epoch Accuracy  : 0.9745
Epoch Accuracy  : 0.9745
Epoch           : 7
Epoch Accuracy  : 0.9745
Epoch Accuracy  : 0.9745
Epoch           : 8
Epoch Accuracy  : 0.976
Epoch Accuracy  : 0.976
Epoch           : 9
Epoch Accuracy  : 0.976
Epoch Accuracy  : 0.976
Epoch           : 10
Epoch Accuracy  : 0.9764
Epoch Accuracy  : 0.9764
Epoch           : 11
Epoch Accuracy  : 0.9774
Epoch Accuracy  : 0.9774
Epoch           : 12
Epoch Accuracy  : 0.9779
Epoch Accuracy  : 0.9779
Epoch           : 13
Epoch Accuracy  : 0.9779
Epoch Accuracy  : 0.