Skip to content

Commit

Permalink
flowers102 one class
Browse files Browse the repository at this point in the history
  • Loading branch information
dreamflasher committed Sep 7, 2020
1 parent 4774a19 commit 9c81414
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
3 changes: 2 additions & 1 deletion common/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ def parse_args(default=False):
parser = ArgumentParser(description='Pytorch implementation of CSI')

parser.add_argument('--dataset', help='Dataset',
choices=['cifar10', 'cifar100', 'imagenet', 'flowers102', 'fashionmnist', 'mnist'], type=str)
choices=['cifar10', 'cifar100', 'imagenet', 'flowers102', 'fashionmnist',
'mnist', 'oxford102flower'], type=str)
parser.add_argument('--one_class_idx', help='None: multi-class, Not None: one-class',
default=None, type=int)
parser.add_argument('--model', help='Model',
Expand Down
11 changes: 10 additions & 1 deletion datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def get_transform_imagenet():


def get_dataset(P, dataset, test_only=False, image_size=None, download=True, eval=False):
if dataset in ['imagenet', 'cub', 'stanford_dogs', 'flowers102',
if dataset in ['imagenet', 'cub', 'stanford_dogs', 'flowers102', 'oxford102flower',
'places365', 'food_101', 'caltech_256', 'dtd', 'pets']:
if eval:
train_transform, test_transform = get_simclr_eval_transform_imagenet(P.ood_samples,
Expand Down Expand Up @@ -168,6 +168,13 @@ def get_dataset(P, dataset, test_only=False, image_size=None, download=True, eva
train_set = datasets.MNIST(DATA_PATH, train=True, download=download, transform=train_transform)
test_set = datasets.MNIST(DATA_PATH, train=False, download=download, transform=test_transform)

elif dataset == 'oxford102flower':
image_size = (224, 224, 3)
n_classes = 102
base_dir = os.path.join(DATA_PATH, 'oxford102flower')
train_set = datasets.ImageFolder(base_dir + "/train", transform=train_transform)
test_set = datasets.ImageFolder(base_dir + "/valid", transform=test_transform)

elif dataset == 'cifar100':
image_size = (32, 32, 3)
n_classes = 100
Expand Down Expand Up @@ -275,6 +282,8 @@ def get_superclass_list(dataset):
return list(range(10))
elif dataset == 'mnist':
return list(range(10))
elif dataset == 'oxford102flower':
return list(range(1, 103))
else:
raise NotImplementedError()

Expand Down
12 changes: 6 additions & 6 deletions training/unsup/simclr_CSI.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.optim

import models.transform_layers as TL
from training.contrastive_loss import get_similarity_matrix, NT_xent
from training.contrastive_loss import NT_xent, get_similarity_matrix
from utils.utils import AverageMeter, normalize

device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -39,13 +39,14 @@ def train(P, epoch, model, criterion, optimizer, scheduler, loader, logger=None,
check = time.time()

### SimCLR loss ###
if P.dataset != 'imagenet':
if P.dataset == 'imagenet' or P.dataset == 'oxford102flower':
batch_size = images[0].size(0)
images1, images2 = images[0].to(device), images[1].to(device)
else:
batch_size = images.size(0)
images = images.to(device)
images1, images2 = hflip(images.repeat(2, 1, 1, 1)).chunk(2) # hflip
else:
batch_size = images[0].size(0)
images1, images2 = images[0].to(device), images[1].to(device)

labels = labels.to(device)

images1 = torch.cat([P.shift_trans(images1, k) for k in range(P.K_shift)])
Expand Down Expand Up @@ -110,4 +111,3 @@ def train(P, epoch, model, criterion, optimizer, scheduler, loader, logger=None,
logger.scalar_summary('train/loss_sim', losses['sim'].average, epoch)
logger.scalar_summary('train/loss_shift', losses['shift'].average, epoch)
logger.scalar_summary('train/batch_time', batch_time.average, epoch)

0 comments on commit 9c81414

Please sign in to comment.