In [8]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.models as models
from tqdm import tqdm
from sklearn.cluster import KMeans
from scipy.stats import dlaplace
from sklearn.cluster import SpectralClustering
import pickle
from mnist_models import InceptionMnist
import os
import scipy
import numpy as np

In [3]:
prior_name = "indomain" # "outdomain"
dataset = "mnist" # "cifar10"
if prior_name == "outdomain":
    checkpoint_name = "dino_vits8"
    model = torch.hub.load('facebookresearch/dino:main', checkpoint_name)
    num_clusters = 50
    eps_p = 0.025
    if dataset == "cifar10":
        transform_train = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
    elif dataset == "mnist":
        transform_train = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.repeat(3,1,1)),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
        
elif prior_name == "indomain":
    num_clusters = 100
    eps_p = 0.05
    if dataset == "cifar10":
        model = models.resnet18(pretrained=False)
        device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
        model.load_state_dict(torch.load("./BYOL/resnet18-CIFAR10-final.pt", map_location=device))
        model =  torch.nn.Sequential(*list(model.children())[:-1])
        transform_train = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
        ])
    elif dataset == "mnist":
        model = InceptionMnist(1, 10)
        checkpoint = torch.load("./model.tar")
        model.load_state_dict(checkpoint)
        model.linear = torch.nn.Identity()
        MNIST_MEAN = (0.1307,)
        MNIST_STD = (0.3081,)
        normalize = [
            transforms.ToTensor(),
            transforms.Normalize(MNIST_MEAN, MNIST_STD),
        ]
        transform_train = transforms.Compose(normalize)

In [4]:
batch_size = 256
if dataset == "cifar10":
    vanila_trainset = datasets.CIFAR10(root='~/data', train=True, download=True,
                                transform=transform_train)
    trainloader = torch.utils.data.DataLoader(vanila_trainset,
                                              batch_size=batch_size,
                                              shuffle=False, num_workers=8)
elif dataset == "mnist":
    vanila_trainset = datasets.MNIST(root="~/data", train=True, transform=transform_train, download=True)
    trainloader = torch.utils.data.DataLoader(vanila_trainset,
                                              batch_size=batch_size,
                                              shuffle=False, num_workers=8)

In [5]:
net = model.cuda()

In [6]:
embeddings = None
for batch_idx, (inputs, targets) in tqdm(enumerate(trainloader), total=len(trainloader)):
    net.eval()
    with torch.no_grad():
        inputs, targets = inputs.cuda(), targets.cuda()
        outputs = net(inputs)
        outputs = outputs.view(len(targets), -1)
        if embeddings is None:
            embeddings = torch.zeros([len(vanila_trainset), outputs.size()[1]])
        embeddings[batch_size * batch_idx:batch_size * (batch_idx + 1)] = outputs 

100%|██████████| 235/235 [00:05<00:00, 40.44it/s]


In [7]:
torch.save(embeddings, f"pretrained_priors/{dataset}_{prior_name}_embeddings.pl")

torch.Size([60000, 8400])


In [8]:
embeddings_path = f"pretrained_priors/{dataset}_{prior_name}_embeddings.pl"
embeddings = torch.load(embeddings_path)
targets = torch.tensor(vanila_trainset.targets)

torch.Size([60000, 8400])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]]) tensor([5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, 7, 2, 8, 6, 9])


  This is separate from the ipykernel package so we can avoid doing imports until


In [12]:
#kmeans clustering
kmeans = KMeans(n_clusters=num_clusters, random_state=0, verbose=1).fit(embeddings)

Initialization complete
Iteration 0, inertia 9446305.756535856
Iteration 1, inertia 5885707.127701719
Iteration 2, inertia 5722827.460379831
Iteration 3, inertia 5647301.723425427
Iteration 4, inertia 5603909.109570508
Iteration 5, inertia 5584045.747171975
Iteration 6, inertia 5569354.163598801
Iteration 7, inertia 5559171.288699869
Iteration 8, inertia 5554280.132724583
Iteration 9, inertia 5550829.098225283
Iteration 10, inertia 5547692.310307057
Iteration 11, inertia 5544221.040181344
Iteration 12, inertia 5541282.818534794
Iteration 13, inertia 5539650.514564337
Iteration 14, inertia 5538157.1632159455
Iteration 15, inertia 5536892.686957562
Iteration 16, inertia 5536057.617681321
Iteration 17, inertia 5535442.241048367
Iteration 18, inertia 5534666.360081406
Iteration 19, inertia 5534443.836012003
Iteration 20, inertia 5534282.36364709
Iteration 21, inertia 5534106.562209493
Iteration 22, inertia 5533844.577870936
Iteration 23, inertia 5533416.686562998
Iteration 24, inertia 5532

