In [1]:
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [2]:
input_size = 784
hidden_size0 = 32
hidden_size1 = 16     
out_size = 1
        
epochs = 10            
batch_size = 64
learning_rate = 0.001

bin_digit = 5 #model predicts "bin_digit or not bin_digit"

In [3]:
train_dataset = datasets.MNIST(root='./data',
                           train=True,
                           transform=transforms.ToTensor(),
                           download=True)

test_dataset = datasets.MNIST(root='./data',
                           train=False,
                           transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)

In [4]:
class Net(nn.Module):
    def __init__(self, input_size, hidden_size0, hidden_size1, out_size):
        super(Net, self).__init__()
        self.fc0 = nn.Linear(input_size, hidden_size0)
        self.fc1 = nn.Linear(hidden_size0, hidden_size1)
        self.fc2 = nn.Linear(hidden_size1, out_size)
        self.tanh = nn.Tanh()
        self.init_weights()

    def init_weights(self):
        nn.init.xavier_uniform_(self.fc0.weight, gain=nn.init.calculate_gain('tanh')) 
        nn.init.xavier_uniform_(self.fc1.weight, gain=nn.init.calculate_gain('tanh'))
        nn.init.xavier_uniform_(self.fc2.weight, gain=nn.init.calculate_gain('tanh'))
        
    def forward(self, x):
        x = x.view(x.size(0), -1)
        out = self.tanh(self.fc0(x))
        out = self.tanh(self.fc1(out))
        out = self.fc2(out)

        return out

In [5]:
net = Net(input_size, hidden_size0, hidden_size1, out_size)
CUDA = torch.cuda.is_available()
if CUDA:
    net = net.cuda()

criterion = nn.BCEWithLogitsLoss() #uses sigmoid
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

In [6]:
#train

for epoch in range(epochs):
    total_train_samples = 0
    correct_train = 0
    running_loss = 0
    for i, (images, labels) in enumerate(train_loader):
        images = images.view(-1, 28*28)
        if CUDA:
            images = images.cuda()
            labels = labels.cuda()

        bin_mask = (labels == bin_digit)
        other_mask = ~bin_mask

        bin_images = images[bin_mask]
        bin_labels = labels[bin_mask]

        other_images = images[other_mask]
        other_labels = labels[other_mask]

        n_bin = bin_images.size(0)

        if n_bin == 0:
            continue

        target_other_per_digit = n_bin

        balanced_other_images = []
        balanced_other_labels = []

        for digit in range(10):
            if digit == bin_digit:
                continue
            
            digit_mask = (other_labels == digit)
            digit_images = other_images[digit_mask]
            digit_labels = other_labels[digit_mask]

            if digit_images.size(0) < target_other_per_digit and digit_images.size(0) > 0:
                indices = torch.randint(0, digit_images.size(0), (target_other_per_digit,), device=digit_images.device)
            else:
                indices = torch.randperm(digit_images.size(0), device=digit_images.device)[:target_other_per_digit]

            balanced_other_images.append(digit_images[indices])
            balanced_other_labels.append(digit_labels[indices])

        if len(balanced_other_images) == 0:
            continue

        balanced_other_images = torch.cat(balanced_other_images, dim=0)
        balanced_other_labels = torch.cat(balanced_other_labels, dim=0)

        batch_images = torch.cat([bin_images, balanced_other_images], dim=0)
        batch_labels = torch.cat([bin_labels, balanced_other_labels], dim=0)

        perm = torch.randperm(batch_images.size(0), device=batch_images.device)
        batch_images = batch_images[perm]
        batch_labels = batch_labels[perm]

        batch_labels_bin = (batch_labels == bin_digit).float().view(-1, 1)


        outputs = net(batch_images)
        loss = criterion(outputs, batch_labels_bin)
        running_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        predicted = (torch.sigmoid(outputs) >= 0.5).float()
        total_train_samples += batch_labels_bin.size(0)  #total examples used in training
        correct_train += (predicted == batch_labels_bin).sum().item()

    accuracy = 100 * correct_train / total_train_samples
    print(f'Epoch [{epoch+1}/{epochs}], Training Loss: {running_loss / len(train_loader):.3f}, Training Accuracy: {accuracy:.3f}%')

print("DONE TRAINING!")
torch.save(net.state_dict(), f'stored_model_weights_binary/ffn_mnist_binary_{bin_digit}.pth')


Epoch [1/10], Training Loss: 0.081, Training Accuracy: 97.234%
Epoch [2/10], Training Loss: 0.036, Training Accuracy: 98.880%
Epoch [3/10], Training Loss: 0.025, Training Accuracy: 99.215%
Epoch [4/10], Training Loss: 0.021, Training Accuracy: 99.335%
Epoch [5/10], Training Loss: 0.016, Training Accuracy: 99.450%
Epoch [6/10], Training Loss: 0.013, Training Accuracy: 99.573%
Epoch [7/10], Training Loss: 0.010, Training Accuracy: 99.695%
Epoch [8/10], Training Loss: 0.008, Training Accuracy: 99.717%
Epoch [9/10], Training Loss: 0.009, Training Accuracy: 99.688%
Epoch [10/10], Training Loss: 0.007, Training Accuracy: 99.769%
DONE TRAINING!


In [7]:
test_images, test_labels = [], []
for images, labels in test_loader:
    test_images.append(images)
    test_labels.append(labels)

test_images = torch.cat(test_images)
test_labels = torch.cat(test_labels)

mask_0 = (test_labels != bin_digit)
mask_1 = (test_labels == bin_digit)
n = mask_1.sum().item()
idx_0 = torch.nonzero(mask_0, as_tuple=True)[0]
rand_idx = idx_0[torch.randperm(len(idx_0))[:n]]
new_mask_0 = torch.zeros_like(test_labels, dtype=torch.bool)
new_mask_0[rand_idx] = True
mask = mask_1 | new_mask_0

images = test_images[mask]
labels = test_labels[mask]

In [8]:
net.eval()
correct, total = 0, 0
with torch.no_grad():
    for i in range(0, len(images), batch_size):
        x = images[i:i+batch_size]
        y = labels[i:i+batch_size]
        if CUDA:
            x = x.cuda()
            y = y.cuda()
        x = x.view(-1, 28*28)
        y_bin = (y == bin_digit).float().view(-1, 1)
        outputs = net(x)
        predicted = (torch.sigmoid(outputs) >= 0.5).float()
        correct += (predicted == y_bin).sum().item()
        total += y_bin.size(0)

print(f'Accuracy of the network on digit {bin_digit}: {100 * correct / total:.2f} %')


Accuracy of the network on digit 5: 97.76 %
