In [1]:
%matplotlib inline
import sys
sys.path.append('../')
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import math
import os
import argparse
import time

from lib import models, datasets


import numpy as np
import scipy as sp
import scipy.sparse.linalg as linalg
import scipy.sparse as sparse

import matplotlib.pyplot as plt
import easydict as edict

In [2]:
# parameters
args = edict

args.cache = '../checkpoint/train_features_labels_cache/colorization_embedding_128.t7'
args.save_path = '../checkpoint/pseudos/colorization_nc_pseudo_wrn-28-2'
os.makedirs(args.save_path, exist_ok=True)

args.num_class = 10
args.rng_seed = 0

In [3]:
train_features = torch.load(args.cache)
train_labels = torch.Tensor(datasets.CIFAR10Instance(root='../data', train=True).targets).long()

print(train_features.dtype, train_labels.dtype)
print(train_features.shape, train_labels.shape)

torch.float32 torch.int64
torch.Size([50000, 128]) torch.Size([50000])


# use cpu because the follow computation need a lot of memory

In [4]:
device = 'cpu'
train_features, train_labels = train_features.to(device), train_labels.to(device)

In [5]:
num_train_data = train_labels.shape[0]
num_class = torch.max(train_labels) + 1

torch.manual_seed(args.rng_seed)
torch.cuda.manual_seed_all(args.rng_seed)
perm = torch.randperm(num_train_data).to(device)
print(perm)

tensor([36044, 49165, 37807,  ..., 42128, 15898, 31476])


# constrained normalized cut

In [6]:
K = 20
def make_column_normalize(X):
    return X.div(torch.norm(X, p=2, dim=0, keepdim=True))

cosin_similarity = torch.mm(train_features, train_features.t())
dist = (1 - cosin_similarity) / 2

dist_sorted, idx = dist.topk(K, dim=1, largest=False, sorted=True)
k_dist = dist_sorted[:, -1:]

similarity_dense = torch.exp(-dist_sorted * 2 / k_dist)
similarity_sparse = torch.zeros_like(cosin_similarity)
similarity_sparse[torch.arange(num_train_data).view(-1, 1), idx[:, 1:]] = similarity_dense[:, 1:]
similarity = torch.max(similarity_sparse, similarity_sparse.t())
print('similarity done')

degree = similarity.sum(0)
degree_normed = (degree**(-0.5))
L_sys = degree_normed.view(-1, 1) * (degree.diag() - similarity) * degree_normed.view(1, -1)
print('L_sys done')

similarity done
L_sys done


In [7]:
num_eigenvectors = 200 # the number of precomputed spectral eigenvectors.

eigenvalues, eigenvectors = linalg.eigs(L_sys.numpy(), k=num_eigenvectors, which='SR', tol=1e-2, maxiter=30000)
eigenvalues, eigenvectors = torch.from_numpy(eigenvalues.real)[1:], torch.from_numpy(eigenvectors.real)[:, 1:]
eigenvalues, idx = eigenvalues.sort()
eigenvectors = eigenvectors[:, idx]
print('eigenvectors done')
print(eigenvalues)

