In [20]:
from tqdm.notebook import tqdm
import clip
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

clip_model, clip_preprocess = clip.load("ViT-B/32", device)

cuda


In [21]:
from datasets import *
dataset_obj = Food101(0, 50)
train_loader, _ = dataset_obj.get_train_loaders(transform_fn=clip_preprocess)
test_loader = dataset_obj.get_test_loader(transform_fn=clip_preprocess)
classes = dataset_obj.classes

In [22]:
import copy

In [23]:
def get_clip_features(dataset):
    all_features = []
    all_labels = []

    global clip_model

    with torch.no_grad():
        for images, labels in tqdm(dataset):
            features = clip_model.encode_image(images.to(device))
            all_features.append(features)
            all_labels.append(labels)

    return torch.cat(all_features), torch.cat(all_labels)

train_features, train_labels = get_clip_features(train_loader)
test_features, test_labels = get_clip_features(test_loader)

  0%|          | 0/1515 [00:00<?, ?it/s]

  0%|          | 0/303 [00:00<?, ?it/s]

In [24]:
def batch(iterable1,iterable2, n=1):
    l = len(iterable1)
    for ndx in range(0, l, n):
        yield (iterable1[ndx:min(ndx + n, l)], iterable2[ndx:min(ndx + n, l)])

out_train = batch(train_features,train_labels, 50)
out_test = batch(test_features,test_labels, 50)

In [25]:
import torch.nn as nn
import torch.optim as optim

class LogisticRegression(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegression, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)
        
    def forward(self, x):
        outputs = self.linear(x)
        return outputs
    
model = LogisticRegression(512,len(classes))
model.logit_scale = nn.Parameter(torch.ones([], device=device))
criterion = nn.CrossEntropyLoss()
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=1e-5)
num_epochs = 500
# scheduler = optim.lr_scheduler.MultiStepLR(
#             optimizer, milestones=[300, 500, 700, 900], gamma=0.1
#         )

In [26]:
def num_correct_preds(outputs, labels):
    _, predicted = outputs.max(1)
    return predicted.eq(labels).sum().item()

In [27]:
def cosine_loss(output, target):
    loss = 1 - torch.cosine_similarity(output, target)
    return loss

def calc_loss(outputs, labels, loss_name="ce"):
    
    loss_labels = labels
    
    criterion = nn.CrossEntropyLoss()

    if loss_name == "ce":
        return criterion(outputs, loss_labels) 

    elif loss_name == "dot":
        outputs = outputs / outputs.norm(dim=-1, keepdim=True)
        return -(outputs * loss_labels).sum(-1).mean()

    elif loss_name == "cosine":
        print(outputs.shape, loss_labels.shape)
        loss = torch.mean(cosine_loss(outputs, loss_labels))
        return loss

    elif loss_name == "temperature_ce":
        image_features = outputs

        # normalized features
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        # cosine similarity as logits
        logit_scale = model.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.t()
        return criterion(logits_per_image, labels)

In [28]:
import torch
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [29]:
import random

model.to(device)
model = model.to(torch.float16)

best_model = None
best_acc = -1

for epoch in tqdm(range(num_epochs+1)):
    
    model.train()
    running_loss = 0.0
    correct = 0.0
    total = 0

    if epoch%50 == 0:
        print("Shuffling")
        c = list(zip(train_features, train_labels))
        random.shuffle(c)
        train_features, train_labels = zip(*c)
        train_features = torch.stack(list(train_features))
        train_labels = torch.stack(list(train_labels))
        
    out_train = batch(train_features,train_labels, 50)
    for inputs, labels in out_train:
        
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = calc_loss(outputs, labels)

        loss.backward()

        optimizer.step()
        running_loss+=loss
        total += len(labels)
        correct += num_correct_preds(outputs, labels)
        
    if epoch%10 == 0:
        model.eval()
        out_test = batch(test_features,test_labels, 50)
        
        test_running_loss = 0.0
        test_correct = 0.0
        test_total = 0
        with torch.no_grad():
            for inputs, labels in tqdm(out_test):
        #         inputs = inputs / inputs.norm(dim=-1, keepdim=True)
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                test_running_loss += loss.item()
                test_total += labels.size(0)
                test_correct += num_correct_preds(outputs, labels)

        epoch_loss = test_running_loss / len(test_loader)
        writer.add_scalar("Loss/val", epoch_loss, epoch)
        epoch_accuracy = test_correct * 100 / test_total
        writer.add_scalar("Accuracy/val", epoch_loss, epoch)
        if epoch_accuracy > best_acc:
            best_model= copy.deepcopy(model)
            best_acc = epoch_accuracy
            print("Found best model")
#         print(
#             f"Testing: Epoch {epoch} || Loss: {epoch_loss:7.3f} || Accuracy: {epoch_accuracy:6.2f}%"
#         )

        
    epoch_loss = running_loss/len(train_loader)
    writer.add_scalar("Loss/train", epoch_loss, epoch)
    epoch_accuracy = correct*100/total
    writer.add_scalar("Accuracy/train", epoch_loss, epoch)
    print(
        f"Training: Epoch {epoch} || Loss: {epoch_loss:7.3f} || Accuracy: {epoch_accuracy:6.2f}%"
    )

    writer.flush()
writer.close()


  0%|          | 0/501 [00:00<?, ?it/s]

Shuffling


0it [00:00, ?it/s]

Found best model
Training: Epoch 0 || Loss:   4.078 || Accuracy:  47.98%
Training: Epoch 1 || Loss:   2.416 || Accuracy:  73.54%
Training: Epoch 2 || Loss:   1.997 || Accuracy:  77.63%
Training: Epoch 3 || Loss:   1.348 || Accuracy:  79.24%
Training: Epoch 4 || Loss:   1.156 || Accuracy:  80.16%
Training: Epoch 5 || Loss:   1.066 || Accuracy:  80.71%
Training: Epoch 6 || Loss:   1.019 || Accuracy:  81.12%
Training: Epoch 7 || Loss:   0.977 || Accuracy:  81.43%
Training: Epoch 8 || Loss:   0.935 || Accuracy:  81.69%
Training: Epoch 9 || Loss:   0.898 || Accuracy:  81.92%


0it [00:00, ?it/s]

