In [None]:
!pip install -q git+https://github.com/huggingface/transformers.git

In [2]:
import torch
import matplotlib.pyplot as plt
import tensorflow
import numpy as np
from transformers import ViTForImageClassification

In [3]:
def features_for_img(img_path, model, size):
    img = plt.imread(img_path)[:, :, :3]
    img = tensorflow.image.resize(img, [size,size]).numpy()
    img = np.expand_dims(img, 0)
    img = np.transpose(img, (0, 3, 1, 2))
    img = torch.from_numpy(img)
        
    outputs = model(img)
    logits = outputs.logits
    logits = logits.detach().numpy()
    
    vit_emd = model.vit.embeddings
    vit_enc = model.vit.encoder.layer[0]
    int_out = vit_emd(img)
    int_out = vit_enc(int_out)[0]
    int_out = int_out.detach().numpy()
    
    return int_out, logits


def get_features_for_task(base_path, int_model, size=224):
    train_base_path = base_path + 'train/'
    base_path = base_path + 'test/'
    
    train_wugs_low = []
    train_wugs_high = []
    for i in range(1, 6):
        curr_path = train_base_path + str(i) + '_fin.png'
        low, high = features_for_img(curr_path, int_model, size)
        train_wugs_low.append(low)
        train_wugs_high.append(high)
        
    wugs_low = []
    wugs_high = []
    for i in range(1, 6):
        curr_path = base_path + 'in_' + str(i) + '_fin.png'
        low, high = features_for_img(curr_path, int_model, size)
        wugs_low.append(low)
        wugs_high.append(high)

    not_wugs_low_close = []
    not_wugs_high_close = []
    for i in range(1, 6):
        curr_path = base_path + 'out_close_' + str(i) + '_fin.png'
        low, high = features_for_img(curr_path, int_model, size)
        not_wugs_low_close.append(low)
        not_wugs_high_close.append(high)
        
    not_wugs_low_far = []
    not_wugs_high_far = []
    for i in range(1, 6):
        curr_path = base_path + 'out_far_' + str(i) + '_fin.png'
        low, high = features_for_img(curr_path, int_model, size)
        not_wugs_low_far.append(low)
        not_wugs_high_far.append(high)
        
    return (train_wugs_low, train_wugs_high, wugs_low, wugs_high, not_wugs_low_close, not_wugs_high_close, not_wugs_low_far, not_wugs_high_far)


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
int_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
int_model.eval()
int_model.to(device)

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): PatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0): ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_fea

In [8]:
# Please replace with new path.
features = get_features_for_task('/viscam/u/joycj/geoclidean/dataset/geoclidean/constraints/concept_cct/', int_model, 224)

