In [30]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import os
import pandas as pd
import matplotlib.pyplot as plt
from torchvision.io import read_image

class Birds270Dataset(Dataset):
    def make_labels(self, csv_table):
        bird_str_labels = csv_table["labels"].unique()
        self.labels_str_to_int = {label:i  for i, label in enumerate(bird_str_labels)}
        self.labels_int_to_str = {i:label  for i, label in enumerate(bird_str_labels)}
        
    def __init__(self, dataset_dir, set_type="train", transform=None, selected_birds=None):
        csv_table = pd.read_csv(os.path.join(dataset_dir, "birds.csv"))
        if selected_birds != None:
            csv_table = csv_table[csv_table["labels"].isin(selected_birds)]
        self.img_data = csv_table[csv_table["data set"]==set_type]
        self.make_labels(self.img_data)
        self.dataset_dir = dataset_dir
        self.transform = transform

    def __len__(self):
        return len(self.img_data)

    def __getitem__(self, idx):
        filepaths_col = self.img_data.columns.get_loc("filepaths")
        labels_col = self.img_data.columns.get_loc("labels")
        img_path = os.path.join(self.dataset_dir, self.img_data.iat[idx, filepaths_col])
        image = read_image(img_path).float()
        label = self.img_data.iat[idx, labels_col]
        if self.transform:
            image = self.transform(image)
        int_label = self.labels_str_to_int[label]
        return image, int_label
    



In [41]:
def make2d(int_or_tuple):
    if(type(int_or_tuple) is int):
        return (int_or_tuple,int_or_tuple)
    else:
        return int_or_tuple

def dim_conv2d(size_in, kernel_size, stride=1, padding=0, dilation=1):
    kernel_size = make2d(kernel_size)
    stride = make2d(stride)
    padding = make2d(padding)
    dilation = make2d(dilation)
    height_out = int((size_in[0] + 2*padding[0] - dilation[0]*(kernel_size[0]-1) - 1)/stride[0] + 1)
    width_out = int((size_in[1] + 2*padding[1] - dilation[1]*(kernel_size[1]-1) - 1)/stride[1] + 1)
    return (height_out, width_out)

def dim_maxpool2d(size_in, kernel_size, stride=None, padding=0, dilation=1):
    if stride == None:
        stride = kernel_size
    # The formula is the same as for dim_conv2d
    return dim_conv2d(size_in, kernel_size, stride=stride, padding=padding, dilation=dilation)



print(conv_out_size(32,32))


(5, 5)


In [46]:
dataset_dir = "../data/birds270"
selected_birds = ["ALBATROSS", "BALD EAGLE", "BARN OWL", "EURASIAN MAGPIE", "FLAMINGO",
                  "MALLARD DUCK", "OSTRICH", "PEACOCK", "PELICAN", "TRUMPTER SWAN"]
tr = transforms.Normalize((127.5, 127.5, 127.5), (127.5, 127.5, 127.5)) # normalizes colors to range [-1,1]
train_dataset = Birds270Dataset(dataset_dir, set_type="train", selected_birds=selected_birds, transform=tr)
test_dataset = Birds270Dataset(dataset_dir, set_type="test", selected_birds=selected_birds, transform=tr)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import datetime

kernel_size = 5

def conv_out_size(height, width):
    dim = (height, width)
    dim = dim_conv2d(dim, kernel_size)
    dim = dim_maxpool2d(dim, 2, 2)
    dim = dim_conv2d(dim, kernel_size)
    dim = dim_maxpool2d(dim, 2, 2)
    return dim



class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, kernel_size)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size)
        conv_height, conv_width = conv_out_size(224, 224)
        self.fc1 = nn.Linear(16 * conv_width * conv_height, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()

criterion = nn.CrossEntropyLoss()
#optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
optimizer = optimizer = optim.Adam(net.parameters())

for epoch in range(20):  # loop over the dataset multiple times
    start_time = datetime.datetime.now()
    net.train()
    for i, data in enumerate(train_dataloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    net.eval()
    correct = 0
    total = 0
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for data in test_dataloader:
            images, labels = data
            # calculate outputs by running images through the network
            outputs = net(images)
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    time_elapsed = datetime.datetime.now() - start_time
    print(f"Accuracy of the network on epoch {epoch}: {100 * correct / total} %. Epoch time: {time_elapsed}")

print('Finished Training')

Accuracy of the network on epoch 0: 56.0 %. Epoch time: 0:00:10.223239
Accuracy of the network on epoch 1: 70.0 %. Epoch time: 0:00:10.515220
Accuracy of the network on epoch 2: 68.0 %. Epoch time: 0:00:10.682823
Accuracy of the network on epoch 3: 74.0 %. Epoch time: 0:00:10.538838
Accuracy of the network on epoch 4: 80.0 %. Epoch time: 0:00:10.445501
Accuracy of the network on epoch 5: 72.0 %. Epoch time: 0:00:10.427835
Accuracy of the network on epoch 6: 76.0 %. Epoch time: 0:00:10.507292
Accuracy of the network on epoch 7: 80.0 %. Epoch time: 0:00:10.467949
Accuracy of the network on epoch 8: 80.0 %. Epoch time: 0:00:10.439531
Accuracy of the network on epoch 9: 80.0 %. Epoch time: 0:00:10.527882
Accuracy of the network on epoch 10: 80.0 %. Epoch time: 0:00:10.473555
Accuracy of the network on epoch 11: 80.0 %. Epoch time: 0:00:10.554042
Accuracy of the network on epoch 12: 80.0 %. Epoch time: 0:00:10.445685
Accuracy of the network on epoch 13: 80.0 %. Epoch time: 0:00:10.508557
Ac