Found best model
Training: Epoch 10 || Loss:   0.861 || Accuracy:  82.08%
Training: Epoch 11 || Loss:   0.827 || Accuracy:  82.24%
Training: Epoch 12 || Loss:   0.801 || Accuracy:  82.40%
Training: Epoch 13 || Loss:   0.775 || Accuracy:  82.55%
Training: Epoch 14 || Loss:   0.746 || Accuracy:  82.65%
Training: Epoch 15 || Loss:   0.729 || Accuracy:  82.81%
Training: Epoch 16 || Loss:   0.713 || Accuracy:  82.90%
Training: Epoch 17 || Loss:   0.694 || Accuracy:  83.00%
Training: Epoch 18 || Loss:   0.676 || Accuracy:  83.10%
Training: Epoch 19 || Loss:   0.667 || Accuracy:  83.16%


0it [00:00, ?it/s]

Found best model
Training: Epoch 20 || Loss:   0.657 || Accuracy:  83.20%
Training: Epoch 21 || Loss:   0.652 || Accuracy:  83.27%
Training: Epoch 22 || Loss:   0.644 || Accuracy:  83.35%
Training: Epoch 23 || Loss:   0.639 || Accuracy:  83.41%
Training: Epoch 24 || Loss:   0.634 || Accuracy:  83.45%
Training: Epoch 25 || Loss:   0.629 || Accuracy:  83.52%
Training: Epoch 26 || Loss:   0.626 || Accuracy:  83.56%
Training: Epoch 27 || Loss:   0.624 || Accuracy:  83.60%
Training: Epoch 28 || Loss:   0.619 || Accuracy:  83.63%
Training: Epoch 29 || Loss:   0.617 || Accuracy:  83.66%


0it [00:00, ?it/s]

Found best model
Training: Epoch 30 || Loss:   0.616 || Accuracy:  83.69%
Training: Epoch 31 || Loss:   0.613 || Accuracy:  83.73%
Training: Epoch 32 || Loss:   0.610 || Accuracy:  83.76%
Training: Epoch 33 || Loss:   0.607 || Accuracy:  83.79%
Training: Epoch 34 || Loss:   0.607 || Accuracy:  83.84%
Training: Epoch 35 || Loss:   0.604 || Accuracy:  83.86%
Training: Epoch 36 || Loss:   0.603 || Accuracy:  83.90%
Training: Epoch 37 || Loss:   0.601 || Accuracy:  83.93%
Training: Epoch 38 || Loss:   0.599 || Accuracy:  83.93%
Training: Epoch 39 || Loss:   0.597 || Accuracy:  83.96%


0it [00:00, ?it/s]

Found best model
Training: Epoch 40 || Loss:   0.596 || Accuracy:  83.97%
Training: Epoch 41 || Loss:   0.594 || Accuracy:  84.00%
Training: Epoch 42 || Loss:   0.593 || Accuracy:  84.00%
Training: Epoch 43 || Loss:   0.591 || Accuracy:  84.02%
Training: Epoch 44 || Loss:   0.590 || Accuracy:  84.04%
Training: Epoch 45 || Loss:   0.588 || Accuracy:  84.06%
Training: Epoch 46 || Loss:   0.588 || Accuracy:  84.08%
Training: Epoch 47 || Loss:   0.586 || Accuracy:  84.09%
Training: Epoch 48 || Loss:   0.584 || Accuracy:  84.12%
Training: Epoch 49 || Loss:   0.583 || Accuracy:  84.15%
Shuffling


0it [00:00, ?it/s]

Training: Epoch 50 || Loss:   0.582 || Accuracy:  84.22%
Training: Epoch 51 || Loss:   0.581 || Accuracy:  84.20%
Training: Epoch 52 || Loss:   0.579 || Accuracy:  84.21%
Training: Epoch 53 || Loss:   0.578 || Accuracy:  84.21%
Training: Epoch 54 || Loss:   0.577 || Accuracy:  84.21%
Training: Epoch 55 || Loss:   0.577 || Accuracy:  84.22%
Training: Epoch 56 || Loss:   0.576 || Accuracy:  84.25%
Training: Epoch 57 || Loss:   0.576 || Accuracy:  84.27%
Training: Epoch 58 || Loss:   0.574 || Accuracy:  84.27%
Training: Epoch 59 || Loss:   0.573 || Accuracy:  84.29%


0it [00:00, ?it/s]

Found best model
Training: Epoch 60 || Loss:   0.572 || Accuracy:  84.30%
Training: Epoch 61 || Loss:   0.571 || Accuracy:  84.34%
Training: Epoch 62 || Loss:   0.571 || Accuracy:  84.36%
Training: Epoch 63 || Loss:   0.570 || Accuracy:  84.37%
Training: Epoch 64 || Loss:   0.569 || Accuracy:  84.37%
Training: Epoch 65 || Loss:   0.568 || Accuracy:  84.37%
Training: Epoch 66 || Loss:   0.568 || Accuracy:  84.39%
Training: Epoch 67 || Loss:   0.567 || Accuracy:  84.39%
Training: Epoch 68 || Loss:   0.567 || Accuracy:  84.41%
Training: Epoch 69 || Loss:   0.566 || Accuracy:  84.41%


0it [00:00, ?it/s]

Found best model
Training: Epoch 70 || Loss:   0.565 || Accuracy:  84.43%
Training: Epoch 71 || Loss:   0.564 || Accuracy:  84.44%
Training: Epoch 72 || Loss:   0.564 || Accuracy:  84.46%
Training: Epoch 73 || Loss:   0.563 || Accuracy:  84.48%
Training: Epoch 74 || Loss:   0.563 || Accuracy:  84.50%
Training: Epoch 75 || Loss:   0.562 || Accuracy:  84.50%
Training: Epoch 76 || Loss:   0.562 || Accuracy:  84.52%
Training: Epoch 77 || Loss:   0.561 || Accuracy:  84.53%
Training: Epoch 78 || Loss:   0.560 || Accuracy:  84.55%
Training: Epoch 79 || Loss:   0.560 || Accuracy:  84.56%


0it [00:00, ?it/s]

