In [1]:
import torch.nn as nn
import torch.nn.functional as F
import net_size_utils as nsu


class Net(nn.Module):
    def __init__(self, kernel_size=5, out_features=10):
        super().__init__()
        self.kernel_size = kernel_size
        self.conv1 = nn.Conv2d(3, 6, kernel_size)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size)
        conv_height, conv_width = self.conv_out_size(224, 224)
        self.fc1 = nn.Linear(16 * conv_width * conv_height, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, out_features)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def conv_out_size(self, height, width):
        dim = (height, width)
        dim = nsu.dim_conv2d(dim, self.kernel_size)
        dim = nsu.dim_maxpool2d(dim, 2, 2)
        dim = nsu.dim_conv2d(dim, self.kernel_size)
        dim = nsu.dim_maxpool2d(dim, 2, 2)
        return dim


In [2]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

def get_params_to_update(net, feature_extract):
    params_to_update = net.parameters()
    print("Params to learn:")
    if feature_extract:
        params_to_update = []
        for name,param in net.named_parameters():
            if param.requires_grad == True:
                params_to_update.append(param)
                print("\t",name)
    else:
        for name,param in net.named_parameters():
            if param.requires_grad == True:
                print("\t",name)
    return params_to_update

def make_alexnet(out_features, feature_extract = False):
    net = models.alexnet(pretrained=True)
    set_parameter_requires_grad(net, feature_extract)
    num_ftrs = net.classifier[6].in_features
    net.classifier[6] = nn.Linear(num_ftrs, out_features)
    params_to_update = get_params_to_update(net, feature_extract)
    optimizer = optim.Adam(params_to_update, lr=0.0001)
    return net, optimizer

def make_mobilenet_v3(out_features, feature_extract = False):
    net = models.mobilenet_v3_small(pretrained=True)
    set_parameter_requires_grad(net, feature_extract)
    num_ftrs = net.classifier[3].in_features
    net.classifier[3] = nn.Linear(num_ftrs, out_features)
    params_to_update = get_params_to_update(net, feature_extract)
    optimizer = optim.Adam(params_to_update, lr=0.0001)
    return net, optimizer


In [None]:
from birds_dataset import Birds270Dataset

import net_training
import torch
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets, models

dataset_dir = "../data/birds270"
selected_birds = ["ALBATROSS", "BALD EAGLE", "BARN OWL", "EURASIAN MAGPIE", "FLAMINGO",
                  "MALLARD DUCK", "OSTRICH", "PEACOCK", "PELICAN", "TRUMPTER SWAN"]
train_transform= transforms.Compose([
    transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5))
])
test_transform = transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5)) # normalizes colors to range [-1,1]
train_dataset = Birds270Dataset(dataset_dir, set_type="train", selected_birds=selected_birds, transform=train_transform)
test_dataset = Birds270Dataset(dataset_dir, set_type=["test","valid"], selected_birds=selected_birds, transform=test_transform)

label_set = test_dataset.get_label_set()

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)


#net = Net()
#optimizer = optim.Adam(net.parameters(), lr=0.001)

net, optimizer = make_mobilenet_v3(out_features=len(selected_birds), feature_extract = False)

net_training.train_and_evaluate(net, train_dataloader, test_dataloader,
                                label_set, epochs=100, optimizer=optimizer, print_results=True)


Params to learn:
	 features.0.0.weight
	 features.0.1.weight
	 features.0.1.bias
	 features.1.block.0.0.weight
	 features.1.block.0.1.weight
	 features.1.block.0.1.bias
	 features.1.block.1.fc1.weight
	 features.1.block.1.fc1.bias
	 features.1.block.1.fc2.weight
	 features.1.block.1.fc2.bias
	 features.1.block.2.0.weight
	 features.1.block.2.1.weight
	 features.1.block.2.1.bias
	 features.2.block.0.0.weight
	 features.2.block.0.1.weight
	 features.2.block.0.1.bias
	 features.2.block.1.0.weight
	 features.2.block.1.1.weight
	 features.2.block.1.1.bias
	 features.2.block.2.0.weight
	 features.2.block.2.1.weight
	 features.2.block.2.1.bias
	 features.3.block.0.0.weight
	 features.3.block.0.1.weight
	 features.3.block.0.1.bias
	 features.3.block.1.0.weight
	 features.3.block.1.1.weight
	 features.3.block.1.1.bias
	 features.3.block.2.0.weight
	 features.3.block.2.1.weight
	 features.3.block.2.1.bias
	 features.4.block.0.0.weight
	 features.4.block.0.1.weight
	 features.4.block.0.1.bias
	 f