In [1]:
from torch.utils.data import Dataset
import torchvision
import torch
import os
from PIL import Image
from torchvision import transforms
import torch.nn as nn
import json
import pickle
from torch.optim import Adam
import argparse
import copy
import numpy as np
import time
from tqdm import tqdm

<h3> Get mappings </h3>

In [2]:
on2onlabel = dict()
for idx, name in enumerate(os.listdir('/storage/jalverio/objectnet-oct-24-d123')):
    on2onlabel[name] = idx
onlabel2name = {v: k for k, v in on2onlabel.items()}

In [3]:
with open('/storage/jalverio/resnet/objectnet2torch.pkl', 'rb') as f:
    objectnet2torch = pickle.load(f)
torch2objectnet = dict()
for objectnet_name, label_list in objectnet2torch.items():
    for label in label_list:
        torch2objectnet[label] = objectnet_name

In [4]:
with open('/storage/jalverio/resnet/dirname_to_objectnet_name.json') as f:
    dirname_to_classname = json.load(f)

with open('/storage/jalverio/resnet/objectnet_subset_to_objectnet_id') as f:
    oncompressed2onlabel = eval(f.read())
    onlabel2oncompressed = {int(v):int(k) for k,v in oncompressed2onlabel.items()}


<h3> Build Dataset </h3>

In [5]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
transformations = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])

In [6]:
# 113, 313

In [7]:
# class Objectnet(Dataset):
#     """Dataset wrapping images and target labels for Kaggle - Planet Amazon from Space competition.
#     Arguments:
#         A CSV file path
#         Path to image folder
#         Extension of images
#         PIL transforms
#     """

#     def __init__(self, root, transform, objectnet2imagenet, imagenet2torch):
#         self.root = root
#         self.transform = transform
#         self.images = []
#         success_counter = 0
#         for dirname in os.listdir(root):
#             class_name = dirname.replace('/', '_').replace('-', '_').replace(' ', '_').lower().replace("'", '')
#             if class_name not in objectnet2imagenet:
#                 continue
#             success_counter += 1
#             labels = objectnet2imagenet[class_name]
#             new_labels = []
#             for label in labels:
#                 new_labels.append(int(imagenet2torch[label - 1]))

#             images = os.listdir(os.path.join(root, dirname))
#             for image_name in images:
#                 path = os.path.join(root, dirname, image_name)
#                 self.images.append((path, new_labels))

#     def __getitem__(self, index):
#         full_path, labels = self.images[index]
#         image = Image.open(full_path)
#         image = image.convert('RGB')
#         image = self.transform(image)
#         return image, labels

#     def n_per_class(self, num_examples, valid_classes):
#         quotas = dict()
#         for label in valid_classes:
#             quotas[label] = num_examples
#         remaining_images = []
#         for path, label in self.images:
#             if label in valid_classes:
#                 if quotas[label] < 0:
#                     quotas[label] -= 1
#                     remaining_images.append((path, label))
#         self.images = remaining_images
#         print('Purged some examples. %s classes and %s examples remaining.' % (len(valid_classes), len(self.images)))

#     def __len__(self):
#         return len(self.images)