Found best model
Training: Epoch 80 || Loss:   0.559 || Accuracy:  84.58%
Training: Epoch 81 || Loss:   0.559 || Accuracy:  84.57%
Training: Epoch 82 || Loss:   0.558 || Accuracy:  84.58%
Training: Epoch 83 || Loss:   0.558 || Accuracy:  84.62%
Training: Epoch 84 || Loss:   0.557 || Accuracy:  84.62%
Training: Epoch 85 || Loss:   0.557 || Accuracy:  84.64%
Training: Epoch 86 || Loss:   0.557 || Accuracy:  84.67%
Training: Epoch 87 || Loss:   0.556 || Accuracy:  84.69%
Training: Epoch 88 || Loss:   0.556 || Accuracy:  84.69%
Training: Epoch 89 || Loss:   0.555 || Accuracy:  84.74%


0it [00:00, ?it/s]

Found best model
Training: Epoch 90 || Loss:   0.555 || Accuracy:  84.74%
Training: Epoch 91 || Loss:   0.555 || Accuracy:  84.75%
Training: Epoch 92 || Loss:   0.556 || Accuracy:  84.74%
Training: Epoch 93 || Loss:   0.555 || Accuracy:  84.74%
Training: Epoch 94 || Loss:   0.554 || Accuracy:  84.76%
Training: Epoch 95 || Loss:   0.553 || Accuracy:  84.76%
Training: Epoch 96 || Loss:   0.553 || Accuracy:  84.80%
Training: Epoch 97 || Loss:   0.552 || Accuracy:  84.79%
Training: Epoch 98 || Loss:   0.552 || Accuracy:  84.81%
Training: Epoch 99 || Loss:   0.552 || Accuracy:  84.79%
Shuffling


0it [00:00, ?it/s]

Training: Epoch 100 || Loss:   0.551 || Accuracy:  84.80%
Training: Epoch 101 || Loss:   0.551 || Accuracy:  84.82%
Training: Epoch 102 || Loss:   0.550 || Accuracy:  84.84%
Training: Epoch 103 || Loss:   0.550 || Accuracy:  84.85%
Training: Epoch 104 || Loss:   0.550 || Accuracy:  84.84%
Training: Epoch 105 || Loss:   0.550 || Accuracy:  84.85%
Training: Epoch 106 || Loss:   0.548 || Accuracy:  84.86%
Training: Epoch 107 || Loss:   0.548 || Accuracy:  84.86%
Training: Epoch 108 || Loss:   0.548 || Accuracy:  84.88%
Training: Epoch 109 || Loss:   0.547 || Accuracy:  84.92%


0it [00:00, ?it/s]

Found best model
Training: Epoch 110 || Loss:   0.547 || Accuracy:  84.93%
Training: Epoch 111 || Loss:   0.547 || Accuracy:  84.92%
Training: Epoch 112 || Loss:   0.546 || Accuracy:  84.92%
Training: Epoch 113 || Loss:   0.546 || Accuracy:  84.95%
Training: Epoch 114 || Loss:   0.546 || Accuracy:  84.93%
Training: Epoch 115 || Loss:   0.545 || Accuracy:  84.93%
Training: Epoch 116 || Loss:   0.545 || Accuracy:  84.92%
Training: Epoch 117 || Loss:   0.545 || Accuracy:  84.94%
Training: Epoch 118 || Loss:   0.545 || Accuracy:  84.98%
Training: Epoch 119 || Loss:   0.544 || Accuracy:  84.95%


0it [00:00, ?it/s]

Found best model
Training: Epoch 120 || Loss:   0.544 || Accuracy:  84.94%
Training: Epoch 121 || Loss:   0.544 || Accuracy:  84.94%
Training: Epoch 122 || Loss:   0.544 || Accuracy:  84.93%
Training: Epoch 123 || Loss:   0.545 || Accuracy:  84.92%
Training: Epoch 124 || Loss:   0.545 || Accuracy:  84.91%
Training: Epoch 125 || Loss:   0.545 || Accuracy:  84.91%
Training: Epoch 126 || Loss:   0.546 || Accuracy:  84.91%
Training: Epoch 127 || Loss:   0.547 || Accuracy:  84.87%
Training: Epoch 128 || Loss:   0.548 || Accuracy:  84.85%
Training: Epoch 129 || Loss:   0.548 || Accuracy:  84.85%


0it [00:00, ?it/s]

Training: Epoch 130 || Loss:   0.548 || Accuracy:  84.82%
Training: Epoch 131 || Loss:   0.548 || Accuracy:  84.82%
Training: Epoch 132 || Loss:   0.547 || Accuracy:  84.80%
Training: Epoch 133 || Loss:   0.548 || Accuracy:  84.78%
Training: Epoch 134 || Loss:   0.548 || Accuracy:  84.79%
Training: Epoch 135 || Loss:   0.548 || Accuracy:  84.79%
Training: Epoch 136 || Loss:   0.547 || Accuracy:  84.81%
Training: Epoch 137 || Loss:   0.548 || Accuracy:  84.83%
Training: Epoch 138 || Loss:   0.548 || Accuracy:  84.87%
Training: Epoch 139 || Loss:   0.547 || Accuracy:  84.89%


0it [00:00, ?it/s]

Training: Epoch 140 || Loss:   0.546 || Accuracy:  84.86%
Training: Epoch 141 || Loss:   0.546 || Accuracy:  84.87%
Training: Epoch 142 || Loss:   0.545 || Accuracy:  84.89%
Training: Epoch 143 || Loss:   0.545 || Accuracy:  84.93%
Training: Epoch 144 || Loss:   0.544 || Accuracy:  84.95%
Training: Epoch 145 || Loss:   0.544 || Accuracy:  84.94%
Training: Epoch 146 || Loss:   0.544 || Accuracy:  84.93%
Training: Epoch 147 || Loss:   0.544 || Accuracy:  84.93%
Training: Epoch 148 || Loss:   0.544 || Accuracy:  84.97%
Training: Epoch 149 || Loss:   0.544 || Accuracy:  84.96%
Shuffling


0it [00:00, ?it/s]

Training: Epoch 150 || Loss:   0.546 || Accuracy:  84.95%
Training: Epoch 151 || Loss:   0.546 || Accuracy:  84.97%
Training: Epoch 152 || Loss:   0.546 || Accuracy:  84.96%
Training: Epoch 153 || Loss:   0.546 || Accuracy:  84.98%
Training: Epoch 154 || Loss:   0.546 || Accuracy:  85.02%
Training: Epoch 155 || Loss:   0.546 || Accuracy:  85.03%
Training: Epoch 156 || Loss:   0.546 || Accuracy:  85.03%
Training: Epoch 157 || Loss:   0.546 || Accuracy:  85.05%
Training: Epoch 158 || Loss:   0.545 || Accuracy:  85.06%
Training: Epoch 159 || Loss:   0.546 || Accuracy:  85.07%


