In [1]:
import torch.nn as nn
import torch.nn.functional as F
import net_size_utils as nsu


class Net(nn.Module):
    def __init__(self, kernel_size=5):
        super().__init__()
        self.kernel_size = kernel_size
        self.conv1 = nn.Conv2d(3, 6, kernel_size)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size)
        conv_height, conv_width = self.conv_out_size(224, 224)
        self.fc1 = nn.Linear(16 * conv_width * conv_height, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def conv_out_size(self, height, width):
        dim = (height, width)
        dim = nsu.dim_conv2d(dim, self.kernel_size)
        dim = nsu.dim_maxpool2d(dim, 2, 2)
        dim = nsu.dim_conv2d(dim, self.kernel_size)
        dim = nsu.dim_maxpool2d(dim, 2, 2)
        return dim

In [2]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

In [3]:
from birds_dataset import Birds270Dataset

import net_training
import torch
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets, models

dataset_dir = "../data/birds270"
selected_birds = ["ALBATROSS", 
"BALD EAGLE", 
"BARN OWL", 
"EURASIAN MAGPIE", 
"FLAMINGO",                
"MALLARD DUCK", 
"OSTRICH", 
"PEACOCK", 
"PELICAN", 
"TRUMPTER SWAN",
"MASKED BOOBY",
"EURASIAN GOLDEN ORIOLE",
"MIKADO  PHEASANT",
"HOUSE FINCH",
"ROSY FACED LOVEBIRD",
"EASTERN BLUEBIRD",
"GREY PLOVER",
"INDIAN BUSTARD",
"CUBAN TODY",
"WATTLED CURASSOW",
"BLUE HERON",
"RED WISKERED BULBUL",
"RUBY THROATED HUMMINGBIRD",
"RED HEADED WOODPECKER",
"NORTHERN JACANA",
"GLOSSY IBIS",
"ANHINGA",
"GOLDEN CHLOROPHONIA",
"KING VULTURE",
"TURQUOISE MOTMOT",
"KAKAPO",
"ELEGANT TROGON",
"WHITE TAILED TROPIC",
"NICOBAR PIGEON",
"MYNA",
"SAND MARTIN",
"BARRED PUFFBIRD",
"UMBRELLA BIRD",
"CAPUCHINBIRD",
"INDIGO BUNTING",
"RAINBOW LORIKEET",
"BIRD OF PARADISE",
"HOOPOES",
"WILSONS BIRD OF PARADISE",
"GUINEAFOWL",
"JAVA SPARROW",
"INDIAN PITTA",
"ROYAL FLYCATCHER",
"RAINBOW LORIKEET",
"CALIFORNIA CONDOR"]

train_transform= transforms.Compose([
    transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5))
])
test_transform = transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5)) # normalizes colors to range [-1,1]
train_dataset = Birds270Dataset(dataset_dir, set_type="train", selected_birds=selected_birds, transform=train_transform)
test_dataset = Birds270Dataset(dataset_dir, set_type=["test","valid"], selected_birds=selected_birds, transform=test_transform)

label_set = test_dataset.get_label_set()

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)


#net = Net()
#optimizer = optim.Adam(net.parameters(), lr=0.001)

feature_extract = False
net = models.resnet18(pretrained=True)
set_parameter_requires_grad(net, feature_extract)
num_ftrs =  net.fc.in_features
net.fc = nn.Linear(num_ftrs, len(selected_birds))

params_to_update = net.parameters()
print("Params to learn:")
if feature_extract:
    params_to_update = []
    for name,param in net.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)
else:
    for name,param in net.named_parameters():
        if param.requires_grad == True:
            print("\t",name)

optimizer = optim.Adam(params_to_update, lr=0.0001)

net_training.train_and_evaluate(net, train_dataloader, test_dataloader,
                                label_set, epochs=100, optimizer=optimizer, print_results=True)

Params to learn:
	 conv1.weight
	 bn1.weight
	 bn1.bias
	 layer1.0.conv1.weight
	 layer1.0.bn1.weight
	 layer1.0.bn1.bias
	 layer1.0.conv2.weight
	 layer1.0.bn2.weight
	 layer1.0.bn2.bias
	 layer1.1.conv1.weight
	 layer1.1.bn1.weight
	 layer1.1.bn1.bias
	 layer1.1.conv2.weight
	 layer1.1.bn2.weight
	 layer1.1.bn2.bias
	 layer2.0.conv1.weight
	 layer2.0.bn1.weight
	 layer2.0.bn1.bias
	 layer2.0.conv2.weight
	 layer2.0.bn2.weight
	 layer2.0.bn2.bias
	 layer2.0.downsample.0.weight
	 layer2.0.downsample.1.weight
	 layer2.0.downsample.1.bias
	 layer2.1.conv1.weight
	 layer2.1.bn1.weight
	 layer2.1.bn1.bias
	 layer2.1.conv2.weight
	 layer2.1.bn2.weight
	 layer2.1.bn2.bias
	 layer3.0.conv1.weight
	 layer3.0.bn1.weight
	 layer3.0.bn1.bias
	 layer3.0.conv2.weight
	 layer3.0.bn2.weight
	 layer3.0.bn2.bias
	 layer3.0.downsample.0.weight
	 layer3.0.downsample.1.weight
	 layer3.0.downsample.1.bias
	 layer3.1.conv1.weight
	 layer3.1.bn1.weight
	 layer3.1.bn1.bias
	 layer3.1.conv2.weight
	 layer3.1.b

{'correct': 489,
 'total': 490,
 'accuracy': 0.9979591836734694,
 'loss': 0.008608988314220797,
 'correct_labels': {'ALBATROSS': 10,
  'ANHINGA': 10,
  'BALD EAGLE': 10,
  'BARN OWL': 10,
  'BARRED PUFFBIRD': 10,
  'BIRD OF PARADISE': 10,
  'BLUE HERON': 10,
  'CALIFORNIA CONDOR': 10,
  'CAPUCHINBIRD': 10,
  'CUBAN TODY': 10,
  'EASTERN BLUEBIRD': 10,
  'ELEGANT TROGON': 10,
  'EURASIAN GOLDEN ORIOLE': 10,
  'EURASIAN MAGPIE': 10,
  'FLAMINGO': 10,
  'GLOSSY IBIS': 10,
  'GOLDEN CHLOROPHONIA': 10,
  'GREY PLOVER': 10,
  'GUINEAFOWL': 10,
  'HOOPOES': 10,
  'HOUSE FINCH': 10,
  'INDIAN BUSTARD': 10,
  'INDIAN PITTA': 10,
  'INDIGO BUNTING': 10,
  'JAVA SPARROW': 10,
  'KAKAPO': 10,
  'KING VULTURE': 10,
  'MALLARD DUCK': 10,
  'MASKED BOOBY': 9,
  'MIKADO  PHEASANT': 10,
  'MYNA': 10,
  'NICOBAR PIGEON': 10,
  'NORTHERN JACANA': 10,
  'OSTRICH': 10,
  'PEACOCK': 10,
  'PELICAN': 10,
  'RAINBOW LORIKEET': 10,
  'RED HEADED WOODPECKER': 10,
  'RED WISKERED BULBUL': 10,
  'ROSY FACED LOVEB