eigenvectors done
tensor([0.0160, 0.0236, 0.0306, 0.0318, 0.0359, 0.0400, 0.0453, 0.0561, 0.0603,
        0.0621, 0.0635, 0.0731, 0.0745, 0.0775, 0.0795, 0.0860, 0.0881, 0.0942,
        0.0967, 0.1003, 0.1035, 0.1047, 0.1070, 0.1094, 0.1170, 0.1221, 0.1237,
        0.1275, 0.1290, 0.1316, 0.1377, 0.1384, 0.1393, 0.1425, 0.1435, 0.1466,
        0.1505, 0.1539, 0.1548, 0.1591, 0.1615, 0.1644, 0.1660, 0.1665, 0.1699,
        0.1715, 0.1721, 0.1727, 0.1734, 0.1756, 0.1794, 0.1803, 0.1805, 0.1819,
        0.1854, 0.1874, 0.1875, 0.1881, 0.1889, 0.1919, 0.1926, 0.1952, 0.1975,
        0.1997, 0.2004, 0.2010, 0.2025, 0.2035, 0.2053, 0.2068, 0.2088, 0.2101,
        0.2117, 0.2138, 0.2145, 0.2150, 0.2183, 0.2190, 0.2205, 0.2220, 0.2245,
        0.2251, 0.2275, 0.2282, 0.2285, 0.2297, 0.2299, 0.2316, 0.2330, 0.2358,
        0.2359, 0.2381, 0.2396, 0.2409, 0.2425, 0.2440, 0.2448, 0.2459, 0.2462,
        0.2480, 0.2492, 0.2495, 0.2518, 0.2520, 0.2531, 0.2543, 0.2545, 0.2562,
        0.2574, 0.2582

In [8]:
fig = plt.figure(dpi=200)

for num_labeled_data in [50, 100, 250, 500, 1000, 2000, 4000, 8000]:
    # index of labeled and unlabeled
    # even split
    index_labeled = []
    index_unlabeled = []
    data_per_class = num_labeled_data // args.num_class
    for c in range(10):
        indexes_c = perm[train_labels[perm] == c]
        index_labeled.append(indexes_c[:data_per_class])
        index_unlabeled.append(indexes_c[data_per_class:])
    index_labeled = torch.cat(index_labeled)
    index_unlabeled = torch.cat(index_unlabeled)

#     index_labeled = perm[:num_labeled_data]
#     index_unlabeled = perm[num_labeled_data:]
    
    # prior
    unary_prior = torch.zeros([num_train_data, num_class])
    unary_prior[index_labeled, :] = -1
    unary_prior[index_labeled, train_labels[index_labeled]] = 1
    AQ = unary_prior.abs()
    pd = degree.view(-1, 1) * (AQ + unary_prior) / 2
    nd = degree.view(-1, 1) * (AQ - unary_prior) / 2
    np_ratio = pd.sum(dim=0) / nd.sum(dim=0)
    unary_prior_norm = (pd / np_ratio).sqrt() - (nd * np_ratio).sqrt()
    unary_prior_norm = make_column_normalize(unary_prior_norm)
    
    # logits and prediction
    alpha = 0
    lambda_reverse = (1 / (eigenvalues - alpha)).view(1, -1)
    logits = torch.mm(lambda_reverse * eigenvectors, torch.mm(eigenvectors.t(), unary_prior_norm))
    logits = make_column_normalize(logits) * math.sqrt(logits.shape[0])   
    logits = logits - logits.max(1, keepdim=True)[0]
    _, predict = logits.max(dim=1)
    
    for temperature_nc in [1]:#, 2, 3, 5, 10, 15, 20, 25, 30, 35, 40, 100]:  
        # pseudo weights
        logits_sorted = logits.sort(dim=1, descending=True)[0]
        subtract = logits_sorted[:, 0] - logits_sorted[:, 1]
        pseudo_weights = 1 - torch.exp(- subtract / temperature_nc)
        
        exp = (logits * temperature_nc).exp()
        probs = exp / exp.sum(1, keepdim=True)
        probs_sorted, predict_all = probs.sort(1, True)
        assert torch.all(predict == predict_all[:, 0])

        idx = pseudo_weights[index_unlabeled].sort(dim=0, descending=True)[1]
        pseudo_indexes = index_unlabeled[idx]
        pseudo_labels = predict[index_unlabeled][idx]
        pseudo_probs = probs[index_unlabeled][idx]
        pseudo_weights = pseudo_weights[index_unlabeled][idx]
        assert torch.all(pseudo_labels == pseudo_probs.max(1)[1])
        
        save_dict = {
            'pseudo_indexes': pseudo_indexes,
            'pseudo_labels': pseudo_labels,
            'pseudo_probs': pseudo_probs,
            'pseudo_weights': pseudo_weights,
            'labeled_indexes': index_labeled,
            'unlabeled_indexes': index_unlabeled,
        }
        torch.save(save_dict, os.path.join(args.save_path, '{}.pth.tar'.format(num_labeled_data)))

        # for plot
        correct = pseudo_labels == train_labels[pseudo_indexes]
        
        entropy = - (pseudo_probs * torch.log(pseudo_probs + 1e-7)).sum(dim=1)
        confidence = (- entropy * 1).exp()
        confidence /= confidence.max()

        arange = 1 + np.arange(confidence.shape[0])
        xs = arange / confidence.shape[0]
        correct_tmp = correct[confidence.sort(descending=True)[1]]
        accuracies = np.cumsum(correct_tmp.numpy()) / arange
        plt.plot(xs, accuracies, label='num_labeled_data={}'.format(num_labeled_data))

        acc = correct.float().mean()

        print('num_labeled={:4} T_nc={}, prec={:.2f}, AUC={:.2f}'.format(
            num_labeled_data, temperature_nc, acc * 100, accuracies.mean() * 100))
    
plt.xlabel('accumulated unlabeled data ratio')
plt.ylabel('unlabeled top1 accuracy')
plt.xticks(np.arange(0, 1.01, 0.1))
plt.grid()
plt.title('num_eigenvectors={}'.format(num_eigenvectors))
legend = plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
plt.show()

num_labeled=  50 T_nc=1, prec=48.40, AUC=60.85
num_labeled= 100 T_nc=1, prec=51.91, AUC=67.34
num_labeled= 250 T_nc=1, prec=61.03, AUC=76.31
num_labeled= 500 T_nc=1, prec=64.05, AUC=80.04
num_labeled=1000 T_nc=1, prec=64.84, AUC=81.78
num_labeled=2000 T_nc=1, prec=64.84, AUC=81.89
num_labeled=4000 T_nc=1, prec=65.60, AUC=82.93
num_labeled=8000 T_nc=1, prec=65.11, AUC=82.03
