In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.models import vit_b_16, vit_b_32#, vit_l_16, vit_l_32, vit_h_14

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
""" 
Extracts the features of the images in the MNIST dataset using the trained vision transformer.
This way the database of MNIST images becomes a database of MNIST trasnformer feature representations.
Python notebook is used to test the code before running it fully as a pure Python script. 

Steps:
    1. Load the trained transformer model
    2. Strip off the final classification layer
    3. Go through images and forward propagate them
    4. For each image save the final transformer layer representation of that image.
"""

run_on_server = True

batch_size = 8

models = {"vit_b_16": vit_b_16, "vit_b_32": vit_b_32}
weights = {"vit_b_16": torchvision.models.ViT_B_16_Weights.IMAGENET1K_V1,\
           "vit_b_32": torchvision.models.ViT_B_32_Weights.IMAGENET1K_V1}

train_features_file = "training_cifar10.pkl"
test_features_file = "test_cifar10.pkl"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Don't change
if run_on_server:
    features_folder = "/usr/itetnas04/data-scratch-01/ddordevic/data/cluster_scripts/vit_copy/CIFAR-10/"
else:
    features_folder = "C:/Users/danil/Desktop/Master thesis/Code/msc-thesis/CIFAR-10/"

train_features_path = features_folder + train_features_file
test_features_path = features_folder + test_features_file

# Dataset preparation
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])



train_set = torchvision.datasets.CIFAR10(root='../datasets/CIFAR-10', train=True,
                                        download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

test_set = torchvision.datasets.CIFAR10(root='../datasets/CIFAR-10', train=False,
                                       download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

train_len = len(train_loader) # number of batches
test_len = len(test_loader)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# Check that folders exist 
if not os.path.exists(features_folder):
    raise Exception("Extracted features folder does not exist.")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")

# TRAINING
if not os.path.exists(train_features_path):
    training_df = pd.DataFrame(columns=['image','label'])

    print('Initial train database save file in progress.')

    for i, batch in enumerate(train_loader):
        images, labels = batch
        new_rows = pd.DataFrame({'image': tuple(images), 'label':  tuple(labels)})
        training_df = pd.concat([training_df, new_rows], ignore_index=True)
        if i%99 == 0:
            print(f"{i}/{len(train_loader)}")

    training_df.to_pickle(train_features_path)

    print('Initial train database save file done.')
else:
    print(f'Reading existing {train_features_file}')
    training_df = pd.read_pickle(train_features_path)

# TEST
if not os.path.exists(test_features_path):
    print('Initial test database save file in progress.')

    test_df = pd.DataFrame(columns=['image','label'])
    for i, batch in enumerate(test_loader):
        images, labels = batch
        new_rows = pd.DataFrame({'image': tuple(images), 'label':  tuple(labels)})
        test_df = pd.concat([test_df, new_rows], ignore_index=True)
        if i%99 == 0:
            print(f"{i}/{len(test_loader)}")
    test_df.to_pickle(test_features_path)
    
    print('Initial test database save file done.')
else:
    print(f'Reading existing {test_features_file}')
    test_df = pd.read_pickle(test_features_path)

# Iterate over the models to be used (key is model)
num_train_imgs = len(training_df)
num_test_imgs = len(test_df)

Using device:  cpu 
Initial train database save file in progress.
0/6250
99/6250
198/6250
297/6250
396/6250
495/6250
594/6250
693/6250
792/6250
891/6250
990/6250
1089/6250
1188/6250
1287/6250
1386/6250
1485/6250
1584/6250
1683/6250
1782/6250
1881/6250
1980/6250
2079/6250
2178/6250
2277/6250
2376/6250
2475/6250
2574/6250
2673/6250
2772/6250
2871/6250
2970/6250
3069/6250
3168/6250
3267/6250
3366/6250
3465/6250
3564/6250
3663/6250
3762/6250
3861/6250
3960/6250
4059/6250
4158/6250
4257/6250
4356/6250
4455/6250
4554/6250
4653/6250
4752/6250
4851/6250
4950/6250
5049/6250
5148/6250
5247/6250
5346/6250
5445/6250
5544/6250
5643/6250
5742/6250
5841/6250
5940/6250
6039/6250
6138/6250
6237/6250
Initial train database save file done.
Initial test database save file in progress.
0/1250
99/1250
198/1250
297/1250
396/1250
495/1250
594/1250
693/1250
792/1250
891/1250
990/1250
1089/1250
1188/1250
Initial test database save file done.


In [4]:
for key in models.keys():
    print(f'Key = {key}')

    model_weights = weights[key]
    model_transform = model_weights.transforms()
    model = models[key](weights = model_weights)

    break

Key = vit_b_16


In [28]:
num_features = model.heads[0].in_features
model_features = torch.zeros(num_train_imgs, num_features)

In [30]:
model_features.shape

torch.Size([50000, 768])

In [32]:
training_df["vit_b_32"] = tuple(model_features)

In [33]:
train_features_path

'/usr/itetnas04/data-scratch-01/ddordevic/data/cluster_scripts/vit_copy/CIFAR-10/training_cifar10.pkl'

In [34]:
training_df = training_df.to_pickle('/usr/itetnas04/data-scratch-01/ddordevic/data/cluster_scripts/vit_copy/CIFAR-10/training_cifar10_proba.pkl')

KeyboardInterrupt: 

: 

In [5]:
for key in models.keys():
    print(f'Key = {key}')

    model_weights = weights[key]
    model_transform = model_weights.transforms()
    model = models[key](weights = model_weights)

    num_features = model.heads[0].in_features

    model.heads = torch.nn.Identity()
    model.eval()


    # Main feature extraction loop
    with torch.no_grad():
        # Feature extraction loop: Training set
        print(f'TRAINING ({key}): ')

        model_features = torch.zeros(num_train_imgs, num_features, requires_grad=False)
        
        for i in range(0, len(training_df), batch_size):
            images = torch.stack(tuple(training_df.iloc[i:i+batch_size]['image']))
            images = images.to(device)
            model_features[i:i+batch_size] = model(model_transform(images))
            print(i)
            break
            if i % 99 == 0:
                print(f"{i+1}/{num_train_imgs}")

        training_df[key] = tuple(model_features)

        # Feature extraction loop: Test set
        print(f'TEST ({key}): ')
        model_features = torch.zeros(num_test_imgs, num_features)
        for i in range(0, len(test_df), batch_size):
            images = torch.stack(tuple(test_df.iloc[i:i+batch_size]['image']))
            images = images.to(device)
            model_features[i:i+batch_size] = model(model_transform(images))
            print(i)
            break
            if i % 99 == 0:
                print(f"{i+1}/{num_test_imgs}")

        test_df[key] = tuple(model_features)

Key = vit_b_16
TRAINING (vit_b_16): 
0
TEST (vit_b_16): 
0
Key = vit_b_32
TRAINING (vit_b_32): 
0
TEST (vit_b_32): 
0


In [10]:
a = np.array([[1,2,3],[7,8,9]])
b = torch.tensor([4,5,6])

a[1] = b


In [14]:
a

array([[1, 2, 3],
       [4, 5, 6]])

In [13]:
b[0] = 10