In [1]:
import sys
sys.path.append("../")

from PIL import Image

import torch
import torchvision
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset

from encoders import resnet18, resnet50
from aug import get_relic_aug_inference, get_relic_aug

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
transform = get_relic_aug_inference()

train_ds = torchvision.datasets.STL10("../data",
                                  split='train',
                                  transform=transform,
                                  download=True)
val_ds = torchvision.datasets.STL10("../data",
                                  split='test',
                                  transform=transform,
                                  download=True)

train_loader = DataLoader(train_ds,
                          batch_size=256,
                          num_workers=4)
val_loader = DataLoader(val_ds,
                       batch_size=256,
                       num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
len(train_ds), len(val_ds)

(5000, 8000)

In [5]:
# ckpt = torch.load("../models/sota_stl/encoder.pth")
# ckpt = torch.load("../models/sota_stl_2048/encoder.pth")
ckpt = torch.load("../models/encoder.pth")

# model = resnet18()
model = resnet50()
model.load_state_dict(ckpt)

model = model.eval().to(device)

In [6]:
from tqdm.auto import tqdm
import numpy as np

def get_embs_labels(dl):
    idx = 0
    embs, labels = [], []
    for idx, (images, targets) in enumerate(tqdm(dl)):
        with torch.no_grad():
            images = images.to(device)
            out = model(images)
            features = out.cpu().detach().tolist()
            embs.extend(features)
            labels.extend(targets.cpu().detach().tolist())
    return np.array(embs), np.array(labels)

In [7]:
embeddings, labels = get_embs_labels(train_loader)
embeddings_val, labels_val = get_embs_labels(val_loader)

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

In [8]:
print(embeddings.shape)
print(labels.shape)
print(embeddings_val.shape)
print(labels_val.shape)

(5000, 2048)
(5000,)
(8000, 2048)
(8000,)


### Resnet 50

In [10]:
from sklearn.calibration import CalibratedClassifierCV
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.model_selection import train_test_split
 
X_train, X_test = embeddings, embeddings_val
y_train, y_test = labels, labels_val

print("train", X_train.shape[0], len(y_train))
print("test", X_test.shape[0], len(y_test))
 
clf = LogisticRegression(max_iter=100)
clf = CalibratedClassifierCV(clf)
 
clf.fit(X_train, y_train)
 
y_pred = clf.predict(X_test)
 
acc = accuracy_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)
class_report = classification_report(y_test, y_pred)
 
print("Accuracy: ", acc)
print("Confusion matrix: \n", conf_matrix)
print("Classification report: \n", class_report)
 
y_pred_train = clf.predict(X_train)
class_report = classification_report(y_train, y_pred_train)
print("Classification report train: \n", class_report)

train 5000 5000
test 8000 8000


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

Accuracy:  0.78525
Confusion matrix: 
 [[720  27   7   6   1   4   5   0  21   9]
 [  9 627   0  53  16  27  11  57   0   0]
 [ 16   2 725   6   0   2   3   1   7  38]
 [  0  57   0 541  43  73  12  69   4   1]
 [  4  27   1  46 614  34  49  24   1   0]
 [  2  44   1 120  41 429  70  91   1   1]
 [  4  14   3   5  29  48 659  34   1   3]
 [  0  76   0  66  21  42  20 575   0   0]
 [ 39   1   6   1   0   0   1   0 733  19]
 [ 17   1  47   7   0   1   4   3  61 659]]
Classification report: 
               precision    recall  f1-score   support

           0       0.89      0.90      0.89       800
           1       0.72      0.78      0.75       800
           2       0.92      0.91      0.91       800
           3       0.64      0.68      0.66       800
           4       0.80      0.77      0.78       800
           5       0.65      0.54      0.59       800
           6       0.79      0.82      0.81       800
           7       0.67      0.72      0.70       800
           8      

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


### Resnet 18

In [9]:
from sklearn.calibration import CalibratedClassifierCV
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.model_selection import train_test_split
 
X_train, X_test = embeddings, embeddings_val
y_train, y_test = labels, labels_val

print("train", X_train.shape[0], len(y_train))
print("test", X_test.shape[0], len(y_test))
 
clf = LogisticRegression(max_iter=100)
clf = CalibratedClassifierCV(clf)
 
clf.fit(X_train, y_train)
 
y_pred = clf.predict(X_test)
 
acc = accuracy_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)
class_report = classification_report(y_test, y_pred)
 
print("Accuracy: ", acc)
print("Confusion matrix: \n", conf_matrix)
print("Classification report: \n", class_report)
 
y_pred_train = clf.predict(X_train)
class_report = classification_report(y_train, y_pred_train)
print("Classification report train: \n", class_report)

train 5000 5000
test 8000 8000


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

Accuracy:  0.75975
Confusion matrix: 
 [[710  27  12   5   3   2   7   0  23  11]
 [ 19 613   1  56  13  19  14  59   5   1]
 [ 19   0 716  12   0   4   2   1   7  39]
 [  0  66   0 535  57  60  17  62   3   0]
 [  1  29   2  60 587  26  69  24   2   0]
 [  1  53   1 127  53 375  94  95   1   0]
 [  0  14   1  12  37  38 666  26   2   4]
 [  0  86   0  83  32  42  34 523   0   0]
 [ 36   3   3   3   0   1   0   0 733  21]
 [ 30   3  60   9   0   4   9   2  63 620]]
Classification report: 
               precision    recall  f1-score   support

           0       0.87      0.89      0.88       800
           1       0.69      0.77      0.72       800
           2       0.90      0.90      0.90       800
           3       0.59      0.67      0.63       800
           4       0.75      0.73      0.74       800
           5       0.66      0.47      0.55       800
           6       0.73      0.83      0.78       800
           7       0.66      0.65      0.66       800
           8      

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