0it [00:00, ?it/s]

Training: Epoch 160 || Loss:   0.546 || Accuracy:  85.03%
Training: Epoch 161 || Loss:   0.546 || Accuracy:  85.03%
Training: Epoch 162 || Loss:   0.546 || Accuracy:  85.04%
Training: Epoch 163 || Loss:   0.547 || Accuracy:  85.05%
Training: Epoch 164 || Loss:   0.546 || Accuracy:  85.06%
Training: Epoch 165 || Loss:   0.546 || Accuracy:  85.06%
Training: Epoch 166 || Loss:   0.546 || Accuracy:  85.06%
Training: Epoch 167 || Loss:   0.546 || Accuracy:  85.05%
Training: Epoch 168 || Loss:   0.546 || Accuracy:  85.07%
Training: Epoch 169 || Loss:   0.547 || Accuracy:  85.08%


0it [00:00, ?it/s]

Training: Epoch 170 || Loss:   0.547 || Accuracy:  85.04%
Training: Epoch 171 || Loss:   0.547 || Accuracy:  85.05%
Training: Epoch 172 || Loss:   0.548 || Accuracy:  85.07%
Training: Epoch 173 || Loss:   0.547 || Accuracy:  85.06%
Training: Epoch 174 || Loss:   0.547 || Accuracy:  85.04%
Training: Epoch 175 || Loss:   0.548 || Accuracy:  85.06%
Training: Epoch 176 || Loss:   0.549 || Accuracy:  85.04%
Training: Epoch 177 || Loss:   0.549 || Accuracy:  85.02%
Training: Epoch 178 || Loss:   0.549 || Accuracy:  85.01%
Training: Epoch 179 || Loss:   0.550 || Accuracy:  84.99%


0it [00:00, ?it/s]

Training: Epoch 180 || Loss:   0.549 || Accuracy:  84.97%
Training: Epoch 181 || Loss:   0.550 || Accuracy:  84.98%
Training: Epoch 182 || Loss:   0.551 || Accuracy:  84.97%
Training: Epoch 183 || Loss:   0.551 || Accuracy:  84.96%
Training: Epoch 184 || Loss:   0.552 || Accuracy:  84.97%
Training: Epoch 185 || Loss:   0.552 || Accuracy:  84.96%
Training: Epoch 186 || Loss:   0.551 || Accuracy:  84.98%
Training: Epoch 187 || Loss:   0.553 || Accuracy:  84.93%
Training: Epoch 188 || Loss:   0.553 || Accuracy:  84.93%
Training: Epoch 189 || Loss:   0.554 || Accuracy:  84.90%


0it [00:00, ?it/s]

Training: Epoch 190 || Loss:   0.555 || Accuracy:  84.93%
Training: Epoch 191 || Loss:   0.555 || Accuracy:  84.90%
Training: Epoch 192 || Loss:   0.554 || Accuracy:  84.88%
Training: Epoch 193 || Loss:   0.556 || Accuracy:  84.85%
Training: Epoch 194 || Loss:   0.556 || Accuracy:  84.86%
Training: Epoch 195 || Loss:   0.556 || Accuracy:  84.84%
Training: Epoch 196 || Loss:   0.557 || Accuracy:  84.85%
Training: Epoch 197 || Loss:   0.557 || Accuracy:  84.81%
Training: Epoch 198 || Loss:   0.558 || Accuracy:  84.81%
Training: Epoch 199 || Loss:   0.558 || Accuracy:  84.80%
Shuffling


0it [00:00, ?it/s]

Training: Epoch 200 || Loss:   0.559 || Accuracy:  84.80%
Training: Epoch 201 || Loss:   0.558 || Accuracy:  84.84%
Training: Epoch 202 || Loss:   0.558 || Accuracy:  84.85%
Training: Epoch 203 || Loss:   0.557 || Accuracy:  84.82%
Training: Epoch 204 || Loss:   0.558 || Accuracy:  84.81%
Training: Epoch 205 || Loss:   0.558 || Accuracy:  84.77%
Training: Epoch 206 || Loss:   0.558 || Accuracy:  84.77%
Training: Epoch 207 || Loss:   0.559 || Accuracy:  84.77%
Training: Epoch 208 || Loss:   0.557 || Accuracy:  84.78%
Training: Epoch 209 || Loss:   0.558 || Accuracy:  84.77%


0it [00:00, ?it/s]

Training: Epoch 210 || Loss:   0.558 || Accuracy:  84.76%
Training: Epoch 211 || Loss:   0.559 || Accuracy:  84.77%
Training: Epoch 212 || Loss:   0.560 || Accuracy:  84.74%
Training: Epoch 213 || Loss:   0.559 || Accuracy:  84.76%
Training: Epoch 214 || Loss:   0.559 || Accuracy:  84.75%
Training: Epoch 215 || Loss:   0.560 || Accuracy:  84.74%
Training: Epoch 216 || Loss:   0.560 || Accuracy:  84.73%
Training: Epoch 217 || Loss:   0.561 || Accuracy:  84.74%
Training: Epoch 218 || Loss:   0.560 || Accuracy:  84.72%
Training: Epoch 219 || Loss:   0.560 || Accuracy:  84.73%


0it [00:00, ?it/s]

Training: Epoch 220 || Loss:   0.561 || Accuracy:  84.72%
Training: Epoch 221 || Loss:   0.560 || Accuracy:  84.71%
Training: Epoch 222 || Loss:   0.561 || Accuracy:  84.70%
Training: Epoch 223 || Loss:   0.561 || Accuracy:  84.72%
Training: Epoch 224 || Loss:   0.561 || Accuracy:  84.70%
Training: Epoch 225 || Loss:   0.562 || Accuracy:  84.68%
Training: Epoch 226 || Loss:   0.562 || Accuracy:  84.68%
Training: Epoch 227 || Loss:   0.562 || Accuracy:  84.67%
Training: Epoch 228 || Loss:   0.562 || Accuracy:  84.67%
Training: Epoch 229 || Loss:   0.562 || Accuracy:  84.64%


0it [00:00, ?it/s]

