In [1]:
import sys
sys.path.append("../")

from PIL import Image

import torch
import torchvision
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset

from relic import ReLIC
from encoders import resnet18
from aug import get_relic_aug_inference

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
IMAGE_SIZE = 32

transform = get_relic_aug_inference()
train_ds = torchvision.datasets.CIFAR10("../data/cifar",
                                  train=True,
                                  transform=transform,
                                  download=True)
val_ds = torchvision.datasets.CIFAR10("../data/cifar",
                                  train=False,
                                  transform=transform,
                                  download=True)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
len(train_ds), len(val_ds)

(50000, 10000)

In [5]:
ckpt = torch.load("../models/relic_model.pth")


encoder = resnet18()
model = ReLIC(encoder, mlp_in_dim=512)
model.load_state_dict(ckpt)

model = model.eval().to(device)

In [6]:
from tqdm.auto import tqdm

def get_embs_targets(ds):
    idx = 0
    embs, targets = [], []
    for idx, (image, target) in enumerate(tqdm(ds)):
        with torch.no_grad():
            image = image.unsqueeze(0).to(device)
            out = model.get_features(image)
            features = out[0].cpu().detach().tolist()   
            embs.append(features)
            targets.append(target)
    return embs, targets

In [7]:
embs, targets = get_embs_targets(train_ds)
embs_val, targets_val = get_embs_targets(val_ds)

  0%|          | 0/50000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

In [8]:
import numpy as np

embeddings = np.array(embs)
labels = np.array(targets)
embeddings_val = np.array(embs_val)
labels_val = np.array(targets_val)

In [9]:
print(embeddings.shape)
print(labels.shape)
print(embeddings_val.shape)
print(labels_val.shape)

(50000, 512)
(50000,)
(10000, 512)
(10000,)


In [10]:
from sklearn.svm import LinearSVC, SVC
from sklearn.calibration import CalibratedClassifierCV
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression, SGDClassifier
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.model_selection import train_test_split
 
X_train, X_test = embeddings, embeddings_val
y_train, y_test = labels, labels_val

print("train", X_train.shape[0], len(y_train))
print("test", X_test.shape[0], len(y_test))
 
clf = LogisticRegression(max_iter=100)
clf = CalibratedClassifierCV(clf)
 
clf.fit(X_train, y_train)
 
y_pred = clf.predict(X_test)
 
acc = accuracy_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)
class_report = classification_report(y_test, y_pred)
 
print("Accuracy: ", acc)
print("Confusion matrix: \n", conf_matrix)
print("Classification report: \n", class_report)
 
y_pred_train = clf.predict(X_train)
class_report = classification_report(y_train, y_pred_train)
print("Classification report train: \n", class_report)

train 50000 50000
test 10000 10000


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

Accuracy:  0.6493
Confusion matrix: 
 [[653  24  58  26  22  14  24  16 114  49]
 [ 18 749   8  15   8   5  21  10  49 117]
 [ 70   5 487  92  83  75 114  39  17  18]
 [ 20  16  56 459  64 170 127  52  14  22]
 [ 15   9  73  69 532  53 110 116  15   8]
 [  9   6  53 174  53 532  71  80   6  16]
 [ 11   7  26  51  27  36 813  11  10   8]
 [  8  14  20  44  58  77  15 735  10  19]
 [ 75  48  18  18   7   2   3   9 790  30]
 [ 30 113  11  22   5  12  14  23  27 743]]
Classification report: 
               precision    recall  f1-score   support

           0       0.72      0.65      0.68      1000
           1       0.76      0.75      0.75      1000
           2       0.60      0.49      0.54      1000
           3       0.47      0.46      0.47      1000
           4       0.62      0.53      0.57      1000
           5       0.55      0.53      0.54      1000
           6       0.62      0.81      0.70      1000
           7       0.67      0.73      0.70      1000
           8       