In [3]:
from feature_extractors import ResnetFeatureExtractor
import torch
from torchvision.models import ResNet101_Weights, resnet101
import os
import h5py
from PIL import Image
import numpy as np

In [None]:
--image_dir ~/clevr-images-unambigous-dale-two/images/ --scene_dir ~/clevr-images-unambigous-dale-two/scenes/ --out_file ~/clevr-images-unambigous-dale-two/features/resnet_3_noavgpool_no-fc2.h5 --feature_extractor ResNet --no-avgpool --no-fc --num_blocks 3 --device cuda --batch_size 32

In [11]:
device = torch.device("cuda")
feature_extractor = ResnetFeatureExtractor(pretrained=True, fine_tune=False, number_blocks=3, avgpool=False, fc=False).to(device)

preprocess = ResNet101_Weights.IMAGENET1K_V2.transforms()

In [20]:
image_dir = "/home/xappma/clevr-images-unambigous-dale-two/images/"
images = sorted(os.listdir(image_dir))
shape = [len(images), *feature_extractor.feature_shape]

In [21]:
batch_size = 32

In [32]:
with h5py.File("test.h5py", "w") as f:
    feature_dataset = f.create_dataset("features", shape, dtype=np.float32)
    ohter_dataset = f.create_dataset("other", shape, dtype=np.float32)
    feature_dataset.attrs["image_size"] = Image.open(
        os.path.join(image_dir, images[0])
    ).size
    batch = []
    i_start = 0
    for image_index, image_file in enumerate(images):
        if image_index % batch_size == 0:
            print(f"processing image {image_index}...", end="\r")
    
        image = Image.open(os.path.join(image_dir, image_file)).convert(
            "RGB"
        )
        preprocessed_image = preprocess(image).to(device)
        batch.append(preprocessed_image)
    
        if len(batch) == batch_size or image_index == len(images) - 1:
            with torch.no_grad():
                # pylint: disable-next=not-callable
                features = feature_extractor(torch.stack(batch)).cpu()
            i_end = i_start + len(batch)
            feature_dataset[i_start:i_end] = features
            ohter_dataset[i_start:i_end] = features
    
            i_start = i_end
            batch = []
    print()


processing image 9984...


In [33]:
with h5py.File("test.h5py", "r") as f:
    print(f['features'][0])
    print(f['other'][0])

[[[3.8073244  2.6239595  4.339472   ... 5.325759   2.782626   4.3295035 ]
  [8.308824   6.6276407  7.509774   ... 5.86327    7.785393   8.902941  ]
  [3.3584938  3.2338123  5.528773   ... 2.0243895  4.3621306  6.3857875 ]
  ...
  [1.6492925  1.4835947  1.6650604  ... 2.3154776  0.87608    0.8543335 ]
  [6.2306514  7.324578   3.187983   ... 2.046173   1.794639   2.6868978 ]
  [1.8481137  1.4104245  1.4879092  ... 1.6555399  0.6400337  0.83423305]]

 [[1.7858517  0.         0.         ... 0.         0.         1.5440044 ]
  [0.88440037 0.         0.         ... 0.09728736 1.2225254  0.        ]
  [2.2111716  0.7295412  3.8566332  ... 0.3129365  3.9986446  6.581839  ]
  ...
  [0.4817189  0.         0.         ... 0.04168516 0.         1.1300035 ]
  [0.         0.         0.         ... 0.         0.         0.6798581 ]
  [0.         0.         1.2156646  ... 0.         0.         0.        ]]

 [[0.7791131  1.0179651  1.1035043  ... 1.0819919  1.1234807  1.6286753 ]
  [1.5182847  2.291483

In [5]:
ls ~/clevr-rotation

[0m[01;34mrot0[0m/  [01;34mrot180[0m/  [01;34mrot270[0m/  [01;34mrot90[0m/


In [22]:
device = torch.device("cuda")
feature_extractor = ResnetFeatureExtractor(pretrained=True, fine_tune=False, number_blocks=3, avgpool=False, fc=False).to(device)
preprocess = ResNet101_Weights.IMAGENET1K_V2.transforms()

image_dir = "/home/xappma/clevr-rotation/"
images = sorted(os.listdir(image_dir + 'rot0'))
shape = [len(images), *feature_extractor.feature_shape]

In [23]:
batch_size = 32

with h5py.File("test.h5py", "w") as f:
    rot0_dataset = f.create_dataset("rot0", shape, dtype=np.float32)
    rot90_dataset = f.create_dataset("rot90", shape, dtype=np.float32)
    rot180_dataset = f.create_dataset("rot180", shape, dtype=np.float32)
    rot270_dataset = f.create_dataset("rot270", shape, dtype=np.float32)

    datasets = {'rot0': rot0_dataset, 'rot90': rot90_dataset, 'rot180': rot180_dataset, 'rot270': rot270_dataset}
    img_size = Image.open(
        os.path.join(image_dir, 'rot0', images[0])
    ).size
    for d in datasets.values():
        d.attrs["image_size"] = img_size 

    for rotation, dataset in datasets.items():
        images = sorted(os.listdir(image_dir + rotation))
        batch = []
        i_start = 0
        for image_index, image_file in enumerate(images):
            if image_index % batch_size == 0:
                print(f"processing image {image_index}...", end="\r")
    
            image = Image.open(os.path.join(image_dir, rotation, image_file)).convert(
                "RGB"
            )
            preprocessed_image = preprocess(image).to(device)
            batch.append(preprocessed_image)
        
            if len(batch) == batch_size or image_index == len(images) - 1:
                with torch.no_grad():
                    # pylint: disable-next=not-callable
                    features = feature_extractor(torch.stack(batch)).cpu()
                i_end = i_start + len(batch)
                dataset[i_start:i_end] = features
        
                i_start = i_end
                batch = []
    print()

processing image 192...


In [24]:
with h5py.File("test.h5py", "r") as f:
    print(f['rot0'][0])

[[[3.2667007e+00 1.9804718e+00 5.4867163e+00 ... 4.7759619e+00
   2.6488540e+00 3.8716872e+00]
  [7.9386778e+00 7.2455015e+00 8.5866747e+00 ... 5.1461329e+00
   4.9513712e+00 8.6591339e+00]
  [3.9532762e+00 4.1714735e+00 5.5034389e+00 ... 2.8516710e+00
   3.7648771e+00 5.6911392e+00]
  ...
  [5.2508745e+00 3.5094647e+00 2.7807615e+00 ... 2.3141162e+00
   1.9528351e+00 1.9380059e+00]
  [6.8675318e+00 5.4736824e+00 4.3544488e+00 ... 3.0802040e+00
   3.1041505e+00 2.5740082e+00]
  [4.7065949e+00 3.9638441e+00 4.7066145e+00 ... 2.5681045e+00
   1.6386139e+00 2.0963702e+00]]

 [[0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 2.9370213e+00
   0.0000000e+00 5.2613678e+00]
  [1.3201824e-01 2.2842543e+00 0.0000000e+00 ... 4.7416210e-01
   1.3833011e+00 0.0000000e+00]
  [2.1585088e+00 0.0000000e+00 1.9286683e+00 ... 1.4459424e+00
   4.1377940e+00 7.2262664e+00]
  ...
  [0.0000000e+00 0.0000000e+00 0.0000000e+00 ... 0.0000000e+00
   0.0000000e+00 0.0000000e+00]
  [5.0276864e-01 6.1952448e-03 1.330