In [1]:
import sys
sys.path.append("../")

from PIL import Image

import torch
import torchvision
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset

from encoders import resnet18, resnet50
from aug import get_relic_aug_inference, get_relic_aug

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
transform = get_relic_aug_inference()

train_ds = torchvision.datasets.STL10("../data",
                                  split='train',
                                  transform=transform,
                                  download=True)
val_ds = torchvision.datasets.STL10("../data",
                                  split='test',
                                  transform=transform,
                                  download=True)

train_loader = DataLoader(train_ds,
                          batch_size=256,
                          num_workers=4)
val_loader = DataLoader(val_ds,
                       batch_size=256,
                       num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
len(train_ds), len(val_ds)

(5000, 8000)

In [6]:
# ckpt = torch.load("../models/sota_stl/encoder.pth")
ckpt = torch.load("../models/sota_stl_2048/encoder.pth")

# model = resnet18()
model = resnet50()
model.load_state_dict(ckpt)

model = model.eval().to(device)

In [7]:
from tqdm.auto import tqdm
import numpy as np

def get_embs_labels(dl):
    idx = 0
    embs, labels = [], []
    for idx, (images, targets) in enumerate(tqdm(dl)):
        with torch.no_grad():
            images = images.to(device)
            out = model(images)
            features = out.cpu().detach().tolist()
            embs.extend(features)
            labels.extend(targets.cpu().detach().tolist())
    return np.array(embs), np.array(labels)

In [8]:
embeddings, labels = get_embs_labels(train_loader)
embeddings_val, labels_val = get_embs_labels(val_loader)

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fab049f2050>
Traceback (most recent call last):
  File "/home/wavelet/projects/relic/.env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/home/wavelet/projects/relic/.env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fab049f2050>
Traceback (most recent call last):
  File "/home/wavelet/projects/relic/.env/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fab04

In [9]:
print(embeddings.shape)
print(labels.shape)
print(embeddings_val.shape)
print(labels_val.shape)

(5000, 512)
(5000,)
(8000, 512)
(8000,)


### Resnet 50

In [10]:
from sklearn.calibration import CalibratedClassifierCV
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.model_selection import train_test_split
 
X_train, X_test = embeddings, embeddings_val
y_train, y_test = labels, labels_val

print("train", X_train.shape[0], len(y_train))
print("test", X_test.shape[0], len(y_test))
 
clf = LogisticRegression(max_iter=100)
clf = CalibratedClassifierCV(clf)
 
clf.fit(X_train, y_train)
 
y_pred = clf.predict(X_test)
 
acc = accuracy_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)
class_report = classification_report(y_test, y_pred)
 
print("Accuracy: ", acc)
print("Confusion matrix: \n", conf_matrix)
print("Classification report: \n", class_report)
 
y_pred_train = clf.predict(X_train)
class_report = classification_report(y_train, y_pred_train)
print("Classification report train: \n", class_report)

train 5000 5000
test 8000 8000


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

Accuracy:  0.804
Confusion matrix: 
 [[710  32   3  14   1   3   6   1  18  12]
 [  8 641   0  43  11  33   7  56   1   0]
 [ 16   2 723   9   0   5   3   1   5  36]
 [  0  46   0 540  55  82  19  58   0   0]
 [  0  30   0  45 630  32  41  22   0   0]
 [  1  33   0 111  41 470  67  76   0   1]
 [  0   8   0  11  32  48 672  24   0   5]
 [  0  57   0  53  18  45  13 614   0   0]
 [ 26   1   3   6   1   1   0   0 744  18]
 [ 10   2  35   9   1   2   7   2  44 688]]
Classification report: 
               precision    recall  f1-score   support

           0       0.92      0.89      0.90       800
           1       0.75      0.80      0.78       800
           2       0.95      0.90      0.92       800
           3       0.64      0.68      0.66       800
           4       0.80      0.79      0.79       800
           5       0.65      0.59      0.62       800
           6       0.80      0.84      0.82       800
           7       0.72      0.77      0.74       800
           8       0

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


### Resnet 18

In [11]:
from sklearn.calibration import CalibratedClassifierCV
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.model_selection import train_test_split
 
X_train, X_test = embeddings, embeddings_val
y_train, y_test = labels, labels_val

print("train", X_train.shape[0], len(y_train))
print("test", X_test.shape[0], len(y_test))
 
clf = LogisticRegression(max_iter=100)
clf = CalibratedClassifierCV(clf)
 
clf.fit(X_train, y_train)
 
y_pred = clf.predict(X_test)
 
acc = accuracy_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)
class_report = classification_report(y_test, y_pred)
 
print("Accuracy: ", acc)
print("Confusion matrix: \n", conf_matrix)
print("Classification report: \n", class_report)
 
y_pred_train = clf.predict(X_train)
class_report = classification_report(y_train, y_pred_train)
print("Classification report train: \n", class_report)

train 5000 5000
test 8000 8000


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

Accuracy:  0.761
Confusion matrix: 
 [[700  30  14   5   0   4   5   0  31  11]
 [ 18 592   1  62  16  27   8  71   4   1]
 [ 17   3 726   3   0   4   3   3   6  35]
 [  0  72   0 488  61  87  18  71   1   2]
 [  6  28   1  43 594  37  58  30   3   0]
 [  1  37   3 125  43 410  93  87   1   0]
 [  4   7   4  14  39  49 655  22   4   2]
 [  3  61   0  69  16  66  28 557   0   0]
 [ 29   4   8   1   0   2   0   0 738  18]
 [ 21   2  63  13   0   1   8   2  62 628]]
Classification report: 
               precision    recall  f1-score   support

           0       0.88      0.88      0.88       800
           1       0.71      0.74      0.72       800
           2       0.89      0.91      0.90       800
           3       0.59      0.61      0.60       800
           4       0.77      0.74      0.76       800
           5       0.60      0.51      0.55       800
           6       0.75      0.82      0.78       800
           7       0.66      0.70      0.68       800
           8       0

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