Training: Epoch 230 || Loss:   0.562 || Accuracy:  84.63%
Training: Epoch 231 || Loss:   0.563 || Accuracy:  84.64%
Training: Epoch 232 || Loss:   0.562 || Accuracy:  84.65%
Training: Epoch 233 || Loss:   0.562 || Accuracy:  84.62%
Training: Epoch 234 || Loss:   0.563 || Accuracy:  84.62%
Training: Epoch 235 || Loss:   0.564 || Accuracy:  84.60%
Training: Epoch 236 || Loss:   0.563 || Accuracy:  84.61%
Training: Epoch 237 || Loss:   0.563 || Accuracy:  84.60%
Training: Epoch 238 || Loss:   0.564 || Accuracy:  84.60%
Training: Epoch 239 || Loss:   0.564 || Accuracy:  84.59%


0it [00:00, ?it/s]

Training: Epoch 240 || Loss:   0.564 || Accuracy:  84.58%
Training: Epoch 241 || Loss:   0.565 || Accuracy:  84.58%
Training: Epoch 242 || Loss:   0.564 || Accuracy:  84.55%
Training: Epoch 243 || Loss:   0.566 || Accuracy:  84.55%
Training: Epoch 244 || Loss:   0.565 || Accuracy:  84.51%
Training: Epoch 245 || Loss:   0.566 || Accuracy:  84.52%
Training: Epoch 246 || Loss:   0.566 || Accuracy:  84.53%
Training: Epoch 247 || Loss:   0.566 || Accuracy:  84.51%
Training: Epoch 248 || Loss:   0.566 || Accuracy:  84.51%
Training: Epoch 249 || Loss:   0.567 || Accuracy:  84.48%
Shuffling


0it [00:00, ?it/s]

Training: Epoch 250 || Loss:   0.561 || Accuracy:  84.52%
Training: Epoch 251 || Loss:   0.559 || Accuracy:  84.59%
Training: Epoch 252 || Loss:   0.558 || Accuracy:  84.59%
Training: Epoch 253 || Loss:   0.558 || Accuracy:  84.61%
Training: Epoch 254 || Loss:   0.558 || Accuracy:  84.61%
Training: Epoch 255 || Loss:   0.557 || Accuracy:  84.58%
Training: Epoch 256 || Loss:   0.558 || Accuracy:  84.60%
Training: Epoch 257 || Loss:   0.558 || Accuracy:  84.62%
Training: Epoch 258 || Loss:   0.559 || Accuracy:  84.58%
Training: Epoch 259 || Loss:   0.558 || Accuracy:  84.57%


0it [00:00, ?it/s]

Training: Epoch 260 || Loss:   0.559 || Accuracy:  84.58%
Training: Epoch 261 || Loss:   0.559 || Accuracy:  84.57%
Training: Epoch 262 || Loss:   0.559 || Accuracy:  84.57%
Training: Epoch 263 || Loss:   0.559 || Accuracy:  84.56%
Training: Epoch 264 || Loss:   0.559 || Accuracy:  84.57%
Training: Epoch 265 || Loss:   0.559 || Accuracy:  84.59%
Training: Epoch 266 || Loss:   0.559 || Accuracy:  84.58%
Training: Epoch 267 || Loss:   0.559 || Accuracy:  84.56%
Training: Epoch 268 || Loss:   0.560 || Accuracy:  84.55%
Training: Epoch 269 || Loss:   0.561 || Accuracy:  84.55%


0it [00:00, ?it/s]

Training: Epoch 270 || Loss:   0.562 || Accuracy:  84.55%
Training: Epoch 271 || Loss:   0.562 || Accuracy:  84.55%
Training: Epoch 272 || Loss:   0.562 || Accuracy:  84.53%
Training: Epoch 273 || Loss:   0.562 || Accuracy:  84.52%
Training: Epoch 274 || Loss:   0.562 || Accuracy:  84.52%
Training: Epoch 275 || Loss:   0.562 || Accuracy:  84.52%
Training: Epoch 276 || Loss:   0.562 || Accuracy:  84.51%
Training: Epoch 277 || Loss:   0.562 || Accuracy:  84.50%
Training: Epoch 278 || Loss:   0.563 || Accuracy:  84.49%
Training: Epoch 279 || Loss:   0.563 || Accuracy:  84.47%


0it [00:00, ?it/s]

Training: Epoch 280 || Loss:   0.563 || Accuracy:  84.46%
Training: Epoch 281 || Loss:   0.564 || Accuracy:  84.47%
Training: Epoch 282 || Loss:   0.564 || Accuracy:  84.45%
Training: Epoch 283 || Loss:   0.565 || Accuracy:  84.43%
Training: Epoch 284 || Loss:   0.564 || Accuracy:  84.42%
Training: Epoch 285 || Loss:   0.564 || Accuracy:  84.43%
Training: Epoch 286 || Loss:   0.564 || Accuracy:  84.45%
Training: Epoch 287 || Loss:   0.566 || Accuracy:  84.44%
Training: Epoch 288 || Loss:   0.566 || Accuracy:  84.43%
Training: Epoch 289 || Loss:   0.566 || Accuracy:  84.44%


0it [00:00, ?it/s]

Training: Epoch 290 || Loss:   0.566 || Accuracy:  84.41%
Training: Epoch 291 || Loss:   0.566 || Accuracy:  84.40%
Training: Epoch 292 || Loss:   0.567 || Accuracy:  84.42%
Training: Epoch 293 || Loss:   0.567 || Accuracy:  84.40%
Training: Epoch 294 || Loss:   0.566 || Accuracy:  84.40%
Training: Epoch 295 || Loss:   0.568 || Accuracy:  84.40%
Training: Epoch 296 || Loss:   0.568 || Accuracy:  84.38%
Training: Epoch 297 || Loss:   0.568 || Accuracy:  84.38%
Training: Epoch 298 || Loss:   0.567 || Accuracy:  84.38%
Training: Epoch 299 || Loss:   0.568 || Accuracy:  84.39%
Shuffling


0it [00:00, ?it/s]

Training: Epoch 300 || Loss:   0.573 || Accuracy:  84.37%
Training: Epoch 301 || Loss:   0.572 || Accuracy:  84.37%
Training: Epoch 302 || Loss:   0.571 || Accuracy:  84.39%
Training: Epoch 303 || Loss:   0.572 || Accuracy:  84.38%
Training: Epoch 304 || Loss:   0.573 || Accuracy:  84.39%
Training: Epoch 305 || Loss:   0.573 || Accuracy:  84.39%
Training: Epoch 306 || Loss:   0.573 || Accuracy:  84.34%
Training: Epoch 307 || Loss:   0.574 || Accuracy:  84.33%
Training: Epoch 308 || Loss:   0.574 || Accuracy:  84.32%
Training: Epoch 309 || Loss:   0.574 || Accuracy:  84.29%


