In [1]:
import sys
sys.path.append("../")

from PIL import Image

import torch
import torchvision
from torch.utils.data import DataLoader

from scl.encoders import resnet18, resnet50
from scl.aug import get_inference_transforms

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
IMAGE_SIZE = 32

transform = get_inference_transforms(image_size=(IMAGE_SIZE, IMAGE_SIZE))
train_ds = torchvision.datasets.CIFAR10("../data/cifar",
                                  train=True,
                                  transform=transform,
                                  download=True)
val_ds = torchvision.datasets.CIFAR10("../data/cifar",
                                  train=False,
                                  transform=transform,
                                  download=True)

train_loader = DataLoader(train_ds,
                          batch_size=256,
                          num_workers=4)
val_loader = DataLoader(val_ds,
                       batch_size=256,
                       num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
len(train_ds), len(val_ds)

(50000, 10000)

In [6]:
ckpt = torch.load("../models/encoder.pth")

# model = resnet50()
model = resnet18(modify_model=True)
model.load_state_dict(ckpt)

model = model.eval().to(device)

In [11]:
from tqdm.auto import tqdm
import numpy as np
import torch.nn.functional as F

def get_embs_labels(dl):
    idx = 0
    embs, labels = [], []
    for idx, (images, targets) in enumerate(tqdm(dl)):
        with torch.no_grad():
            images = images.to(device)
            out = model(images)
            features = out.cpu().detach()
            features = F.normalize(features, p=2, dim=-1)
            embs.extend(features.tolist())
            labels.extend(targets.cpu().detach().tolist())
    return np.array(embs), np.array(labels)

In [12]:
embeddings, labels = get_embs_labels(train_loader)
embeddings_val, labels_val = get_embs_labels(val_loader)

100%|██████████| 196/196 [00:03<00:00, 62.22it/s]
100%|██████████| 40/40 [00:00<00:00, 45.36it/s]


In [13]:
from sklearn.calibration import CalibratedClassifierCV
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

def eval():
    X_train, X_test = embeddings, embeddings_val
    y_train, y_test = labels, labels_val

    print("train", X_train.shape[0], len(y_train))
    print("test", X_test.shape[0], len(y_test))
    
    clf = LogisticRegression(max_iter=100)
    clf = CalibratedClassifierCV(clf)
    
    clf.fit(X_train, y_train)
    
    y_pred = clf.predict(X_test)
    
    acc = accuracy_score(y_test, y_pred)
    conf_matrix = confusion_matrix(y_test, y_pred)
    class_report = classification_report(y_test, y_pred)
    
    print("Accuracy: ", acc)
    print("Confusion matrix: \n", conf_matrix)
    print("Classification report: \n", class_report)
    
    y_pred_train = clf.predict(X_train)
    class_report = classification_report(y_train, y_pred_train)
    print("Classification report train: \n", class_report)

(50000, 512)
(50000,)
(10000, 512)
(10000,)


In [16]:
# resnet18
eval()

train 50000 50000
test 10000 10000


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

Accuracy:  0.8253
Confusion matrix: 
 [[879   0  43  22  12   3   7   5  27   2]
 [ 30 908   2   4   1   1   6   0  33  15]
 [ 38   0 750  41  72  15  63  17   3   1]
 [ 13   1  58 699  35  65 104  18   2   5]
 [  4   0  69  30 787  10  51  47   2   0]
 [  2   1  33 191  36 659  43  35   0   0]
 [  5   0  20  28  10   3 929   2   1   2]
 [  5   0  19  35  22  16   5 896   0   2]
 [ 42   1   9  12   4   0   6   0 919   7]
 [ 57  51   6  16   0   0   4   1  38 827]]
Classification report: 
               precision    recall  f1-score   support

           0       0.82      0.88      0.85      1000
           1       0.94      0.91      0.93      1000
           2       0.74      0.75      0.75      1000
           3       0.65      0.70      0.67      1000
           4       0.80      0.79      0.80      1000
           5       0.85      0.66      0.74      1000
           6       0.76      0.93      0.84      1000
           7       0.88      0.90      0.89      1000
           8       