In [1]:
import torch.nn as nn
import torch.nn.functional as F
import net_size_utils as nsu


class Net(nn.Module):
    def __init__(self, kernel_size=5):
        super().__init__()
        self.kernel_size = kernel_size
        self.conv1 = nn.Conv2d(3, 6, kernel_size)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size)
        conv_height, conv_width = self.conv_out_size(224, 224)
        self.fc1 = nn.Linear(16 * conv_width * conv_height, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def conv_out_size(self, height, width):
        dim = (height, width)
        dim = nsu.dim_conv2d(dim, self.kernel_size)
        dim = nsu.dim_maxpool2d(dim, 2, 2)
        dim = nsu.dim_conv2d(dim, self.kernel_size)
        dim = nsu.dim_maxpool2d(dim, 2, 2)
        return dim


In [2]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False


In [3]:
from birds_dataset import Birds270Dataset

import net_training
import torch
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets, models

dataset_dir = "../data/birds270"
selected_birds = ["ALBATROSS", "BALD EAGLE", "BARN OWL", "EURASIAN MAGPIE", "FLAMINGO",
                  "MALLARD DUCK", "OSTRICH", "PEACOCK", "PELICAN", "TRUMPTER SWAN"]
train_transform= transforms.Compose([
    transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5))
])
test_transform = transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5)) # normalizes colors to range [-1,1]
train_dataset = Birds270Dataset(dataset_dir, set_type="train", selected_birds=selected_birds, transform=train_transform)
test_dataset = Birds270Dataset(dataset_dir, set_type=["test","valid"], selected_birds=selected_birds, transform=test_transform)

label_set = test_dataset.get_label_set()

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)


#net = Net()
#optimizer = optim.Adam(net.parameters(), lr=0.001)

feature_extract = False
net = models.alexnet(pretrained=True)
set_parameter_requires_grad(net, feature_extract)
num_ftrs = net.classifier[6].in_features
net.classifier[6] = nn.Linear(num_ftrs, len(selected_birds))

params_to_update = net.parameters()
print("Params to learn:")
if feature_extract:
    params_to_update = []
    for name,param in net.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)
else:
    for name,param in net.named_parameters():
        if param.requires_grad == True:
            print("\t",name)

optimizer = optim.Adam(params_to_update, lr=0.001)

net_training.train_and_evaluate(net, train_dataloader, test_dataloader,
                                label_set, epochs=100, optimizer=optimizer, print_results=True)


Downloading: "https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth" to C:\Users\Dominik/.cache\torch\hub\checkpoints\alexnet-owt-4df8aa71.pth
10.1%IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

26.6%IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

43.2%IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--Noteb

Params to learn:
	 features.0.weight
	 features.0.bias
	 features.3.weight
	 features.3.bias
	 features.6.weight
	 features.6.bias
	 features.8.weight
	 features.8.bias
	 features.10.weight
	 features.10.bias
	 classifier.1.weight
	 classifier.1.bias
	 classifier.4.weight
	 classifier.4.bias
	 classifier.6.weight
	 classifier.6.bias
Epoch 0:
	train loss: 2.3084616481994495
	validation loss: 2.128449172973633, validation accuracy: 21.0%
	Elapsed time: 0:00:31.976326
Epoch 1:
	train loss: 2.232446912813257
	validation loss: 2.1550396251678468, validation accuracy: 15.0%
	Elapsed time: 0:00:32.340040
Epoch 2:
	train loss: 2.2000913911434616
	validation loss: 2.2667873764038085, validation accuracy: 8.0%
	Elapsed time: 0:00:32.847337
Epoch 3:
	train loss: 2.14344205842278
	validation loss: 1.9138764905929566, validation accuracy: 32.0%
	Elapsed time: 0:00:33.433848
Epoch 4:
	train loss: 1.9841321392508775
	validation loss: 1.7069512271881104, validation accuracy: 38.0%
	Elapsed time: 0:00:

{'correct': 88,
 'total': 100,
 'accuracy': 0.88,
 'loss': 0.7615122032165528,
 'correct_labels': {'ALBATROSS': 9,
  'BALD EAGLE': 10,
  'BARN OWL': 8,
  'EURASIAN MAGPIE': 9,
  'FLAMINGO': 10,
  'MALLARD DUCK': 9,
  'OSTRICH': 9,
  'PEACOCK': 10,
  'PELICAN': 5,
  'TRUMPTER SWAN': 9},
 'total_labels': {'ALBATROSS': 10,
  'BALD EAGLE': 10,
  'BARN OWL': 10,
  'EURASIAN MAGPIE': 10,
  'FLAMINGO': 10,
  'MALLARD DUCK': 10,
  'OSTRICH': 10,
  'PEACOCK': 10,
  'PELICAN': 10,
  'TRUMPTER SWAN': 10},
 'label_accuracy': {'ALBATROSS': 0.9,
  'BALD EAGLE': 1.0,
  'BARN OWL': 0.8,
  'EURASIAN MAGPIE': 0.9,
  'FLAMINGO': 1.0,
  'MALLARD DUCK': 0.9,
  'OSTRICH': 0.9,
  'PEACOCK': 1.0,
  'PELICAN': 0.5,
  'TRUMPTER SWAN': 0.9}}