In [1]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False
def get_params_to_update(net, feature_extract):
    params_to_update = net.parameters()
    print("Params to learn:")
    if feature_extract:
        params_to_update = []
        for name,param in net.named_parameters():
            if param.requires_grad == True:
                params_to_update.append(param)
                print("\t",name)
    else:
        for name,param in net.named_parameters():
            if param.requires_grad == True:
                print("\t",name)
    return params_to_update

def make_alexnet(out_features, feature_extract = False):
    net = models.alexnet(pretrained=True)
    set_parameter_requires_grad(net, feature_extract)
    num_ftrs = net.classifier[6].in_features
    net.classifier[6] = nn.Linear(num_ftrs, out_features)
    params_to_update = get_params_to_update(net, feature_extract)
    optimizer = optim.Adam(params_to_update, lr=0.0001)
    return net, optimizer

def make_vgg16(out_features, feature_extract = False):
    net = models.vgg16_bn(pretrained=True)
    print(net)
    set_parameter_requires_grad(net, feature_extract)
    num_ftrs = net.classifier[6].in_features
    net.classifier[6] = nn.Linear(num_ftrs, out_features)
    params_to_update = get_params_to_update(net, feature_extract)
    optimizer = optim.Adam(params_to_update, lr=0.0001)
    return net, optimizer

def make_mobilenet_v3(out_features, feature_extract = False):
    net = models.mobilenet_v3_small(pretrained=True)
    set_parameter_requires_grad(net, feature_extract)
    num_ftrs = net.classifier[3].in_features
    net.classifier[3] = nn.Linear(num_ftrs, out_features)
    params_to_update = get_params_to_update(net, feature_extract)
    optimizer = optim.Adam(params_to_update, lr=0.0001)
    return net, optimizer

def make_inception_v3(out_features, feature_extract = False):
    net = models.inception_v3(pretrained=True)
    set_parameter_requires_grad(net, feature_extract)
    num_ftrs = net.AuxLogits.fc.in_features
    net.AuxLogits.fc = nn.Linear(num_ftrs, out_features)
    num_ftrs = net.fc.in_features
    net.fc = nn.Linear(num_ftrs, out_features)
    params_to_update = get_params_to_update(net, feature_extract)
    optimizer = optim.Adam(params_to_update, lr=0.00002)
    return net, optimizer


In [2]:
import statistics
import torch.utils.data as td
import net_training
from birds_dataset import Birds270Dataset
import torch
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import datasets, models
import pandas as pd

# Splits a dataset randomly into a train and test set. The size of test set is 80% of the whole dataset.
# Then it trains the network
def train_net_random_dataset_split(net, dataset, epochs, optimizer, batch_size, early_stopping, is_inception):
    train_set_size = int(len(dataset)*0.8)
    test_set_size = len(dataset)-train_set_size
    train_dataset, test_dataset = td.random_split(dataset, [train_set_size, test_set_size])
    label_set = dataset.get_label_set()
    train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True, drop_last=True)
    test_dataloader = DataLoader(test_dataset, batch_size, shuffle=True)
    device = None # Wybiera CUDA jeśli jest dostępne, w przeciwnym wypadku CPU
    # device = "cpu" # Odkomentować jeżeli CUDA nie będzie działać
    results = net_training.train_and_evaluate(net, train_dataloader, test_dataloader, label_set,early_stopping = early_stopping,
                                 epochs=epochs, optimizer=optimizer, print_results=True, device=device, is_inception=is_inception)
    return results

# Creates a few networks and trains them using a random split of the dataset.
# The number of created networks is in the "repeat" argument
# It returns the final validation results for each network
def cross_validate_net(net_generator, dataset, repeat=5, epochs=20, batch_size=32, is_inception=False):
    all_results = []
    for i in range(repeat):
        print(f"Training network {i+1} ...")
        net, optimizer = net_generator()
        early_stopping = net_training.EarlyStoppingByAccuracy(patience=10)
        results = train_net_random_dataset_split(net, dataset, epochs,optimizer=optimizer,
                                                 early_stopping=early_stopping, batch_size=batch_size, is_inception=is_inception)
        all_results.append(results)
        net_training.print_final_results(results)
        del optimizer, net
        torch.cuda.empty_cache()
    return all_results
    
    
