In [11]:
import os
import numpy as np
import xml.etree.ElementTree as ET
from PIL import Image
from sklearn.metrics import average_precision_score
import time

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import torchvision.models as models
import torchvision.transforms as transforms

In [2]:
class VOCDataset(Dataset):
    def __init__(self, directory, split, transforms=None, multi_instance=False, verbose=False):
        self.split = split
        self.verbose = verbose
        self.directory = directory
        self.transforms = transforms
        self.multi_instance = multi_instance
        self.labels_dict = self.get_labels_dict()
        self.label_count, self.data = self._load_all_image_paths_labels(split)
        self.classes_count = self._count_classes()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self._load_image(self.data[idx]['image_path'])
        if self.transforms is not None:
            image = self.transforms(image)
        labels = self.data[idx]['labels']
        return (image, labels)

    def get_labels_dict(self):
        return {
            'aeroplane' :    0,
            'bicycle' :      1,
            'bird' :         2,
            'boat' :         3,
            'bottle' :       4,
            'bus' :          5,
            'car' :          6,
            'cat' :          7,
            'chair' :        8,
            'cow' :          9,
            'diningtable' :  10,
            'dog' :          11,
            'horse' :        12,
            'motorbike' :    13,
            'person' :       14,
            'pottedplant' :  15,
            'sheep' :        16,
            'sofa' :         17,
            'train' :        18,
            'tvmonitor' :    19
        }
    
    def plot_classes(self):
        import matplotlib.pyplot as plt
        count_dict = self._count_classes()
        x = count_dict.values()
        y = count_dict.keys()
        plt.figure(figsize=(20,20))
        plt.bar(y,x)
        plt.show()

    def _count_classes(self):
        count_dict = {x: 0 for x in self.labels_dict}
        for pairs in self.data:
            for label_list in pairs['labels']:
                for label in np.unique(label_list):
                    count_dict[label] += 1
        return count_dict

    def _load_image(self, image_path):
        img = Image.open(image_path)
        assert(img.mode == 'RGB')
        return img        

    def _get_images_list(self, split):
        image_paths = []
        image_path_file = os.path.join(self.directory, 'ImageSets/Main', split + '.txt')
        with open(image_path_file) as f:
            for image_path in f.readlines():
                candidate_path = image_path.split(' ')[0].strip('\n')
                image_paths.append(candidate_path)
        return image_paths

    def _get_xml_file_path(self, image_name):
        xml_name = image_name + '.xml'
        xml_path = os.path.join(self.directory, 'Annotations', xml_name)
        return xml_path
    
    def _get_labels_from_xml(self, xml_path):
        labels = []
        tree = ET.parse(xml_path)
        root = tree.getroot()
        for child in root.iter('object'):
            labels.append(child.find('name').text)
        return labels

    def _load_all_image_paths_labels(self, split):
        label_count = 0
        all_image_paths_labels = []
        images_list = self._get_images_list(split)
        xml_path_list = [self._get_xml_file_path(image_path)
                        for image_path in images_list]
        for image_path, xml_path in zip(images_list, xml_path_list):
            image_path = os.path.join(self.directory, 'JPEGImages', image_path + '.jpg')
            assert(image_path not in all_image_paths_labels)
            if self.multi_instance:
                labels = self._get_labels_from_xml(xml_path)
            else:
                labels = list(np.unique(self._get_labels_from_xml(xml_path)))
            label_count += len(labels)
            if self.verbose:
                print("Loading labels of size {} for {}...".format(
                    len(labels), image_path))
            image_path_labels = {'image_path': image_path,
                                 'labels': labels}
            all_image_paths_labels.append(image_path_labels)

        print("SET: {} | TOTAL IMAGES: {}".format(self.split, len(all_image_paths_labels)))
        print("SET: {} | TOTAL LABELS: {}".format(self.split, label_count))
        return label_count, all_image_paths_labels

In [3]:
class VOCBatch:
    def __init__(self, data):
        self.transposed_data = list(zip(*data))
        self.image = torch.stack(self.transposed_data[0], 0)
        self.labels = self.construct_int_labels()

    def construct_int_labels(self):
        remap_dict = VOCDataset.get_labels_dict(None)
        labels = self.transposed_data[1]
        batch_size = self.image.shape[0]
        num_classes = len(remap_dict)
        one_hot_int_labels = torch.zeros((batch_size, num_classes))
        for i in range(len(labels)):
            sample_labels = labels[i]
            one_hot = torch.zeros(num_classes)
            sample_int_labels = []
            for string_label in sample_labels:
                int_label = remap_dict[string_label]
                one_hot[int_label] = 1.
            one_hot_int_labels[i] = one_hot
        return one_hot_int_labels

    def pin_memory(self):
        self.image = self.image.pin_memory()
        self.labels = self.labels.pin_memory()
        return self

