In [4]:
import sys
sys.path.append("../")

import torch
import torchvision
from torch.utils.data import DataLoader

from simple_ijepa.model import VisionTransformer
from simple_ijepa.utils import inference_transforms

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
IMAGE_SIZE = 96
transform = inference_transforms(img_size=(IMAGE_SIZE, IMAGE_SIZE))

train_ds = torchvision.datasets.STL10("../data/",
                                  split='train',
                                  transform=transform,
                                  download=True)
val_ds = torchvision.datasets.STL10("../data",
                                  split='test',
                                  transform=transform,
                                  download=True)

train_loader = DataLoader(train_ds,
                          batch_size=128,
                          num_workers=4)
val_loader = DataLoader(val_ds,
                       batch_size=128,
                       num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


In [7]:
len(train_ds), len(val_ds)

(5000, 8000)

In [8]:
ckpt = torch.load("../models/encoder_best.pth")

dim = 512
model = VisionTransformer(image_size=96, patch_size=8, dim=dim, depth=6, heads=6, mlp_dim=dim * 2)

model.load_state_dict(ckpt)

model = model.eval().to(device)

In [None]:
from tqdm.auto import tqdm
import numpy as np

def get_embs_labels(dl):
    embs, labels = [], []
    for images, targets in tqdm(dl):
        with torch.no_grad():
            images = images.to(device)
            out = model(images)
            features = out.cpu().detach()
            features = features.mean(dim = 1)
            
            embs.extend(features.tolist())
            labels.extend(targets.cpu().detach().tolist())

    return np.array(embs), np.array(labels)

In [79]:
embeddings, labels = get_embs_labels(train_loader)
embeddings_val, labels_val = get_embs_labels(val_loader)

100%|██████████| 40/40 [00:02<00:00, 14.12it/s]
100%|██████████| 63/63 [00:04<00:00, 13.58it/s]


In [80]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

def eval():
    X_train, X_test = embeddings, embeddings_val
    y_train, y_test = labels, labels_val

    print("train", X_train.shape[0], len(y_train))
    print("test", X_test.shape[0], len(y_test))
    
    clf = LogisticRegression(max_iter=100)
    
    clf.fit(X_train, y_train)
    
    y_pred = clf.predict(X_test)
    
    acc = accuracy_score(y_test, y_pred)
    conf_matrix = confusion_matrix(y_test, y_pred)
    class_report = classification_report(y_test, y_pred)
    
    print("Accuracy: ", acc)
    print("Confusion matrix: \n", conf_matrix)
    print("Classification report: \n", class_report)
    
    y_pred_train = clf.predict(X_train)
    class_report = classification_report(y_train, y_pred_train)
    print("Classification report train: \n", class_report)

In [81]:
eval()

train 5000 5000
test 8000 8000
Accuracy:  0.77075
Confusion matrix: 
 [[703  14  12   3   3   1   4   2  39  19]
 [ 23 619   1  53  17  29   8  44   5   1]
 [ 23   4 694   3   1   1   3   1  10  60]
 [  0  40   1 509  66  88  16  76   3   1]
 [  1  28   3  53 590  52  49  20   0   4]
 [  1  28   2  95  53 445  83  88   2   3]
 [  1   9   1  13  28  86 631  22   1   8]
 [  2  41   1  47  32  72  14 588   1   2]
 [ 31   2   3   2   1   0   1   0 728  32]
 [ 27   2  40   3   1   1   4   4  59 659]]
Classification report: 
               precision    recall  f1-score   support

           0       0.87      0.88      0.87       800
           1       0.79      0.77      0.78       800
           2       0.92      0.87      0.89       800
           3       0.65      0.64      0.64       800
           4       0.74      0.74      0.74       800
           5       0.57      0.56      0.57       800
           6       0.78      0.79      0.78       800
           7       0.70      0.73      0.

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
