In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from math import floor

import matplotlib.pyplot as plt

from ml_utils import custom_data

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [2]:
def get_accuracy(curr_net, data_loader):
    """
    Use trainloader for train accuracy
    Use testloader for test accuracy
    """
    correct = 0
    total = 0
    with torch.no_grad():
        for data in data_loader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = curr_net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct / total

def train(curr_net, curr_optimizer, graph_title, num_epoch=5, print_interval=2000):
    train_accu = []
    test_accu = []
    print('HI')
    # reduce precision
    curr_net = curr_net.float()
    for epoch in range(num_epoch):
        running_loss = 0.0
        print('Epoch', epoch)
        for i, data in enumerate(trainloader, 0):

            criterion = nn.MSELoss()
            inputs, labels = data[0].float().to(device), data[1].float().to(device)

            curr_optimizer.zero_grad()  # fresh start

            # the entire training step
            outputs = curr_net(inputs)
            labels = labels.view(labels.size(0), 1)
            # print("size: output {}, labels {}".format(outputs.size(), labels.size()), flush=True)
            # print("values: output {}, labels {}".format(outputs, labels), flush=True)
            loss = criterion(outputs, labels)
            print("iter {} loss {}".format(i, loss))
            loss.backward()
            curr_optimizer.step()

            running_loss += loss.item()
            if i % print_interval == print_interval - 1:
                print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / print_interval))
                running_loss = 0.0
        # once per epoch
        train_accu.append(get_accuracy(curr_net, trainloader))
        test_accu.append(get_accuracy(curr_net, testloader))
                
    # graph test / train accuracy
    plt.title(graph_title + ' ' + 'Accuracy as a function of iteration')
    plt.xlabel('Iteration / Epoch')
    plt.ylabel('Accuracy')
    plt.plot(range(1, num_epoch + 1), train_accu, label = 'train')
    plt.plot(range(1, num_epoch + 1), test_accu, label = 'test')
    plt.legend()
    plt.savefig(graph_title + '.png')
    
    return max(train_accu), max(test_accu)

In [3]:
# Load data
image_folder = "output/pizza_urlc_10000/"
params = {
    'batch_size': 16,
    'shuffle': True,
    'num_workers': 8
}

# Load partition
train_partition, test_partition = custom_data.get_train_test_partition(image_folder)
trainset = custom_data.PizzaDatabase(image_folder, train_partition)
testset = custom_data.PizzaDatabase(image_folder, test_partition)

# Load DataLoader
trainloader = torch.utils.data.DataLoader(trainset, **params)
testloader = torch.utils.data.DataLoader(testset, **params)


## Simpliest Model

In [4]:
class FC_No_Hidden(nn.Module):
    def __init__(self):
        super(FC_No_Hidden, self).__init__()
        self.fc1 = nn.Linear(800*800*3, 1)

    def forward(self, x):
        # first transform to long vector then output through linear
        x = x.view(-1, 800*800*3)
        x = self.fc1(x)
        # squeeze the dimension
        x = x.view(x.size(0), -1)
        return x

In [5]:
part_a_net = FC_No_Hidden()
part_a_net.to(device)
part_a_opt = optim.SGD(part_a_net.parameters(), lr=0.00001, momentum=0)
train(part_a_net, part_a_opt, graph_title='simpliest_model', num_epoch=30, print_interval=100)

HI
Epoch 0
iter 0 loss 0.3730698823928833
iter 1 loss 4.005115509033203
iter 2 loss 61.52337646484375
iter 3 loss 888.7198486328125
iter 4 loss 13382.470703125
iter 5 loss 224991.359375
iter 6 loss 4239744.0
iter 7 loss 74775160.0
iter 8 loss 1077042688.0
iter 9 loss 12655656960.0
iter 10 loss 164620009472.0
iter 11 loss 2449315987456.0
iter 12 loss 41244558360576.0
iter 13 loss 436946786582528.0
iter 14 loss 4855126594420736.0


KeyboardInterrupt: 