In [None]:
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import os
import numpy as np

def get_mnist_data(train_dir, image_size, digit_0=8, digit_1=9):

  data_transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.RandomHorizontalFlip(),
    transforms.Resize((image_size,image_size)),
    transforms.Grayscale(3),
    transforms.ToTensor(),
  ])
  train_files = os.listdir(train_dir)
  dataset = torchvision.datasets.MNIST(root='./data', 
                        train=True, 
                        download=True, 
                        transform=data_transform)
  idx_0 = dataset.targets==digit_0
  idx_1 = dataset.targets==digit_1
  
  data_0 = dataset.data[idx_0]
  targets_0 = torch.zeros(len(data_0))
  data_1 = dataset.data[idx_1]
  targets_1 = torch.ones(len(data_1))
  dataset.data = torch.concat((data_0,data_1), dim = 0)
  dataset.targets = torch.concat((targets_0,targets_1), dim = 0)
  return dataset

num_classes = 2
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = torchvision.models.densenet121(pretrained=True)
num_ftrs = model.classifier.in_features
model.classifier = nn.Sequential(
    nn.Linear(num_ftrs, 500),
    nn.Linear(500, num_classes)
    )

model = model.to(device)
batch_size = 32

dataset = get_mnist_data('./', 32)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=2,
                                         shuffle=True, drop_last=True, pin_memory=True)

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

for epoch in range(10):  # loop over the dataset multiple times

    running_loss = 0.0

    for i, data in enumerate(dataloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        print(labels)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

    print(running_loss)
torch.save(model.state_dict(), 'drive/MyDrive/StyleEX/mnist_classifier.pt')

tensor([1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1,
        1, 0, 0, 1, 1, 0, 1, 1])
tensor([0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0,
        1, 0, 1, 0, 1, 0, 1, 1])
tensor([0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0,
        0, 1, 1, 1, 0, 1, 1, 0])
tensor([0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0,
        0, 0, 0, 1, 1, 1, 1, 0])
tensor([1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1,
        0, 1, 0, 0, 1, 0, 1, 1])
tensor([1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1,
        0, 1, 0, 0, 1, 0, 0, 0])
tensor([1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1,
        0, 1, 1, 0, 1, 0, 1, 0])
tensor([1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1,
        1, 1, 1, 1, 1, 1, 1, 0])
tensor([1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
        0, 0, 0,

In [None]:
torch.save(model.state_dict(), 'drive/MyDrive/StyleEX/mnist_classifier.pt')

In [None]:
print('Finished Training')
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2:.3f}')
print(running_loss)