In [8]:
class Objectnet(Dataset):
    def __init__(self, root, transform, onlabel2oncompressed, num_examples, overlap, test_images=None):
        self.transform = transform
        if test_images is None:
            self.classes_in_dataset = set()
            images_dict = dict()
            for dirname in os.listdir(root):
                label = on2onlabel[dirname]
                if overlap:
                    if label not in onlabel2oncompressed:
                        continue
                    label = onlabel2oncompressed[label]
                    class_name = dirname_to_classname[dirname]
                images = os.listdir(os.path.join(root, dirname))
                if len(images) < num_examples:
                    continue
                for image_name in images:
                    path = os.path.join(root, dirname, image_name)
                    if label not in images_dict:
                        images_dict[label] = []
                    images_dict[label].append(path)
                self.classes_in_dataset.add(dirname)
            self.images = []
            self.test_images = []
            for label in images_dict.keys():
                idxs_to_choose_from = list(range(len(images_dict[label])))
                chosen_idxs = np.random.choice(idxs_to_choose_from, num_examples, replace=False)
                class_training_idxs = set(chosen_idxs.tolist())
                class_training_images = [images_dict[label][idx] for idx in class_training_idxs]
                if len(class_training_images) != 8:
                    import pdb; pdb.set_trace()
                test_training_idxs = [x for x in range(len(images_dict[label])) if x not in class_training_idxs]
                class_test_images = [images_dict[label][idx] for idx in test_training_idxs]
                [self.images.append((image, label)) for image in class_training_images]
                [self.test_images.append((image, label)) for image in class_test_images]
            print('Dataset has %s classes, %s training examples and %s test examples' % (len(self.classes_in_dataset), len(self.images), len(self.test_images)))
        else:
            self.images = test_images

#     def n_per_class(self, num_examples, test):
#         valid_classes = set()
#         [valid_classes.add(label) for _, label in self.images]

#         quotas = dict()
#         for label in valid_classes:
#             quotas[label] = 0
#         test_images = []
#         remaining_images = []
#         for path, objectnet_label in self.images:
#             if not test:
#                 if quotas[objectnet_label] < num_examples:
#                     quotas[objectnet_label] += 1
#                     remaining_images.append((path, objectnet_label))
#                 else:
#                     test_images.append((path, objectnet_label))
#             else:
#                 if quotas[objectnet_label] < num_examples * 2:
#                     if quotas[objectnet_label] >= num_examples:
#                         remaining_images.append((path, objectnet_label))
#                     quotas[objectnet_label] += 1
#                 else:
#                     test_images.append((path, objectnet_label))
#         self.images = remaining_images
#         self.test_images = test_images
#         print('Removed some examples. %s classes and %s examples remaining.' % (len(valid_classes), len(self.images)))

    def __getitem__(self, index):
        full_path, labels = self.images[index]
        image = Image.open(full_path)
        image = image.convert('RGB')
        image = self.transform(image)
        return image, labels

    def __len__(self):
        return len(self.images)

<h3> Helper functions </h3>

In [9]:
def evaluate():
    total_top1, total_top5, total_examples = 0, 0, 0
    for batch_counter, (batch, labels) in enumerate(test_loader):
        print(batch_counter / len(test_loader))
        labels = labels.to(DEVICE)
        batch = batch.to(DEVICE)
        with torch.no_grad():
            logits = model(batch)
        top1, top5 = accuracy(logits, labels)
        total_top1 += top1
        total_top5 += top5
        total_examples += batch.shape[0]
    top1_score = total_top1 / total_examples
    top5_score = total_top5 / total_examples
    return top1_score, top5_score

In [10]:
def accuracy(logits, targets):
    _, pred = logits.topk(5, 1, True, True)
    targets = targets.unsqueeze(1)
    targets_repeat = targets.repeat(1, 5)
    assert pred.shape == targets_repeat.shape
    correct = ((pred - targets_repeat) == 0).float()
    top1_score = correct[:, 0].sum()
    top5_score = correct.sum()
    return top1_score.item(), top5_score.item()

<h3> params and model </h3>

In [11]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
WORKERS = 60
BATCH_SIZE = 32
NUM_EXAMPLES = 8
OVERLAP = True

In [12]:
image_dir = '/storage/jalverio/objectnet-oct-24-d123/'
dataset = Objectnet(image_dir, transformations, onlabel2oncompressed, NUM_EXAMPLES, OVERLAP, test_images=None)
dataset_test = Objectnet(image_dir, transformations, onlabel2oncompressed, NUM_EXAMPLES, OVERLAP, test_images=dataset.test_images)
total_classes = len(dataset.classes_in_dataset)

