<a href="https://colab.research.google.com/github/mtrefilek/cs762/blob/main/resnet50_pynb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import os
cwd = os.getcwd().replace('\\','/')

import numpy as np
import pickle
from copy import deepcopy
from torchvision import datasets
from glob import glob, iglob
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models as models

In [3]:
DSET_NAME = 'MNIST' #('MNIST', 'FMNIST', 'CIFAR10', 'CIFAR100', 'PlantDisease', 'EuroSAT', 'ChestXRay') 'ISIC2018', 'TinyImageNet' are not yet implemented

In [4]:
# Load the pretrained model
model = models.resnet50(pretrained=True)
# Use the model object to select the desired layer
layer = model._modules.get('avgpool')

In [5]:
# Set model to evaluation mode
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [None]:
def batch(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]

#import torchvision
#transform=torchvision.transforms.Compose([
    # you can add other transformations in this list
    #torchvision.transforms.ToTensor()
#])
dset_path = cwd + '/dataset'
feature_path = cwd+'/extracted_features/'

if DSET_NAME=='MNIST': # MNIST
    scaler = transforms.Scale((224, 224))
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
    to_tensor = transforms.ToTensor()
    mnist_train = datasets.MNIST(root=dset_path, train=True, download=True)#, transform=transform)
    mnist_test = datasets.MNIST(root=dset_path, train=False, download=True)#, transform=transform)
    imgs_tr, imgs_tst, labels_tr, labels_tst = [], [], [], []
    for (img, label) in mnist_train:
        img = normalize(to_tensor(scaler(img.convert(mode='RGB')))).unsqueeze(0)
        imgs_tr.append(img)
        labels_tr.append(label)
    for (img, label) in mnist_test:
        img = normalize(scaler(img.convert(mode='RGB')))
        imgs_tst.append(img)
        labels_tst.append(label)
    classnames = [str(a) for a in range(10)]
    



In [6]:
def get_vector(image_name):
    # 1. Load the image with Pillow library
    img = Image.open(image_name)    # 2. Create a PyTorch Variable with the transformed image
    t_img = Variable(normalize(to_tensor(scaler(img))).unsqueeze(0))    # 3. Create a vector of zeros that will hold our feature vector
    #    The 'avgpool' layer has an output size of 512
    my_embedding = torch.zeros(512)    # 4. Define a function that will copy the output of a layer
    def copy_data(m, i, o):
        my_embedding.copy_(o.data)    # 5. Attach that function to our selected layer
    h = layer.register_forward_hook(copy_data)    # 6. Run the model on our transformed image
    model(t_img)    # 7. Detach our copy function from the layer
    h.remove()    # 8. Return the feature vector
    return my_embedding

In [None]:
### Extract Features
if DSET_NAME in ('MNIST', 'FMNIST', 'CIFAR10', 'CIFAR100'):
    n_tr = len(imgs_tr)
    feature_matrix_tr = np.zeros((1, 2048))
    for img_batch in batch(imgs_tr, n=256):
        inputs = processor(images=img_batch, return_tensors="pt")
        img_features_tr = model.get_image_features(inputs['pixel_values'].to(device)).cpu().detach().numpy()
        feature_matrix_tr = np.concatenate((feature_matrix_tr, img_features_tr), axis=0)
        print('Extracting Training Features: {0:.2f}% done'.format(100*len(feature_matrix_tr[1:])/n_tr) )
    feature_matrix_tr = feature_matrix_tr[1:]

    n_cls = np.max(labels_tr)+1
    labels_tr = np.array(labels_tr)
    feature_matrices_tr = []
    for i in range(n_cls):
        feature_matrices_tr.append(feature_matrix_tr[labels_tr==i])

    n_tst = len(imgs_tst)
    feature_matrix_tst = np.zeros((1, 512))
    for img_batch in batch(imgs_tst, n=256):
        inputs = processor(images=img_batch, return_tensors="pt")
        img_features_tst = model.get_image_features(inputs['pixel_values'].to(device)).cpu().detach().numpy()
        feature_matrix_tst = np.concatenate((feature_matrix_tst, img_features_tst), axis=0)
        print('Extracting Test Features: {0:.2f}% done'.format(100*len(feature_matrix_tst[1:])/n_tst) )
    feature_matrix_tst = feature_matrix_tst[1:]

    n_cls = np.max(labels_tst)+1
    labels_tst = np.array(labels_tst)
    feature_matrices_tst = []
    for i in range(n_cls):
        feature_matrices_tst.append(feature_matrix_tst[labels_tst==i])
        
    inputs = processor(text=classtexts, return_tensors="pt", padding=True).to(device)
    classtext_embeddings = model.get_text_features(**inputs).cpu().detach().numpy()
    
### Save Features
if DSET_NAME in ('MNIST', 'FMNIST', 'CIFAR10', 'CIFAR100'):
    fname_tr = DSET_NAME+'_'+PRETRAINED_MODEL_NAME+'_train.npz'
    fname_tst = DSET_NAME+'_'+PRETRAINED_MODEL_NAME+'_test.npz'

    np.savez(feature_path + fname_tr, feature_matrices = feature_matrices_tr, classnames = classnames, classtexts = classtexts, classtext_embeddings = classtext_embeddings)
    np.savez(feature_path + fname_tst, feature_matrices = feature_matrices_tst, classnames = classnames, classtexts = classtexts, classtext_embeddings = classtext_embeddings)