Skip to content

Commit

Permalink
Ported BNInception from pt03 to pt04. Using BNInception now for train…
Browse files Browse the repository at this point in the history
…ing. Fixed reproduction of results from paper (comparable now).
  • Loading branch information
dichotomies committed Aug 3, 2018
1 parent deeba4c commit 5e2babb
Show file tree
Hide file tree
Showing 22 changed files with 1,928 additions and 1,030 deletions.
4 changes: 0 additions & 4 deletions data/__init__.py

This file was deleted.

61 changes: 0 additions & 61 deletions data/cars.py

This file was deleted.

43 changes: 0 additions & 43 deletions data/cub.py

This file was deleted.

20 changes: 0 additions & 20 deletions data/utils.py

This file was deleted.

2 changes: 2 additions & 0 deletions dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

from .cub import Birds
36 changes: 36 additions & 0 deletions dataset/cub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import os
from . import utils
import torch
import torchvision
import numpy as np
import PIL.Image


class Birds(torch.utils.data.Dataset):
def __init__(self, path, labels, transform = None):
# e.g., labels = range(0, 50) for using first 50 classes only
self.labels = labels
if transform: self.transform = transform
self.ys, self.im_paths = [], []
for i in torchvision.datasets.ImageFolder(root = path).imgs:
# i[1]: label, i[0]: path
y = i[1]
# fn needed for removing non-images starting with `._`
fn = os.path.split(i[0])[1]
if y in self.labels and fn[:2] != '._':
self.ys += [y]
self.im_paths.append(os.path.join(path, i[0]))

def nb_classes(self):
n = len(np.unique(self.ys))
assert n == len(self.labels)
return n

def __len__(self):
return len(self.ys)

def __getitem__(self, index):
im = PIL.Image.open(self.im_paths[index])
im = self.transform(im)
return im, self.ys[index]

60 changes: 60 additions & 0 deletions dataset/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from torchvision import transforms
import PIL.Image
import torch


def std_per_channel(images):
images = torch.stack(images, dim = 0)
return images.view(3, -1).std(dim = 1)


def mean_per_channel(images):
images = torch.stack(images, dim = 0)
return images.view(3, -1).mean(dim = 1)


class Identity(): # used for skipping transforms
def __call__(self, im):
return im


class ScaleIntensities():
def __init__(self, in_range, out_range):
""" Scales intensities. For example [-1, 1] -> [0, 255]."""
self.in_range = in_range
self.out_range = out_range

def __call__(self, tensor):
tensor = (
tensor - self.in_range[0]
) / (
self.in_range[1] - self.in_range[0]
) * (
self.out_range[1] - self.out_range[0]
) + self.out_range[0]
return tensor


def make_transform(sz_resize = 256, sz_crop = 227, mean = [128, 117, 104],
std = [1, 1, 1], rgb_to_bgr = True, is_train = True,
intensity_scale = [[0, 1], [0, 255]]):
return transforms.Compose([
transforms.Compose([ # train: horizontal flip and random resized crop
transforms.RandomResizedCrop(sz_crop),
transforms.RandomHorizontalFlip(),
]) if is_train else transforms.Compose([ # test: else center crop
transforms.Resize(sz_resize),
transforms.CenterCrop(sz_crop),
]),
transforms.ToTensor(),
ScaleIntensities(
*intensity_scale) if intensity_scale is not None else Identity(),
transforms.Normalize(
mean=mean,
std=std,
),
transforms.Lambda(
lambda x: x[[2, 1, 0], ...]
) if rgb_to_bgr else Identity()
])

7 changes: 4 additions & 3 deletions evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__all__ = ["nmi", "ratk"]

from .nmi import cluster_by_kmeans, calc_nmi
from .ratk import assign_by_euclidian_at_k, recall_at_k
from .normalized_mutual_information import calc_normalized_mutual_information
from .normalized_mutual_information import cluster_by_kmeans
from .recall import assign_by_euclidian_at_k
from .recall import calc_recall_at_k

Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import sklearn.cluster
import sklearn.metrics.cluster

def cluster_by_kmeans(xs, nb_clusters):
def cluster_by_kmeans(X, nb_clusters):
"""
xs : embeddings with shape [nb_samples, nb_features]
nb_clusters : in this case, must be equal to number of classes
"""
return sklearn.cluster.KMeans(nb_clusters).fit(xs).labels_
return sklearn.cluster.KMeans(nb_clusters).fit(X).labels_

def calc_nmi(ys, xs_clustered):
def calc_normalized_mutual_information(ys, xs_clustered):
return sklearn.metrics.cluster.normalized_mutual_info_score(xs_clustered, ys)
23 changes: 0 additions & 23 deletions evaluation/ratk.py

This file was deleted.

24 changes: 24 additions & 0 deletions evaluation/recall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

import numpy as np
import sklearn.metrics.pairwise

def assign_by_euclidian_at_k(X, T, k):
"""
X : [nb_samples x nb_features], e.g. 100 x 64 (embeddings)
k : for each sample, assign target labels of k nearest points
"""
distances = sklearn.metrics.pairwise.pairwise_distances(X)
# get nearest points
indices = np.argsort(distances, axis = 1)[:, 1 : k + 1]
return np.array([[T[i] for i in ii] for ii in indices])


def calc_recall_at_k(T, Y, k):
"""
T : [nb_samples] (target labels)
Y : [nb_samples x k] (k predicted labels/neighbours)
"""
s = sum([1 for t, y in zip(T, Y) if t in y[:k]])
return s / (1. * len(T))


Loading

0 comments on commit 5e2babb

Please sign in to comment.