def print_validation_results(cross_validation_results):
    losses = [r["loss"] for r in cross_validation_results]
    accuracies = [r["accuracy"] for r in cross_validation_results]
    print("Losses: ", losses)
    print("Loss:  mean: {:.4f}, std: {:.4f}".format(statistics.mean(losses), statistics.stdev(losses)))
    print("Accuracies: ", accuracies)
    print("Accuracy:  mean: {:.4f}, std: {:.4f}".format(statistics.mean(accuracies), statistics.stdev(accuracies)))
    
def results_to_dataframe(cross_validation_results):
    normalized = pd.json_normalize(cross_validation_results)
    return normalized
    


In [3]:
dataset_dir = "../data/birds270"
selected_birds = ["ALBATROSS", 
"BALD EAGLE", 
"BARN OWL", 
"EURASIAN MAGPIE", 
"FLAMINGO",                
"MALLARD DUCK", 
"OSTRICH", 
"PEACOCK", 
"PELICAN", 
"TRUMPTER SWAN",
"MASKED BOOBY",
"EURASIAN GOLDEN ORIOLE",
"MIKADO  PHEASANT",
"HOUSE FINCH",
"ROSY FACED LOVEBIRD",
"EASTERN BLUEBIRD",
"GREY PLOVER",
"INDIAN BUSTARD",
"CUBAN TODY",
"WATTLED CURASSOW",
"BLUE HERON",
"RED WISKERED BULBUL",
"RUBY THROATED HUMMINGBIRD",
"RED HEADED WOODPECKER",
"NORTHERN JACANA",
"GLOSSY IBIS",
"ANHINGA",
"GOLDEN CHLOROPHONIA",
"KING VULTURE",
"TURQUOISE MOTMOT",
"KAKAPO",
"ELEGANT TROGON",
"WHITE TAILED TROPIC",
"NICOBAR PIGEON",
"MYNA",
"SAND MARTIN",
"BARRED PUFFBIRD",
"UMBRELLA BIRD",
"CAPUCHINBIRD",
"INDIGO BUNTING",
"RAINBOW LORIKEET",
"BIRD OF PARADISE",
"HOOPOES",
"WILSONS BIRD OF PARADISE",
"GUINEAFOWL",
"JAVA SPARROW",
"INDIAN PITTA",
"ROYAL FLYCATCHER",
"RAINBOW LORIKEET",
"CALIFORNIA CONDOR"]