def collate_wrapper(batch):
    return VOCBatch(batch)

In [5]:
directory = 'VOC2012/'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


tr = transforms.Compose([transforms.RandomResizedCrop(300),
                             transforms.ToTensor(),
                             transforms.Normalize([0.4589, 0.4355, 0.4032],[0.2239, 0.2186, 0.2206])])

train_dataset = VOCDataset(directory, 'train', transforms=tr, multi_instance=False)
train_loader = DataLoader(train_dataset, batch_size=16, collate_fn=collate_wrapper, shuffle=True, num_workers=4)

val_dataset = VOCDataset(directory, 'trainval', transforms=tr, multi_instance=False)
val_loader = DataLoader(val_dataset, batch_size=16, collate_fn=collate_wrapper, shuffle=True, num_workers=4)

test_dataset = VOCDataset(directory, 'val', transforms=tr, multi_instance=False)
test_loader = DataLoader(test_dataset, batch_size=16, collate_fn=collate_wrapper, shuffle=True, num_workers=4)

SET: train | TOTAL IMAGES: 5717
SET: train | TOTAL LABELS: 8863
SET: trainval | TOTAL IMAGES: 11540
SET: trainval | TOTAL LABELS: 17719
SET: val | TOTAL IMAGES: 5823
SET: val | TOTAL LABELS: 8856


In [6]:
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(512, 20)
model.to(device)
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
loss_fun = nn.BCEWithLogitsLoss()

In [7]:
def train_epoch(model, device, train_loader, optimizer, num_epochs, val_loader, loss_fun):
    train_losses = []
    val_losses = []
    val_accuracies = []

    for epoch in range(1, num_epochs + 1):
        train_loss = train(model, device, train_loader, optimizer, epoch, loss_fun)
        val_loss, predictions, targets = validate(model, device, val_loader, epoch, loss_fun)

#         if (len(val_losses) > 0) and (val_loss < min(val_losses)):
#             torch.save(model.state_dict(), "best_model_B.pt")
#             print("Saving model (epoch {}) with lowest validation loss: {}"
#                   .format(epoch, val_loss))

#         train_losses.append(train_loss)
#         val_losses.append(val_loss)
#         val_accuracies.append(val_accuracy)

    print("Training and validation complete.")
    return train_losses, val_losses, val_accuracies

In [8]:
def train(model, device, train_loader, optimizer, epoch, loss_fun):
    model.train()
    train_losses = []
    for idx, batch in enumerate(train_loader):
        data = batch.image.to(device)
        target = batch.labels.to(device)
        optimizer.zero_grad()
        output = model(data)
#         print(target, torch.sigmoid(output))
        loss = loss_fun(output, target)
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())
        if idx % 5 == 0:
            print('Epoch: {}, Training_Samples: {}/{}, Loss: {}'.format(epoch, idx, len(train_loader), loss.item()))
    train_loss = torch.mean(torch.tensor(train_losses))
    print('\nEpoch: {}'.format(epoch))
    print('Training set: Average loss: {:.4f}'.format(train_loss))
    return train_loss

In [9]:
def validate(model, device, val_loader, epoch, loss_fun):
    model.eval()
    val_loss = 0
    correct = 0
    
    with torch.no_grad():
        for idx, batch in enumerate(val_loader):
            data = batch.image.to(device)
            target = batch.labels.to(device)
            output = model(data)
            
            # compute the batch loss
            batch_loss = loss_fun(output, target).item()
            val_loss += batch_loss
            pred = torch.sigmoid(output)
            
            if idx == 0:
                predictions = pred
                targets = target
            else:
                predictions = torch.cat((predictions, pred))
                targets = torch.cat((targets, target))
            if idx % 5 == 0:
                print('Epoch: {}, Validation_Samples: {}/{}, Loss: {}'.format(epoch, idx, len(val_loader), batch_loss))

    val_loss /= len(val_loader)
    print('\nEpoch: {}'.format(epoch))
    print('Validation set: Average loss: {:.4f}, AP: {:.4f})'.format(val_loss, 
                                                                     average_precision_score(targets.reshape(-1, 20).cpu(), 
                                                                                             predictions.reshape(-1, 20).cpu())))
    
    return val_loss, predictions, targets

In [10]:
start = time.time()
train_losses, val_losses, val_accuracies = train_epoch(model, device, train_loader, optimizer, 5, val_loader, loss_fun)

print(time.time() - start)

BrokenPipeError: [Errno 32] Broken pipe

In [10]:
# validate(model, device, val_loader, 1, loss_fun)