0it [00:00, ?it/s]

Training: Epoch 310 || Loss:   0.575 || Accuracy:  84.30%
Training: Epoch 311 || Loss:   0.575 || Accuracy:  84.25%
Training: Epoch 312 || Loss:   0.575 || Accuracy:  84.26%
Training: Epoch 313 || Loss:   0.576 || Accuracy:  84.25%
Training: Epoch 314 || Loss:   0.575 || Accuracy:  84.23%
Training: Epoch 315 || Loss:   0.576 || Accuracy:  84.24%
Training: Epoch 316 || Loss:   0.577 || Accuracy:  84.21%
Training: Epoch 317 || Loss:   0.577 || Accuracy:  84.25%
Training: Epoch 318 || Loss:   0.578 || Accuracy:  84.21%
Training: Epoch 319 || Loss:   0.577 || Accuracy:  84.24%


0it [00:00, ?it/s]

Training: Epoch 320 || Loss:   0.578 || Accuracy:  84.22%
Training: Epoch 321 || Loss:   0.578 || Accuracy:  84.22%
Training: Epoch 322 || Loss:   0.578 || Accuracy:  84.19%
Training: Epoch 323 || Loss:   0.578 || Accuracy:  84.22%
Training: Epoch 324 || Loss:   0.579 || Accuracy:  84.20%
Training: Epoch 325 || Loss:   0.580 || Accuracy:  84.19%
Training: Epoch 326 || Loss:   0.579 || Accuracy:  84.19%
Training: Epoch 327 || Loss:   0.579 || Accuracy:  84.18%
Training: Epoch 328 || Loss:   0.580 || Accuracy:  84.17%
Training: Epoch 329 || Loss:   0.580 || Accuracy:  84.16%


0it [00:00, ?it/s]

Training: Epoch 330 || Loss:   0.579 || Accuracy:  84.16%
Training: Epoch 331 || Loss:   0.580 || Accuracy:  84.17%
Training: Epoch 332 || Loss:   0.580 || Accuracy:  84.15%
Training: Epoch 333 || Loss:   0.580 || Accuracy:  84.14%
Training: Epoch 334 || Loss:   0.580 || Accuracy:  84.15%
Training: Epoch 335 || Loss:   0.580 || Accuracy:  84.13%
Training: Epoch 336 || Loss:   0.581 || Accuracy:  84.14%
Training: Epoch 337 || Loss:   0.580 || Accuracy:  84.13%
Training: Epoch 338 || Loss:   0.580 || Accuracy:  84.13%
Training: Epoch 339 || Loss:   0.580 || Accuracy:  84.11%


0it [00:00, ?it/s]

Training: Epoch 340 || Loss:   0.580 || Accuracy:  84.10%
Training: Epoch 341 || Loss:   0.581 || Accuracy:  84.12%
Training: Epoch 342 || Loss:   0.581 || Accuracy:  84.10%
Training: Epoch 343 || Loss:   0.582 || Accuracy:  84.11%
Training: Epoch 344 || Loss:   0.581 || Accuracy:  84.10%
Training: Epoch 345 || Loss:   0.582 || Accuracy:  84.10%
Training: Epoch 346 || Loss:   0.583 || Accuracy:  84.10%
Training: Epoch 347 || Loss:   0.582 || Accuracy:  84.10%
Training: Epoch 348 || Loss:   0.582 || Accuracy:  84.09%
Training: Epoch 349 || Loss:   0.583 || Accuracy:  84.10%
Shuffling


0it [00:00, ?it/s]

Training: Epoch 350 || Loss:   0.580 || Accuracy:  84.10%
Training: Epoch 351 || Loss:   0.579 || Accuracy:  84.16%
Training: Epoch 352 || Loss:   0.578 || Accuracy:  84.17%
Training: Epoch 353 || Loss:   0.579 || Accuracy:  84.17%
Training: Epoch 354 || Loss:   0.579 || Accuracy:  84.17%
Training: Epoch 355 || Loss:   0.580 || Accuracy:  84.15%
Training: Epoch 356 || Loss:   0.581 || Accuracy:  84.14%
Training: Epoch 357 || Loss:   0.580 || Accuracy:  84.15%
Training: Epoch 358 || Loss:   0.581 || Accuracy:  84.13%
Training: Epoch 359 || Loss:   0.581 || Accuracy:  84.11%


0it [00:00, ?it/s]

Training: Epoch 360 || Loss:   0.581 || Accuracy:  84.08%
Training: Epoch 361 || Loss:   0.582 || Accuracy:  84.10%
Training: Epoch 362 || Loss:   0.582 || Accuracy:  84.11%
Training: Epoch 363 || Loss:   0.582 || Accuracy:  84.10%
Training: Epoch 364 || Loss:   0.582 || Accuracy:  84.07%
Training: Epoch 365 || Loss:   0.583 || Accuracy:  84.08%
Training: Epoch 366 || Loss:   0.583 || Accuracy:  84.09%
Training: Epoch 367 || Loss:   0.584 || Accuracy:  84.06%
Training: Epoch 368 || Loss:   0.585 || Accuracy:  84.05%
Training: Epoch 369 || Loss:   0.585 || Accuracy:  84.06%


0it [00:00, ?it/s]

Training: Epoch 370 || Loss:   0.585 || Accuracy:  84.03%
Training: Epoch 371 || Loss:   0.585 || Accuracy:  84.05%
Training: Epoch 372 || Loss:   0.586 || Accuracy:  84.04%
Training: Epoch 373 || Loss:   0.586 || Accuracy:  84.02%
Training: Epoch 374 || Loss:   0.586 || Accuracy:  84.03%
Training: Epoch 375 || Loss:   0.586 || Accuracy:  84.03%
Training: Epoch 376 || Loss:   0.587 || Accuracy:  84.04%
Training: Epoch 377 || Loss:   0.586 || Accuracy:  84.04%
Training: Epoch 378 || Loss:   0.586 || Accuracy:  84.04%
Training: Epoch 379 || Loss:   0.586 || Accuracy:  84.05%