#transform = transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5)) # normalizes colors to range [-1,1]
transform =  transforms.Compose([
    transforms.Normalize((0, 0, 0), (255, 255, 255)),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

transform_inception =  transforms.Compose([
    transforms.Resize([299, 299]),
    transforms.Normalize((0, 0, 0), (255, 255, 255)),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

#dataset = Birds270Dataset(dataset_dir,  selected_birds=selected_birds, transform=transform)
# inception
dataset = Birds270Dataset(dataset_dir,  selected_birds=selected_birds, transform=transform_inception)

def net_generator():
    return make_inception_v3(out_features=len(selected_birds), feature_extract = False)

results = cross_validate_net(net_generator, dataset, repeat=5, epochs=100, batch_size=64, is_inception=True)
print_validation_results(results)

Training network 1 ...
Params to learn:
	 Conv2d_1a_3x3.conv.weight
	 Conv2d_1a_3x3.bn.weight
	 Conv2d_1a_3x3.bn.bias
	 Conv2d_2a_3x3.conv.weight
	 Conv2d_2a_3x3.bn.weight
	 Conv2d_2a_3x3.bn.bias
	 Conv2d_2b_3x3.conv.weight
	 Conv2d_2b_3x3.bn.weight
	 Conv2d_2b_3x3.bn.bias
	 Conv2d_3b_1x1.conv.weight
	 Conv2d_3b_1x1.bn.weight
	 Conv2d_3b_1x1.bn.bias
	 Conv2d_4a_3x3.conv.weight
	 Conv2d_4a_3x3.bn.weight
	 Conv2d_4a_3x3.bn.bias
	 Mixed_5b.branch1x1.conv.weight
	 Mixed_5b.branch1x1.bn.weight
	 Mixed_5b.branch1x1.bn.bias
	 Mixed_5b.branch5x5_1.conv.weight
	 Mixed_5b.branch5x5_1.bn.weight
	 Mixed_5b.branch5x5_1.bn.bias
	 Mixed_5b.branch5x5_2.conv.weight
	 Mixed_5b.branch5x5_2.bn.weight
	 Mixed_5b.branch5x5_2.bn.bias
	 Mixed_5b.branch3x3dbl_1.conv.weight
	 Mixed_5b.branch3x3dbl_1.bn.weight
	 Mixed_5b.branch3x3dbl_1.bn.bias
	 Mixed_5b.branch3x3dbl_2.conv.weight
	 Mixed_5b.branch3x3dbl_2.bn.weight
	 Mixed_5b.branch3x3dbl_2.bn.bias
	 Mixed_5b.branch3x3dbl_3.conv.weight
	 Mixed_5b.branch3x3dbl_3

Epoch 0:
	train loss: 4.697661951091657
	validation loss: 2.7425551798777583, validation accuracy: 83.61316319677637%
	Elapsed time: 0:22:17.447268
Epoch 1:
	train loss: 2.846172688409058
	validation loss: 1.4814245029133066, validation accuracy: 92.00805910006716%
	Elapsed time: 0:22:31.477252
Epoch 2:
	train loss: 1.5326712819962427
	validation loss: 0.7472353071234545, validation accuracy: 95.03022162525184%
	Elapsed time: 0:23:15.114644
Epoch 3:
	train loss: 0.8242733104878186
	validation loss: 0.4227543133389157, validation accuracy: 96.91067830758898%
	Elapsed time: 0:22:46.908485
Epoch 4:
	train loss: 0.49325548354946186
	validation loss: 0.2775372736020886, validation accuracy: 97.58226997985226%
	Elapsed time: 0:24:46.320653
Epoch 5:
	train loss: 0.3135494422496076
	validation loss: 0.20287913529864088, validation accuracy: 97.85090664875756%
	Elapsed time: 0:23:02.468499
Epoch 6:
	train loss: 0.21444970868964686
	validation loss: 0.1656219836767285, validation accuracy: 97.71

Epoch 0:
	train loss: 4.7640982726712435
	validation loss: 2.7794279683909213, validation accuracy: 85.35930154466085%
	Elapsed time: 0:22:06.850555
Epoch 1:
	train loss: 2.9289808178797063
	validation loss: 1.490655139839193, validation accuracy: 93.82135661517798%
	Elapsed time: 0:22:06.327194
Epoch 2:
	train loss: 1.5680231805330722
	validation loss: 0.725721160063734, validation accuracy: 95.76897246474144%
	Elapsed time: 0:22:08.565222
Epoch 3:
	train loss: 0.8348348006548529
	validation loss: 0.38796190249991946, validation accuracy: 97.1793149764943%
	Elapsed time: 0:22:19.223195
Epoch 4:
	train loss: 0.48540298337598375
	validation loss: 0.25557132977979785, validation accuracy: 97.64942914707858%
	Elapsed time: 0:22:21.219949
Epoch 5:
	train loss: 0.31689869951004346
	validation loss: 0.18799784893634097, validation accuracy: 97.58226997985226%
	Elapsed time: 0:22:10.442384
Epoch 6:
	train loss: 0.21224793565988181
	validation loss: 0.15243067477796604, validation accuracy: 97

Epoch 0:
	train loss: 4.762658584264073
	validation loss: 2.7799013592878232, validation accuracy: 82.87441235728676%
	Elapsed time: 0:22:22.348679
Epoch 1:
	train loss: 2.894596378602602
	validation loss: 1.521173748095777, validation accuracy: 93.619879113499%
	Elapsed time: 0:22:49.640863
Epoch 2:
	train loss: 1.5638938095115842
	validation loss: 0.744969850302223, validation accuracy: 95.97044996642042%
	Elapsed time: 0:22:37.944569
Epoch 3:
	train loss: 0.8393487339157548
	validation loss: 0.42656595407515424, validation accuracy: 97.1793149764943%
	Elapsed time: 0:22:32.607628
Epoch 4:
	train loss: 0.5018242250144852
	validation loss: 0.274971896275888, validation accuracy: 97.51511081262592%
	Elapsed time: 0:22:25.775342
Epoch 5:
	train loss: 0.31779723966019546
	validation loss: 0.19786386034727257, validation accuracy: 97.9852249832102%
	Elapsed time: 0:22:35.602801
Epoch 6:
	train loss: 0.22226414134997832
	validation loss: 0.1551226332603564, validation accuracy: 98.32102081

Epoch 0:
	train loss: 4.733941217720779
	validation loss: 2.7399869991517694, validation accuracy: 82.26997985224983%
	Elapsed time: 0:29:32.716143
Epoch 1:
	train loss: 2.8877136753363586
	validation loss: 1.4730599531797532, validation accuracy: 92.00805910006716%
	Elapsed time: 0:29:38.262633
Epoch 2:
	train loss: 1.5509690192293244
	validation loss: 0.717864675798858, validation accuracy: 95.70181329751512%
	Elapsed time: 0:30:01.502224
Epoch 3:
	train loss: 0.8267808597307735
	validation loss: 0.40326543459562425, validation accuracy: 97.1793149764943%
	Elapsed time: 0:29:29.974159
Epoch 4:
	train loss: 0.4977965000075051
	validation loss: 0.2579684891181956, validation accuracy: 97.9852249832102%
	Elapsed time: 0:29:25.396963
Epoch 5:
	train loss: 0.3132325764686024
	validation loss: 0.189087889933442, validation accuracy: 98.11954331766286%
	Elapsed time: 0:29:24.662079
Epoch 6:
	train loss: 0.2166998248210018
	validation loss: 0.15154391985018834, validation accuracy: 98.388179

Params to learn:
	 Conv2d_1a_3x3.conv.weight
	 Conv2d_1a_3x3.bn.weight
	 Conv2d_1a_3x3.bn.bias
	 Conv2d_2a_3x3.conv.weight
	 Conv2d_2a_3x3.bn.weight
	 Conv2d_2a_3x3.bn.bias
	 Conv2d_2b_3x3.conv.weight
	 Conv2d_2b_3x3.bn.weight
	 Conv2d_2b_3x3.bn.bias
	 Conv2d_3b_1x1.conv.weight
	 Conv2d_3b_1x1.bn.weight
	 Conv2d_3b_1x1.bn.bias
	 Conv2d_4a_3x3.conv.weight
	 Conv2d_4a_3x3.bn.weight
	 Conv2d_4a_3x3.bn.bias
	 Mixed_5b.branch1x1.conv.weight
	 Mixed_5b.branch1x1.bn.weight
	 Mixed_5b.branch1x1.bn.bias
	 Mixed_5b.branch5x5_1.conv.weight
	 Mixed_5b.branch5x5_1.bn.weight
	 Mixed_5b.branch5x5_1.bn.bias
	 Mixed_5b.branch5x5_2.conv.weight
	 Mixed_5b.branch5x5_2.bn.weight
	 Mixed_5b.branch5x5_2.bn.bias
	 Mixed_5b.branch3x3dbl_1.conv.weight
	 Mixed_5b.branch3x3dbl_1.bn.weight
	 Mixed_5b.branch3x3dbl_1.bn.bias
	 Mixed_5b.branch3x3dbl_2.conv.weight
	 Mixed_5b.branch3x3dbl_2.bn.weight
	 Mixed_5b.branch3x3dbl_2.bn.bias
	 Mixed_5b.branch3x3dbl_3.conv.weight
	 Mixed_5b.branch3x3dbl_3.bn.weight
	 Mixed_5b.b

Epoch 0:
	train loss: 4.750112819687651
	validation loss: 2.761827135822611, validation accuracy: 84.7548690396239%
	Elapsed time: 0:29:35.886421
Epoch 1:
	train loss: 2.925191711702608
	validation loss: 1.4513391048977085, validation accuracy: 92.27669576897246%
	Elapsed time: 0:29:23.690211
Epoch 2:
	train loss: 1.5667989927240926
	validation loss: 0.6975156161271141, validation accuracy: 95.29885829415716%
	Elapsed time: 0:29:20.328072
Epoch 3:
	train loss: 0.8454515042874506
	validation loss: 0.3919531124962345, validation accuracy: 96.57488247145736%
	Elapsed time: 0:29:20.181681
Epoch 4:
	train loss: 0.5073361623473435
	validation loss: 0.2612058224195278, validation accuracy: 97.51511081262592%
	Elapsed time: 0:29:22.507216
Epoch 5:
	train loss: 0.3196689867480711
	validation loss: 0.19101824349729385, validation accuracy: 97.7165883143049%
	Elapsed time: 0:29:17.495211
Epoch 6:
	train loss: 0.22151927364867774
	validation loss: 0.1528165080829144, validation accuracy: 97.783747

In [4]:
dataframe = results_to_dataframe(results)
dataframe.to_csv("../results/inception_v3_pretrained_50_species.csv")