In [1]:
import numpy as np
from datasets import Dataset, load_from_disk
import torch.nn as nn
from torchvision import transforms
import torch
from PIL import Image
from torchvision.datasets import ImageFolder
from torch.utils.data import random_split

In [2]:
from transformers import ViTModel
from transformers.modeling_outputs import SequenceClassifierOutput
import torch.nn as nn

class ViTForImageClassification(nn.Module):
    def __init__(self, num_labels=50):
        super(ViTForImageClassification, self).__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.vit.config.hidden_size, num_labels)
        self.num_labels = num_labels

    def forward(self, pixel_values, labels):
        outputs = self.vit(pixel_values=pixel_values)
        output = self.dropout(outputs.last_hidden_state[:,0])
        logits = self.classifier(output)

        loss = None
        if labels is not None:
          loss_fct = nn.CrossEntropyLoss()
          loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [3]:
# dataset = load_from_disk('awa2-test')

In [5]:
transform = transforms.Compose([transforms.Resize(224),
                                transforms.CenterCrop(224)])
data_set = ImageFolder('AwA2-data/Animals_with_Attributes2/JPEGImages', transform)
# should split dataset before feature extractor to make it smaller so they
# don't run out of memory

splits = [int(0.85*len(data_set)), len(data_set)-int(0.85*len(data_set))]
train_set, test_set = random_split(data_set, splits, generator=torch.Generator().manual_seed(42))

splits = [int(0.9*len(train_set)), len(train_set)-int(0.9*len(train_set))]
train_set, valid_set = random_split(train_set,splits, generator=torch.Generator().manual_seed(42))

In [6]:
from transformers import ViTFeatureExtractor

feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')

In [7]:
from tqdm import tqdm

processed_dataset = [(feature_extractor(item[0]), item[1]) for item in tqdm(valid_set)]
    

100%|██████████| 3173/3173 [00:46<00:00, 68.72it/s]


__uncomment and run this to save dataset to disk, may need to install pickle with conda__

In [8]:
# import pickle

# filename = 'pickled_processed_dataset.p'
# with open(filename, 'wb') as filehandler:
#     pickle.dump(processed_dataset, filehandler)

In [17]:
class ViTDATASET(Dataset):
    def __init__(self, _dataset):
        self.dataset = _dataset
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        imageTensor, target = self.dataset[idx]

        return imageTensor, target

In [18]:
dataset = ViTDATASET(processed_dataset)

In [19]:
model = ViTForImageClassification()
model.load_state_dict(torch.load('vit_base.pt'))

<All keys matched successfully>

In [72]:
from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=1)

In [73]:
def evaluate(model_in, testset, testloader):
#     device = ('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = model_in
    correct = 0
    total = 0
    # model.eval()
    with torch.no_grad():
        for i, data in tqdm(enumerate(testloader), total=len(testloader)):
            features, labels = data[0], data[1]
#             print(features['pixel_values'][0])
            output = model(pixel_values=features['pixel_values'][0], labels=labels)
 
        # should be able to convert pixel values to image
    
            logits = output.logits
            predicted = logits.argmax(1)
            print(predicted)
            print(labels)

            break

    return


In [74]:
# dataset[0][0]['pixel_values']

In [75]:
acc = evaluate(model, dataset, dataloader)
acc

  0%|          | 0/3173 [00:00<?, ?it/s]

tensor([31])
tensor([9])