0it [00:00, ?it/s]

Training: Epoch 380 || Loss:   0.586 || Accuracy:  84.04%
Training: Epoch 381 || Loss:   0.586 || Accuracy:  84.04%
Training: Epoch 382 || Loss:   0.587 || Accuracy:  84.05%
Training: Epoch 383 || Loss:   0.587 || Accuracy:  84.05%
Training: Epoch 384 || Loss:   0.586 || Accuracy:  84.03%
Training: Epoch 385 || Loss:   0.587 || Accuracy:  84.04%
Training: Epoch 386 || Loss:   0.588 || Accuracy:  84.04%
Training: Epoch 387 || Loss:   0.587 || Accuracy:  84.03%
Training: Epoch 388 || Loss:   0.587 || Accuracy:  84.04%
Training: Epoch 389 || Loss:   0.588 || Accuracy:  84.04%


0it [00:00, ?it/s]

Training: Epoch 390 || Loss:   0.587 || Accuracy:  84.03%
Training: Epoch 391 || Loss:   0.589 || Accuracy:  84.01%
Training: Epoch 392 || Loss:   0.589 || Accuracy:  84.01%
Training: Epoch 393 || Loss:   0.589 || Accuracy:  84.02%
Training: Epoch 394 || Loss:   0.589 || Accuracy:  84.01%
Training: Epoch 395 || Loss:   0.590 || Accuracy:  84.02%
Training: Epoch 396 || Loss:   0.590 || Accuracy:  84.04%
Training: Epoch 397 || Loss:   0.590 || Accuracy:  84.02%
Training: Epoch 398 || Loss:   0.590 || Accuracy:  84.04%
Training: Epoch 399 || Loss:   0.590 || Accuracy:  84.02%
Shuffling


0it [00:00, ?it/s]

Training: Epoch 400 || Loss:   0.587 || Accuracy:  84.06%
Training: Epoch 401 || Loss:   0.584 || Accuracy:  84.15%
Training: Epoch 402 || Loss:   0.583 || Accuracy:  84.14%
Training: Epoch 403 || Loss:   0.582 || Accuracy:  84.13%
Training: Epoch 404 || Loss:   0.581 || Accuracy:  84.14%
Training: Epoch 405 || Loss:   0.580 || Accuracy:  84.12%
Training: Epoch 406 || Loss:   0.580 || Accuracy:  84.11%
Training: Epoch 407 || Loss:   0.579 || Accuracy:  84.12%
Training: Epoch 408 || Loss:   0.579 || Accuracy:  84.11%
Training: Epoch 409 || Loss:   0.580 || Accuracy:  84.10%


0it [00:00, ?it/s]

Training: Epoch 410 || Loss:   0.580 || Accuracy:  84.08%
Training: Epoch 411 || Loss:   0.580 || Accuracy:  84.07%
Training: Epoch 412 || Loss:   0.580 || Accuracy:  84.07%
Training: Epoch 413 || Loss:   0.580 || Accuracy:  84.07%
Training: Epoch 414 || Loss:   0.581 || Accuracy:  84.06%
Training: Epoch 415 || Loss:   0.581 || Accuracy:  84.06%
Training: Epoch 416 || Loss:   0.582 || Accuracy:  84.05%
Training: Epoch 417 || Loss:   0.582 || Accuracy:  84.05%
Training: Epoch 418 || Loss:   0.582 || Accuracy:  84.05%
Training: Epoch 419 || Loss:   0.583 || Accuracy:  84.04%


0it [00:00, ?it/s]

Training: Epoch 420 || Loss:   0.583 || Accuracy:  84.04%
Training: Epoch 421 || Loss:   0.583 || Accuracy:  84.06%
Training: Epoch 422 || Loss:   0.582 || Accuracy:  84.06%
Training: Epoch 423 || Loss:   0.582 || Accuracy:  84.04%
Training: Epoch 424 || Loss:   0.582 || Accuracy:  84.05%
Training: Epoch 425 || Loss:   0.582 || Accuracy:  84.07%
Training: Epoch 426 || Loss:   0.583 || Accuracy:  84.05%
Training: Epoch 427 || Loss:   0.583 || Accuracy:  84.06%
Training: Epoch 428 || Loss:   0.583 || Accuracy:  84.07%
Training: Epoch 429 || Loss:   0.583 || Accuracy:  84.05%


0it [00:00, ?it/s]

Training: Epoch 430 || Loss:   0.582 || Accuracy:  84.07%
Training: Epoch 431 || Loss:   0.583 || Accuracy:  84.08%
Training: Epoch 432 || Loss:   0.583 || Accuracy:  84.06%
Training: Epoch 433 || Loss:   0.582 || Accuracy:  84.07%
Training: Epoch 434 || Loss:   0.582 || Accuracy:  84.10%
Training: Epoch 435 || Loss:   0.582 || Accuracy:  84.09%
Training: Epoch 436 || Loss:   0.582 || Accuracy:  84.09%
Training: Epoch 437 || Loss:   0.582 || Accuracy:  84.09%
Training: Epoch 438 || Loss:   0.582 || Accuracy:  84.08%
Training: Epoch 439 || Loss:   0.582 || Accuracy:  84.11%


0it [00:00, ?it/s]

Training: Epoch 440 || Loss:   0.581 || Accuracy:  84.11%
Training: Epoch 441 || Loss:   0.581 || Accuracy:  84.12%
Training: Epoch 442 || Loss:   0.581 || Accuracy:  84.13%
Training: Epoch 443 || Loss:   0.581 || Accuracy:  84.13%
Training: Epoch 444 || Loss:   0.581 || Accuracy:  84.15%
Training: Epoch 445 || Loss:   0.581 || Accuracy:  84.16%
Training: Epoch 446 || Loss:   0.579 || Accuracy:  84.17%
Training: Epoch 447 || Loss:   0.580 || Accuracy:  84.16%
Training: Epoch 448 || Loss:   0.579 || Accuracy:  84.14%
Training: Epoch 449 || Loss:   0.579 || Accuracy:  84.17%
Shuffling


0it [00:00, ?it/s]

