In [1]:
%matplotlib inline
import sys
import os
import argparse
import time
import numpy as np
sys.path.append('../')

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim.lr_scheduler as lr_scheduler
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import easydict as edict

from lib import models, datasets
import math

In [2]:
# parameters
args = edict

# imagenet
args.cache = '../checkpoint/train_features_labels_cache/instance_imagenet_train_feature_resnet50.pth.tar'
args.val_cache = '../checkpoint/train_features_labels_cache/instance_imagenet_val_feature_resnet50.pth.tar'
args.save_path = '../checkpoint/pseudos/unsupervised_imagenet32x32_nc_wrn-28-2'
os.makedirs(args.save_path, exist_ok=True)

args.low_dim = 128
args.num_class = 1000
args.rng_seed = 0

In [3]:
ckpt = torch.load(args.cache)
train_labels, train_features = ckpt['labels'], ckpt['features']

ckpt = torch.load(args.val_cache)
val_labels, val_features = ckpt['val_labels'], ckpt['val_features']

train_features = torch.cat([val_features, train_features], dim=0)
train_labels = torch.cat([val_labels, train_labels], dim=0)

print(train_features.dtype, train_labels.dtype)
print(train_features.shape, train_labels.shape)

torch.float32 torch.int64
torch.Size([1331167, 128]) torch.Size([1331167])


# use cpu because the following computation need a lot of memory

In [None]:
device = 'cpu'
train_features, train_labels = train_features.to(device), train_labels.to(device)

In [None]:
num_train_data = train_labels.shape[0]
num_class = torch.max(train_labels) + 1

torch.manual_seed(args.rng_seed)
torch.cuda.manual_seed_all(args.rng_seed)
perm = torch.randperm(num_train_data).to(device)
print(perm)

tensor([ 970454, 1058848,  717280,  ...,  462299,  305137,  436069])


# soft label

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
n_chunks = 100
n_val = val_features.shape[0]

prec_top5 = AverageMeter()
prec_top1 = AverageMeter()
index_labeled = torch.arange(n_val, train_features.shape[0])
index_unlabeled = torch.arange(n_val)
num_labeled_data = index_labeled.shape[0]

for i_chunks, index_unlabeled_chunk in enumerate(index_unlabeled.chunk(n_chunks)):

    # calculate similarity matrix
    dist = torch.mm(train_features[index_unlabeled_chunk], train_features[index_labeled].t())

    K = min(num_labeled_data, 200)
    bs = index_unlabeled_chunk.shape[0]
    yd, yi = dist.topk(K, dim=1, largest=True, sorted=True)
    candidates = train_labels.view(1,-1).expand(bs, -1)
    retrieval = torch.gather(candidates, 1, index_labeled[yi])
    retrieval_one_hot = torch.zeros(bs * K, num_class).to(device)
    retrieval_one_hot.scatter_(1, retrieval.view(-1, 1), 1)

    temperature = 0.1

    yd_transform = (yd / temperature).exp_()
    probs = torch.sum(torch.mul(retrieval_one_hot.view(bs, -1 , num_class), yd_transform.view(bs, -1, 1)), 1)
    probs.div_(probs.sum(dim=1, keepdim=True))
    probs_sorted, predictions = probs.sort(1, True)
    correct = predictions.eq(train_labels[index_unlabeled_chunk].data.view(-1,1))
    
    top5 = torch.any(correct[:, :5], dim=1).float().mean() 
    top1 = correct[:, 0].float().mean() 
    prec_top5.update(top5, bs)
    prec_top1.update(top1, bs)
    print('[{}]/[{}] top5={:.2%}({:.2%}) top1={:.2%}({:.2%})'.format(
        i_chunks, n_chunks, prec_top5.val, prec_top5.avg, prec_top1.val, prec_top1.avg))

[0]/[100] top5=85.00%(85.00%) top1=66.60%(66.60%)
[1]/[100] top5=79.60%(82.30%) top1=52.80%(59.70%)
[2]/[100] top5=81.00%(81.87%) top1=61.40%(60.27%)
[3]/[100] top5=65.20%(77.70%) top1=42.80%(55.90%)
[4]/[100] top5=70.00%(76.16%) top1=47.40%(54.20%)
[5]/[100] top5=69.20%(75.00%) top1=42.60%(52.27%)
[6]/[100] top5=67.20%(73.89%) top1=41.40%(50.71%)
[7]/[100] top5=77.60%(74.35%) top1=52.20%(50.90%)
[8]/[100] top5=84.60%(75.49%) top1=67.00%(52.69%)
[9]/[100] top5=77.40%(75.68%) top1=57.00%(53.12%)
[10]/[100] top5=82.20%(76.27%) top1=67.00%(54.38%)
[11]/[100] top5=71.00%(75.83%) top1=49.00%(53.93%)
[12]/[100] top5=65.20%(75.02%) top1=43.80%(53.15%)
[13]/[100] top5=83.00%(75.59%) top1=62.20%(53.80%)
[14]/[100] top5=85.20%(76.23%) top1=66.80%(54.67%)
[15]/[100] top5=71.80%(75.95%) top1=43.60%(53.97%)
[16]/[100] top5=62.20%(75.14%) top1=39.80%(53.14%)
[17]/[100] top5=61.00%(74.36%) top1=38.80%(52.34%)
[18]/[100] top5=63.60%(73.79%) top1=36.60%(51.52%)
[19]/[100] top5=67.40%(73.47%) top1=41.60

In [None]:
# n_chunks = 100

# prec_top5 = AverageMeter()
# for num_labeled_data in [10000]:
#     index_labeled = []
#     index_unlabeled = []
#     data_per_class = num_labeled_data // args.num_class
#     for c in range(args.num_class):
#         indexes_c = perm[train_labels[perm] == c]
#         index_labeled.append(indexes_c[:data_per_class])
#         index_unlabeled.append(indexes_c[data_per_class:])
#     index_labeled = torch.cat(index_labeled)
#     index_unlabeled = torch.cat(index_unlabeled)

#     for i_chunks, index_unlabeled_chunk in enumerate(index_unlabeled.chunk(n_chunks)):
    
#         # calculate similarity matrix
#         dist = torch.mm(train_features[index_unlabeled_chunk], train_features[index_labeled].t())

#         K = min(num_labeled_data, 5000)
#         bs = index_unlabeled_chunk.shape[0]
#         yd, yi = dist.topk(K, dim=1, largest=True, sorted=True)
#         candidates = train_labels.view(1,-1).expand(bs, -1)
#         retrieval = torch.gather(candidates, 1, index_labeled[yi])
#         retrieval_one_hot = torch.zeros(bs * K, num_class).to(device)
#         retrieval_one_hot.scatter_(1, retrieval.view(-1, 1), 1)

#         temperature = 0.1

#         yd_transform = (yd / temperature).exp_()
#         probs = torch.sum(torch.mul(retrieval_one_hot.view(bs, -1 , num_class), yd_transform.view(bs, -1, 1)), 1)
#         probs.div_(probs.sum(dim=1, keepdim=True))
#         probs_sorted, predictions = probs.sort(1, True)
#         correct = predictions.eq(train_labels[index_unlabeled_chunk].data.view(-1,1))
#         top5 = torch.any(correct[:, :5], dim=1).float().mean() 
        
#         prec_top5.update(top5, bs)
#         print('[{}]/[{}] {:.2%} {:.2%}'.format(i_chunks, n_chunks, prec_top5.val, prec_top5.avg))