In [1]:
import torch.nn as nn
import torch.nn.functional as F
import net_size_utils as nsu


class Net(nn.Module):
    def __init__(self, kernel_size=5):
        super().__init__()
        self.kernel_size = kernel_size
        self.conv1 = nn.Conv2d(3, 6, kernel_size)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size)
        conv_height, conv_width = self.conv_out_size(224, 224)
        self.fc1 = nn.Linear(16 * conv_width * conv_height, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def conv_out_size(self, height, width):
        dim = (height, width)
        dim = nsu.dim_conv2d(dim, self.kernel_size)
        dim = nsu.dim_maxpool2d(dim, 2, 2)
        dim = nsu.dim_conv2d(dim, self.kernel_size)
        dim = nsu.dim_maxpool2d(dim, 2, 2)
        return dim


In [5]:
from birds_dataset import Birds270Dataset

import net_training
import torch
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

dataset_dir = "../data/birds270"
selected_birds = ["ALBATROSS", "BALD EAGLE", "BARN OWL", "EURASIAN MAGPIE", "FLAMINGO",
                  "MALLARD DUCK", "OSTRICH", "PEACOCK", "PELICAN", "TRUMPTER SWAN"]
train_transform= transforms.Compose([
    transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5))
])
test_transform = transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5)) # normalizes colors to range [-1,1]
train_dataset = Birds270Dataset(dataset_dir, set_type="train", selected_birds=selected_birds, transform=train_transform)
test_dataset = Birds270Dataset(dataset_dir, set_type=["test","valid"], selected_birds=selected_birds, transform=test_transform)

label_set = test_dataset.get_label_set()

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)

net = Net()
optimizer = optim.Adam(net.parameters(), lr=0.001)
net_training.train_and_evaluate(net, train_dataloader, test_dataloader, label_set, epochs=50, print_results=True)


Epoch 0:
	train loss: 1.8189263938979654
	validation loss: 1.2686478662490845, validation accuracy: 56.00000000000001%
	Elapsed time: 0:00:11.826597
Epoch 1:
	train loss: 1.167449516909463
	validation loss: 0.8719316911697388, validation accuracy: 66.0%
	Elapsed time: 0:00:11.809948
Epoch 2:
	train loss: 0.7949531539199278
	validation loss: 0.774615193605423, validation accuracy: 74.0%
	Elapsed time: 0:00:11.486090
Epoch 3:
	train loss: 0.5307092223321918
	validation loss: 0.7790631699562073, validation accuracy: 73.0%
	Elapsed time: 0:00:11.926077
Epoch 4:
	train loss: 0.3450522891990566
	validation loss: 0.8453554654121399, validation accuracy: 80.0%
	Elapsed time: 0:00:11.012202
Epoch 5:
	train loss: 0.1657464779002677
	validation loss: 1.011232476234436, validation accuracy: 73.0%
	Elapsed time: 0:00:11.123501
Epoch 6:
	train loss: 0.07895599642496924
	validation loss: 0.9581988191604615, validation accuracy: 80.0%
	Elapsed time: 0:00:12.360266
Epoch 7:
	train loss: 0.0416811102718

{'correct': 79,
 'total': 100,
 'accuracy': 0.79,
 'loss': 1.157432942390442,
 'correct_labels': {'ALBATROSS': 8,
  'BALD EAGLE': 7,
  'BARN OWL': 8,
  'EURASIAN MAGPIE': 8,
  'FLAMINGO': 10,
  'MALLARD DUCK': 7,
  'OSTRICH': 9,
  'PEACOCK': 10,
  'PELICAN': 4,
  'TRUMPTER SWAN': 8},
 'total_labels': {'ALBATROSS': 10,
  'BALD EAGLE': 10,
  'BARN OWL': 10,
  'EURASIAN MAGPIE': 10,
  'FLAMINGO': 10,
  'MALLARD DUCK': 10,
  'OSTRICH': 10,
  'PEACOCK': 10,
  'PELICAN': 10,
  'TRUMPTER SWAN': 10},
 'label_accuracy': {'ALBATROSS': 0.8,
  'BALD EAGLE': 0.7,
  'BARN OWL': 0.8,
  'EURASIAN MAGPIE': 0.8,
  'FLAMINGO': 1.0,
  'MALLARD DUCK': 0.7,
  'OSTRICH': 0.9,
  'PEACOCK': 1.0,
  'PELICAN': 0.4,
  'TRUMPTER SWAN': 0.8}}