Training: Epoch 450 || Loss:   0.578 || Accuracy:  84.19%
Training: Epoch 451 || Loss:   0.577 || Accuracy:  84.26%
Training: Epoch 452 || Loss:   0.577 || Accuracy:  84.23%
Training: Epoch 453 || Loss:   0.576 || Accuracy:  84.23%
Training: Epoch 454 || Loss:   0.576 || Accuracy:  84.23%
Training: Epoch 455 || Loss:   0.577 || Accuracy:  84.22%
Training: Epoch 456 || Loss:   0.578 || Accuracy:  84.18%
Training: Epoch 457 || Loss:   0.577 || Accuracy:  84.20%
Training: Epoch 458 || Loss:   0.577 || Accuracy:  84.18%
Training: Epoch 459 || Loss:   0.577 || Accuracy:  84.14%


0it [00:00, ?it/s]

Training: Epoch 460 || Loss:   0.577 || Accuracy:  84.14%
Training: Epoch 461 || Loss:   0.577 || Accuracy:  84.12%
Training: Epoch 462 || Loss:   0.577 || Accuracy:  84.10%
Training: Epoch 463 || Loss:   0.576 || Accuracy:  84.07%
Training: Epoch 464 || Loss:   0.577 || Accuracy:  84.09%
Training: Epoch 465 || Loss:   0.576 || Accuracy:  84.08%
Training: Epoch 466 || Loss:   0.577 || Accuracy:  84.07%
Training: Epoch 467 || Loss:   0.577 || Accuracy:  84.10%
Training: Epoch 468 || Loss:   0.577 || Accuracy:  84.13%
Training: Epoch 469 || Loss:   0.576 || Accuracy:  84.13%


0it [00:00, ?it/s]

Training: Epoch 470 || Loss:   0.576 || Accuracy:  84.12%
Training: Epoch 471 || Loss:   0.576 || Accuracy:  84.13%
Training: Epoch 472 || Loss:   0.576 || Accuracy:  84.13%
Training: Epoch 473 || Loss:   0.576 || Accuracy:  84.14%
Training: Epoch 474 || Loss:   0.576 || Accuracy:  84.14%
Training: Epoch 475 || Loss:   0.575 || Accuracy:  84.13%
Training: Epoch 476 || Loss:   0.575 || Accuracy:  84.12%
Training: Epoch 477 || Loss:   0.574 || Accuracy:  84.13%
Training: Epoch 478 || Loss:   0.574 || Accuracy:  84.14%
Training: Epoch 479 || Loss:   0.574 || Accuracy:  84.13%


0it [00:00, ?it/s]

Training: Epoch 480 || Loss:   0.574 || Accuracy:  84.14%
Training: Epoch 481 || Loss:   0.574 || Accuracy:  84.15%
Training: Epoch 482 || Loss:   0.574 || Accuracy:  84.17%
Training: Epoch 483 || Loss:   0.573 || Accuracy:  84.16%
Training: Epoch 484 || Loss:   0.573 || Accuracy:  84.16%
Training: Epoch 485 || Loss:   0.573 || Accuracy:  84.16%
Training: Epoch 486 || Loss:   0.573 || Accuracy:  84.17%
Training: Epoch 487 || Loss:   0.573 || Accuracy:  84.18%
Training: Epoch 488 || Loss:   0.573 || Accuracy:  84.19%
Training: Epoch 489 || Loss:   0.574 || Accuracy:  84.18%


0it [00:00, ?it/s]

Training: Epoch 490 || Loss:   0.573 || Accuracy:  84.18%
Training: Epoch 491 || Loss:   0.572 || Accuracy:  84.17%
Training: Epoch 492 || Loss:   0.573 || Accuracy:  84.18%
Training: Epoch 493 || Loss:   0.573 || Accuracy:  84.18%
Training: Epoch 494 || Loss:   0.573 || Accuracy:  84.17%
Training: Epoch 495 || Loss:   0.573 || Accuracy:  84.17%
Training: Epoch 496 || Loss:   0.573 || Accuracy:  84.18%
Training: Epoch 497 || Loss:   0.573 || Accuracy:  84.19%
Training: Epoch 498 || Loss:   0.572 || Accuracy:  84.21%
Training: Epoch 499 || Loss:   0.573 || Accuracy:  84.19%
Shuffling


0it [00:00, ?it/s]

Training: Epoch 500 || Loss:   0.575 || Accuracy:  84.23%


In [30]:
embeddings = torch.from_numpy(best_model.linear.weight.detach().cpu().numpy()).to(torch.float16)

In [31]:
zeroshot_weights = torch.from_numpy(np.array(embeddings).T).to(torch.float16)

def accuracy(output, target, topk=(1,)):
    pred = output.topk(max(topk), 1, True, True)[1].t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    return [
        float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
        for k in topk
    ]

# lazy load
if clip_model == None:
    clip_model, clip_preprocess = clip.load(clip_model_name, device)

with torch.no_grad():
    top1, top5, n = 0.0, 0.0, 0.0
    for i, (images, target) in enumerate(tqdm(test_loader)):
        images = images.cuda()
        target = target.cuda()

        # predict

        image_features = clip_model.encode_image(images)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        logits = 100.0 * image_features.to(device) @ zeroshot_weights.to(device)

        # measure accuracy
        acc1, _ = accuracy(logits, target, topk=(1, 5))
        top1 += acc1
        n += images.size(0)

top1 = (top1 / n) * 100

print("acc:", top1)

  0%|          | 0/303 [00:00<?, ?it/s]

acc: 82.69966996699671


## Deterministic - (using sci-kit)

In [32]:
global clip_model, clip_preprocess
from sklearn.linear_model import LogisticRegression

len_classes = len(classes)

train_features, train_labels = get_clip_features(train_loader)
test_features, test_labels = get_clip_features(test_loader)

train_features = train_features / train_features.norm(dim=-1, keepdim=True)
test_features = test_features / test_features.norm(dim=-1, keepdim=True)

classifier = LogisticRegression(C=1, max_iter=1000, n_jobs=4,verbose=1)
classifier.fit(train_features.cpu().numpy(), train_labels.cpu().numpy())
predictions = classifier.predict(test_features.cpu().numpy())
accuracy = np.mean((test_labels.cpu().numpy() == predictions).astype(np.float)) * 100.0

print(accuracy)

  0%|          | 0/1515 [00:00<?, ?it/s]

  0%|          | 0/303 [00:00<?, ?it/s]

[Parallel(n_jobs=4)]: Using backend LokyBackend with 4 concurrent workers.
[Parallel(n_jobs=4)]: Done   1 out of   1 | elapsed:  2.1min finished


83.86138613861385
