In [1]:
from torch.utils.data import Dataset
import torchvision
import torch
import os
from PIL import Image
from torchvision import transforms
import torch.nn as nn
import json
import pickle
from torch.optim import Adam
import argparse
import copy
import numpy as np

In [5]:
on2onlabel = dict()
for idx, name in enumerate(os.listdir('/storage/jalverio/objectnet-oct-24-d123')):
    on2onlabel[name] = idx
onlabel2name = {v: k for k, v in on2onlabel.items()}

FileNotFoundError: [Errno 2] No such file or directory: '/storage/jalverio/objectnet-oct-24-d123'

In [None]:
with open('/storage/jalverio/resnet/objectnet2torch.pkl', 'rb') as f:
    objectnet2torch = pickle.load(f)
torch2objectnet = dict()
for objectnet_name, label_list in objectnet2torch.items():
    for label in label_list:
        torch2objectnet[label] = objectnet_name

In [None]:
with open('/storage/jalverio/resnet/dirname_to_objectnet_name.json') as f:
    dirname_to_classname = json.load(f)

with open('/storage/jalverio/resnet/objectnet_subset_to_objectnet_id') as f:
    oncompressed2onlabel = eval(f.read())
    onlabel2oncompressed = {v:k for k,v in oncompressed2onlabel.items()}


In [4]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
transformations = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])

In [None]:
class Objectnet(Dataset):
    def __init__(self, root, transform, objectnet2torch, num_examples, test, overlap, test_images=None):
        self.root = root
        self.transform = transform
        self.images = []
        classes_in_dataset = set()
        for dirname in os.listdir(root):
            if overlap:
                class_name = dirname_to_classname[dirname]
                if class_name not in objectnet2torch:
                    continue
            classes_in_dataset.add(dirname)
            label = on2onlabel[dirname]
            import pdb; pdb.set_trace()
            label = onlabel2oncompressed[label]
            images = os.listdir(os.path.join(root, dirname))
            for image_name in images:
                path = os.path.join(root, dirname, image_name)
                self.images.append((path, label))

        if num_examples == 64:
            self.remove_small_classes()

        print('Created objectnet dataset with %s classes' % len(classes_in_dataset))
        self.n_per_class(num_examples, test)

        self.classes_in_dataset = classes_in_dataset

    def remove_small_classes(self):
        counter_dict = dict()
        for _, label in self.images:
            if label not in counter_dict:
                counter_dict[label] = 1
            else:
                counter_dict[label] += 1
        to_remove = []
        for label, frequency in counter_dict.items():
            if frequency < 64:
                to_remove.append(label)
        to_remove = set(to_remove)
        new_images = []
        for path, label in self.images:
            if label not in to_remove:
                new_images.append((path, label))
        self.images = new_images

    def n_per_class(self, num_examples, test):
        valid_classes = set()
        [valid_classes.add(label) for _, label in self.images]

        quotas = dict()
        for label in valid_classes:
            quotas[label] = 0
        test_images = []
        remaining_images = []
        for path, objectnet_label in self.images:
            if not test:
                if quotas[objectnet_label] < num_examples:
                    quotas[objectnet_label] += 1
                    remaining_images.append((path, objectnet_label))
                else:
                    test_images.append((path, objectnet_label))
            else:
                if quotas[objectnet_label] < num_examples * 2:
                    if quotas[objectnet_label] >= num_examples:
                        remaining_images.append((path, objectnet_label))
                    quotas[objectnet_label] += 1
                else:
                    test_images.append((path, objectnet_label))
        self.images = remaining_images
        self.test_images = test_images
        print('Removed some examples. %s classes and %s examples remaining.' % (len(valid_classes), len(self.images)))

    def __getitem__(self, index):
        full_path, labels = self.images[index]
        image = Image.open(full_path)
        image = image.convert('RGB')
        image = self.transform(image)
        return image, labels

    def __len__(self):
        return len(self.images)

In [None]:
def evaluate():
    total_top1, total_top5 = 0, 0
    total_examples = 0
    score_dict = dict()
    for class_name in VALID_CLASSES:
        score_dict[on2onlabel[class_name]] = np.zeros((2,))
    for batch, labels in test_loader:
        labels = labels.to(DEVICE)
        batch = batch.to(DEVICE)
        with torch.no_grad():
            logits = model(batch)
        # accuracy_results = accuracy_objectnet_nobatch(logits, labels)
        top1, top5 = accuracy_objectnet(logits, labels)
        # score_dict[labels.item()] += accuracy_results
        total_top1 += top1
        total_top5 += top5
        total_examples += batch.shape[0]
    top1_score = total_top1 / total_examples
    top5_score = total_top5 / total_examples
    SAVER.write_evaluation_record(top1_score, top5_score)
    return top1_score, top5_score

In [6]:
def accuracy_objectnet(output, target):
    with torch.no_grad():
        _, pred = output.topk(5, 1, True, True)
    top5_correct = 0
    top1_correct = 0

    for idx, prediction in enumerate(pred):
        pred_set = set(prediction.cpu().numpy().tolist())
        target_set = set([target[idx].cpu().numpy().tolist()])
        if pred_set.intersection(target_set):
            top5_correct += 1

        if prediction[0].item() in target_set:
            top1_correct += 1
    return top1_correct, top5_correct


In [7]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
WORKERS = 50
BATCH_SIZE = 32
NUM_EXAMPLES = 8
OVERLAP = True

In [8]:
model = torchvision.models.resnet152(pretrained=True).eval()
for param in model.parameters():
    param.requires_grad = False
model.fc = nn.Linear(2048, 1000, bias=True)
model = model.eval().to(DEVICE)

Downloading: "https://download.pytorch.org/models/resnet152-b121ed2d.pth" to /Users/julianalverio/.cache/torch/checkpoints/resnet152-b121ed2d.pth
100%|██████████| 241530880/241530880 [00:18<00:00, 12775518.92it/s]


In [None]:
image_dir = '/storage/jalverio/objectnet-oct-24-d123/'
dataset = Objectnet(image_dir, transformations, objectnet2torch, N_EXAMPLES, test=False, overlap=OVERLAP)
total_classes = len(dataset.classes_in_dataset)
VALID_CLASSES = dataset.classes_in_dataset
dataset_test = copy.deepcopy(dataset)
dataset_test.images = dataset.test_images
val_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=BATCH_SIZE, shuffle=False,
        num_workers=WORKERS, pin_memory=True)

test_loader = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=256, shuffle=False,
        num_workers=WORKERS, pin_memory=True)

In [None]:
for epoch in range(50):
    total_examples = 0
    total_training_top1 = 0
    total_training_top5 = 0
    print('starting epoch %s' % epoch)
    for batch_counter, (batch, labels) in enumerate(val_loader):
        labels = labels.to(DEVICE)
        batch = batch.to(DEVICE)
        logits = model(batch)
        top1, top5 = accuracy_objectnet(logits, labels)
        total_training_top1 += top1
        total_training_top5 += top5
        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_examples += batch.shape[0]

    training_top1_performance = total_training_top1 / total_examples
    training_top5_performance = total_training_top5 / total_examples
    print('training top1 score: %s' % training_top1_performance)
    print('training top5 score: %s' % training_top5_performance)