Initialization complete
Iteration 0, inertia 9370562.75403087
Iteration 1, inertia 5835475.275671811
Iteration 2, inertia 5695700.510605563
Iteration 3, inertia 5644222.2859391505
Iteration 4, inertia 5612673.074697799
Iteration 5, inertia 5595871.24179535
Iteration 6, inertia 5586254.903384874
Iteration 7, inertia 5578592.7315942375
Iteration 8, inertia 5574963.756430222
Iteration 9, inertia 5571714.139875464
Iteration 10, inertia 5570465.13747359
Iteration 11, inertia 5569795.474766269
Iteration 12, inertia 5568690.506439738
Iteration 13, inertia 5568286.700102679
Iteration 14, inertia 5567742.545839316
Iteration 15, inertia 5566232.004242356
Iteration 16, inertia 5563979.5767739825
Iteration 17, inertia 5561725.584770252
Iteration 18, inertia 5560705.542134244
Iteration 19, inertia 5560326.012207529
Iteration 20, inertia 5560069.754038954
Iteration 21, inertia 5559859.47723235
Iteration 22, inertia 5559829.326168074
Iteration 23, inertia 5559815.430643644
Iteration 24, inertia 55598

Iteration 6, inertia 5598730.624508297
Iteration 7, inertia 5584838.287100725
Iteration 8, inertia 5573696.983267502
Iteration 9, inertia 5564965.408341785
Iteration 10, inertia 5558691.523624789
Iteration 11, inertia 5552979.128193947
Iteration 12, inertia 5548233.921633626
Iteration 13, inertia 5543100.187440188
Iteration 14, inertia 5540139.52666965
Iteration 15, inertia 5539029.474028342
Iteration 16, inertia 5538627.052897787
Iteration 17, inertia 5538316.612095686
Iteration 18, inertia 5537765.286485073
Iteration 19, inertia 5537111.469402397
Iteration 20, inertia 5536798.124486733
Iteration 21, inertia 5536484.42882322
Iteration 22, inertia 5536158.709307335
Iteration 23, inertia 5535886.536260902
Iteration 24, inertia 5535785.144374766
Iteration 25, inertia 5535744.5972583955
Iteration 26, inertia 5535667.667452945
Iteration 27, inertia 5535486.251105941
Iteration 28, inertia 5535290.261888112
Iteration 29, inertia 5535169.71848328
Iteration 30, inertia 5535157.906737346
Iterat

In [13]:
checkpoint = {
    "cluster_labels": kmeans.labels_,
    "cluster_centers": kmeans.cluster_centers_
}
torch.save(checkpoint, f"pretrained_priors/{dataset}_{prior_name}_kmeans.pl")

In [14]:
checkpoint = torch.load(f"pretrained_priors/{dataset}_{prior_name}_kmeans.pl")
cluster_labels = checkpoint["cluster_labels"]

In [27]:
if dataset.startswith("raw_dataset"):
    num_classes = 2
else:
    num_classes = 10
prior_probs = torch.zeros([len(targets), num_classes])
targets = torch.tensor(vanila_trainset.targets)
correct = 0
torch.manual_seed(0)
for c in range(num_clusters):
    if_c = (cluster_labels == c)
    targets_c = targets[if_c]
    histo = torch.histc(targets_c.float(), bins=num_classes, min=-0.5, max=num_classes-0.5)
    correct += (targets_c == torch.argmax(histo)).sum()
    histo = histo + torch.tensor(dlaplace.rvs(eps_p/2, num_classes))
    histo = torch.clamp(histo, min=0)
    histo = histo / histo.sum()
    prior_probs[if_c] = histo

tensor([0.9378, 0.0622])
tensor([0.9628, 0.0372])
tensor([0.9683, 0.0317])
tensor([0.9588, 0.0412])
tensor([0.9830, 0.0170])
tensor([0.9699, 0.0301])
tensor([0.9698, 0.0302])
tensor([0.9743, 0.0257])
tensor([0.9601, 0.0399])
tensor([0.9570, 0.0430])
tensor([0.9774, 0.0226])
tensor([0.9757, 0.0243])
tensor([0.9712, 0.0288])
tensor([0.9442, 0.0558])
tensor([0.9608, 0.0392])
tensor([0.9817, 0.0183])
tensor([0.9783, 0.0217])
tensor([0.9769, 0.0231])
tensor([0.9825, 0.0175])
tensor([0.9839, 0.0161])
tensor([0.9691, 0.0309])
tensor([0.9770, 0.0230])
tensor([0.9434, 0.0566])
tensor([0.9734, 0.0266])
tensor([0.9727, 0.0273])
tensor([0.9763, 0.0237])
tensor([0.9670, 0.0330])
tensor([0.9597, 0.0403])
tensor([0.9811, 0.0189])
tensor([0.9947, 0.0053])
tensor([0.9627, 0.0373])
tensor([0.9662, 0.0338])
tensor([0.9765, 0.0235])
tensor([0.9784, 0.0216])
tensor([0.9710, 0.0290])
tensor([0.9679, 0.0321])
tensor([0.9780, 0.0220])
tensor([0.9301, 0.0699])
tensor([0.9582, 0.0418])
tensor([0.9406, 0.0594])


In [29]:
torch.save(prior_probs, f"pretrained_priors/{dataset}_{prior_name}_prior_probs.pl")