-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Ported BNInception from pt03 to pt04. Using BNInception now for train…
…ing. Fixed reproduction of results from paper (comparable now).
- Loading branch information
1 parent
deeba4c
commit 5e2babb
Showing
22 changed files
with
1,928 additions
and
1,030 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
|
||
from .cub import Birds |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import os | ||
from . import utils | ||
import torch | ||
import torchvision | ||
import numpy as np | ||
import PIL.Image | ||
|
||
|
||
class Birds(torch.utils.data.Dataset): | ||
def __init__(self, path, labels, transform = None): | ||
# e.g., labels = range(0, 50) for using first 50 classes only | ||
self.labels = labels | ||
if transform: self.transform = transform | ||
self.ys, self.im_paths = [], [] | ||
for i in torchvision.datasets.ImageFolder(root = path).imgs: | ||
# i[1]: label, i[0]: path | ||
y = i[1] | ||
# fn needed for removing non-images starting with `._` | ||
fn = os.path.split(i[0])[1] | ||
if y in self.labels and fn[:2] != '._': | ||
self.ys += [y] | ||
self.im_paths.append(os.path.join(path, i[0])) | ||
|
||
def nb_classes(self): | ||
n = len(np.unique(self.ys)) | ||
assert n == len(self.labels) | ||
return n | ||
|
||
def __len__(self): | ||
return len(self.ys) | ||
|
||
def __getitem__(self, index): | ||
im = PIL.Image.open(self.im_paths[index]) | ||
im = self.transform(im) | ||
return im, self.ys[index] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from torchvision import transforms | ||
import PIL.Image | ||
import torch | ||
|
||
|
||
def std_per_channel(images): | ||
images = torch.stack(images, dim = 0) | ||
return images.view(3, -1).std(dim = 1) | ||
|
||
|
||
def mean_per_channel(images): | ||
images = torch.stack(images, dim = 0) | ||
return images.view(3, -1).mean(dim = 1) | ||
|
||
|
||
class Identity(): # used for skipping transforms | ||
def __call__(self, im): | ||
return im | ||
|
||
|
||
class ScaleIntensities(): | ||
def __init__(self, in_range, out_range): | ||
""" Scales intensities. For example [-1, 1] -> [0, 255].""" | ||
self.in_range = in_range | ||
self.out_range = out_range | ||
|
||
def __call__(self, tensor): | ||
tensor = ( | ||
tensor - self.in_range[0] | ||
) / ( | ||
self.in_range[1] - self.in_range[0] | ||
) * ( | ||
self.out_range[1] - self.out_range[0] | ||
) + self.out_range[0] | ||
return tensor | ||
|
||
|
||
def make_transform(sz_resize = 256, sz_crop = 227, mean = [128, 117, 104], | ||
std = [1, 1, 1], rgb_to_bgr = True, is_train = True, | ||
intensity_scale = [[0, 1], [0, 255]]): | ||
return transforms.Compose([ | ||
transforms.Compose([ # train: horizontal flip and random resized crop | ||
transforms.RandomResizedCrop(sz_crop), | ||
transforms.RandomHorizontalFlip(), | ||
]) if is_train else transforms.Compose([ # test: else center crop | ||
transforms.Resize(sz_resize), | ||
transforms.CenterCrop(sz_crop), | ||
]), | ||
transforms.ToTensor(), | ||
ScaleIntensities( | ||
*intensity_scale) if intensity_scale is not None else Identity(), | ||
transforms.Normalize( | ||
mean=mean, | ||
std=std, | ||
), | ||
transforms.Lambda( | ||
lambda x: x[[2, 1, 0], ...] | ||
) if rgb_to_bgr else Identity() | ||
]) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
__all__ = ["nmi", "ratk"] | ||
|
||
from .nmi import cluster_by_kmeans, calc_nmi | ||
from .ratk import assign_by_euclidian_at_k, recall_at_k | ||
from .normalized_mutual_information import calc_normalized_mutual_information | ||
from .normalized_mutual_information import cluster_by_kmeans | ||
from .recall import assign_by_euclidian_at_k | ||
from .recall import calc_recall_at_k | ||
|
6 changes: 3 additions & 3 deletions
6
evaluation/nmi.py → evaluation/normalized_mutual_information.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,12 @@ | ||
import sklearn.cluster | ||
import sklearn.metrics.cluster | ||
|
||
def cluster_by_kmeans(xs, nb_clusters): | ||
def cluster_by_kmeans(X, nb_clusters): | ||
""" | ||
xs : embeddings with shape [nb_samples, nb_features] | ||
nb_clusters : in this case, must be equal to number of classes | ||
""" | ||
return sklearn.cluster.KMeans(nb_clusters).fit(xs).labels_ | ||
return sklearn.cluster.KMeans(nb_clusters).fit(X).labels_ | ||
|
||
def calc_nmi(ys, xs_clustered): | ||
def calc_normalized_mutual_information(ys, xs_clustered): | ||
return sklearn.metrics.cluster.normalized_mutual_info_score(xs_clustered, ys) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
|
||
import numpy as np | ||
import sklearn.metrics.pairwise | ||
|
||
def assign_by_euclidian_at_k(X, T, k): | ||
""" | ||
X : [nb_samples x nb_features], e.g. 100 x 64 (embeddings) | ||
k : for each sample, assign target labels of k nearest points | ||
""" | ||
distances = sklearn.metrics.pairwise.pairwise_distances(X) | ||
# get nearest points | ||
indices = np.argsort(distances, axis = 1)[:, 1 : k + 1] | ||
return np.array([[T[i] for i in ii] for ii in indices]) | ||
|
||
|
||
def calc_recall_at_k(T, Y, k): | ||
""" | ||
T : [nb_samples] (target labels) | ||
Y : [nb_samples x k] (k predicted labels/neighbours) | ||
""" | ||
s = sum([1 for t, y in zip(T, Y) if t in y[:k]]) | ||
return s / (1. * len(T)) | ||
|
||
|
Oops, something went wrong.