val_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=BATCH_SIZE, shuffle=False,
        num_workers=WORKERS, pin_memory=True)

test_loader = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=512, shuffle=False,
        num_workers=WORKERS, pin_memory=True)

Dataset has 113 classes, 904 training examples and 16990 test examples


In [17]:
model = torchvision.models.resnet152(pretrained=True).eval()
for param in model.parameters():
    param.requires_grad = False
model.fc = nn.Linear(2048, total_classes, bias=True)
model = model.eval().to(DEVICE)

<h3> Training <h3>

In [18]:
criterion = nn.CrossEntropyLoss().to(DEVICE)
optimizer = Adam(model.parameters(), lr=0.0001)
previous_accuracy = 0.
top_score = 0.
total_top1, total_top5, total_examples = 0, 0, 0

In [21]:
for epoch in range(50):
    total_examples = 0
    total_training_top1 = 0
    total_training_top5 = 0
    print('starting epoch %s' % epoch)
    for batch, labels in val_loader:
        labels = labels.to(DEVICE)
        batch = batch.to(DEVICE)
        logits = model(batch)
        top1, top5 = accuracy(logits, labels)
        total_training_top1 += top1
        total_training_top5 += top5
        
        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_examples += batch.shape[0]

    training_top1_performance = total_training_top1 / total_examples
    training_top5_performance = total_training_top5 / total_examples
    print('training top1 score: %s' % training_top1_performance)
    print('training top5 score: %s' % training_top5_performance)

starting epoch 0
training top1 score: 0.9170353982300885
training top5 score: 0.995575221238938
starting epoch 1
training top1 score: 0.9192477876106194
training top5 score: 0.995575221238938
starting epoch 2
training top1 score: 0.9214601769911505
training top5 score: 0.995575221238938
starting epoch 3
training top1 score: 0.9280973451327433
training top5 score: 0.995575221238938
starting epoch 4
training top1 score: 0.9314159292035398
training top5 score: 0.995575221238938
starting epoch 5
training top1 score: 0.9325221238938053
training top5 score: 0.9966814159292036
starting epoch 6
training top1 score: 0.9369469026548672
training top5 score: 0.9977876106194691
starting epoch 7
training top1 score: 0.9369469026548672
training top5 score: 0.9977876106194691
starting epoch 8
training top1 score: 0.9402654867256637
training top5 score: 0.9977876106194691
starting epoch 9
training top1 score: 0.9435840707964602
training top5 score: 0.9977876106194691
starting epoch 10
training top1 sco

<h3> Evaluate when done <h3>

In [22]:
top1, top5 = evaluate()
print('top1 score', top1)
print('top5 score', top5)

0.0
0.029411764705882353
0.058823529411764705
0.08823529411764706
0.11764705882352941
0.14705882352941177
0.17647058823529413
0.20588235294117646
0.23529411764705882
0.2647058823529412
0.29411764705882354
0.3235294117647059
0.35294117647058826
0.38235294117647056
0.4117647058823529
0.4411764705882353
0.47058823529411764
0.5
0.5294117647058824
0.5588235294117647
0.5882352941176471
0.6176470588235294
0.6470588235294118
0.6764705882352942
0.7058823529411765
0.7352941176470589
0.7647058823529411
0.7941176470588235
0.8235294117647058
0.8529411764705882
0.8823529411764706
0.9117647058823529
0.9411764705882353
0.9705882352941176
top1 score 0.3274278987639788
top5 score 0.5958799293702177


In [None]:
# with 0.0001
# top1 score 0.3578546635315194
# top5 score 0.6243421789273318

In [None]:
# with 0.001 mistake
# top1 score 0.36076587168290225
# top5 score 0.6241182398387639

In [None]:
# after refactoring with 0.001 on FULL objectnet
# top1 score 0.20835247442054708
# top5 score 0.4139277510962623

In [None]:
# after refactoring with 0.0001
# top1 score 0.3274278987639788
# top5 score 0.5958799293702177