VQA — Image processing with ResNet-50

Setup: imports and device

In [1]:
import os
from pathlib import Path
from collections import defaultdict
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from tqdm import tqdm
import pandas as pd
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


Mount Google Drive and set paths

In [2]:
from google.colab import drive
drive.mount('/content/drive')

BASE = Path('/content/drive/MyDrive/Data')
IMAGES_DIR = BASE / 'images'

print("BASE:", BASE)
print("IMAGES_DIR exists:", IMAGES_DIR.exists())

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
BASE: /content/drive/MyDrive/Data
IMAGES_DIR exists: True


Inspect files and prepare image id -> file path mapping

In [3]:
image_files = list(IMAGES_DIR.glob('*'))
print("Number of files found in images folder:", len(image_files))

# Build a mapping from image id token (like 'image1') to full path
imgid_to_path = {}
for p in image_files:
    name = p.stem  # 'image1'
    imgid_to_path[name] = str(p)

Number of files found in images folder: 1449


Determining which images to process based on the .csv file for training and testing

In [4]:
data_csv_path = BASE / 'data.csv'
train_list_path = BASE / 'train_images_list.txt'
test_list_path = BASE / 'test_images_list.txt'

In [5]:
if data_csv_path.exists():
    df = pd.read_csv(data_csv_path)
    if list(df.columns) == [0,1,2] or 'image_id' not in df.columns:
        df.columns = ['question','answer','image_id']
    # Drop rows with missing or invalid image_id values
    df['image_id'] = df['image_id'].astype(str).str.strip()
    df = df[df['image_id'].notna() & (df['image_id'] != '')]
    # Remove accidental header-like rows (safety net)
    df = df[df['image_id'].str.lower() != 'image_id']
    df = df.reset_index(drop=True)
    print("Loaded data.csv:", df.shape)
else:
    df = pd.DataFrame(columns=['question','answer','image_id'])
    print("data.csv not found:", data_csv_path)

Loaded data.csv: (12468, 3)


In [6]:
def load_list(path):
    if path.exists():
        with open(path, 'r') as f:
            items = [line.strip() for line in f if line.strip()]
        return items
    return []

In [7]:
train_imgs = load_list(train_list_path)
test_imgs = load_list(test_list_path)
print("Train list count:", len(train_imgs), "Test list count:", len(test_imgs))

Train list count: 795 Test list count: 654


In [8]:
csv_image_ids = sorted(df['image_id'].unique()) if not df.empty else []
all_desired_ids = sorted(set(csv_image_ids + train_imgs + test_imgs))
print("Unique image ids referenced:", len(all_desired_ids))

# If empty, fallback to all images in folder
if len(all_desired_ids) == 0:
    all_desired_ids = sorted(imgid_to_path.keys())
    print("No ids in csv/lists — processing all images found in folder:", len(all_desired_ids))

all_desired_ids = [iid.strip() for iid in all_desired_ids]

Unique image ids referenced: 1449


Dataset & transforms

In [9]:
# Here ImageNet transform used for ResNet-50
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std  = [0.229, 0.224, 0.225]

In [10]:
transform = transforms.Compose([
    transforms.Resize(256),            # short side -> 256
    transforms.CenterCrop(224),        # crop to 224x224
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std),
])

In [11]:
class VQAImageDataset(Dataset):
    def __init__(self, img_ids, imgid_to_path, transform=None):
        self.img_ids = img_ids
        self.imgid_to_path = imgid_to_path
        self.transform = transform

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        path = self.imgid_to_path.get(img_id, None)

        # Try direct path first, otherwise attempt common extensions
        if path is None or not os.path.exists(path):
            alt = None
            for ext in ['.png', '.jpg', '.jpeg', '.JPG', '.PNG']:
                candidate = str(IMAGES_DIR / (img_id + ext))
                if os.path.exists(candidate):
                    alt = candidate
                    break
            if alt:
                path = alt

        if path is None or not os.path.exists(path):
            # WARNING fallback: return a dummy black image instead of crashing
            # This prevents the DataLoader worker process from raising FileNotFoundError.
            # Create a black PIL image (224x224)
            dummy = Image.new("RGB", (224, 224), color=(0, 0, 0))
            if self.transform:
                image_t = self.transform(dummy)
            else:
                image_t = transforms.ToTensor()(dummy)
            # Optionally, prefix the img_id so we can detect placeholder when loading features
            return img_id, image_t

        image = Image.open(path).convert('RGB')
        if self.transform:
            image_t = self.transform(image)
        else:
            image_t = transforms.ToTensor()(image)
        return img_id, image_t


sanity check

In [12]:
ds = VQAImageDataset(all_desired_ids[:5], imgid_to_path, transform=transform)
for i in range(len(ds)):
    img_id, im = ds[i]
    print(img_id, im.shape)

image1 torch.Size([3, 224, 224])
image10 torch.Size([3, 224, 224])
image100 torch.Size([3, 224, 224])
image1000 torch.Size([3, 224, 224])
image1001 torch.Size([3, 224, 224])


Building ResNet-50 as feature extractor (Input for ReNet50 224 -> 7x7)


In [13]:
resnet = models.resnet50(pretrained=True)
resnet.eval()
resnet.to(device)



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

global extractor: outputs 2048-d pooled vector after avgpool and flatten

In [14]:
global_extractor = nn.Sequential(*list(resnet.children())[:-1])  # excludes final fc layer
global_extractor.to(device)
global_extractor.eval()


Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)


spatial extractor: outputs 2048 x H x W map before avgpool

In [15]:
spatial_extractor = nn.Sequential(*list(resnet.children())[:-2])  # up to layer4 output
spatial_extractor.to(device)
spatial_extractor.eval()


Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)


In [16]:
# Freeze parameters (no grads)
for p in global_extractor.parameters():
    p.requires_grad = False
for p in spatial_extractor.parameters():
    p.requires_grad = False

print("Models ready. Global extractor & spatial extractor loaded.")

Models ready. Global extractor & spatial extractor loaded.


Feature extraction loop

In [17]:
from tqdm import tqdm

batch_size = 32
num_workers = 2
dataset = VQAImageDataset(all_desired_ids, imgid_to_path, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True, collate_fn=lambda batch: batch)

# Output paths
OUT_DIR = BASE / 'features'
OUT_DIR.mkdir(parents=True, exist_ok=True)
# dict saved via torch.save
global_out_path = OUT_DIR / 'features_global.pt'
spatial_out_path = OUT_DIR / 'features_spatial.pt'

In [18]:
# Quick sanity check: list missing image ids (those not present in imgid_to_path with any ext)
missing = []
for iid in all_desired_ids:
    if iid not in imgid_to_path:
        # also check common extensions
        found = False
        for ext in ['.png', '.jpg', '.jpeg', '.JPG', '.PNG']:
            if os.path.exists(str(IMAGES_DIR / (iid + ext))):
                found = True
                break
        if not found:
            missing.append(iid)

print("Missing image ids count:", len(missing))
if len(missing) > 0:
    print("First 20 missing ids:", missing[:20])
    # Optionally remove them from processing:

Missing image ids count: 0


In [19]:
features_global = {}
features_spatial = {}

with torch.no_grad():
    for batch in tqdm(dataloader, desc="Batches"):
        # batch is list of (img_id, image_t)
        img_ids = [item[0] for item in batch]
        imgs = torch.stack([item[1] for item in batch], dim=0).to(device)  # (B,3,224,224)

        # global features
        out = global_extractor(imgs)                   # (B, 2048, 1, 1)
        out = out.view(out.size(0), -1).cpu()          # (B, 2048)
        for i, img_id in enumerate(img_ids):
            features_global[img_id] = out[i].clone()

        # spatial features (B, 2048, 7, 7)
        sp = spatial_extractor(imgs).cpu()
        for i, img_id in enumerate(img_ids):
            features_spatial[img_id] = sp[i].clone()

# Save to drive
torch.save(features_global, global_out_path)
torch.save(features_spatial, spatial_out_path)

print("Saved global features to:", global_out_path)
print("Saved spatial features to:", spatial_out_path)
print("Example feature shapes:")
some_id = next(iter(features_global.keys()))
print(some_id, "global:", features_global[some_id].shape, "spatial:", features_spatial[some_id].shape)

Batches: 100%|██████████| 46/46 [00:36<00:00,  1.28it/s]


Saved global features to: /content/drive/MyDrive/Data/features/features_global.pt
Saved spatial features to: /content/drive/MyDrive/Data/features/features_spatial.pt
Example feature shapes:
image1 global: torch.Size([2048]) spatial: torch.Size([2048, 7, 7])


Save global features as numpy .npy matrix plus index mapping

In [20]:
idx_map = []
feat_list = []
for i, img_id in enumerate(sorted(features_global.keys())):
    idx_map.append(img_id)
    feat_list.append(features_global[img_id].numpy())

feat_mat = np.stack(feat_list, axis=0)  # (N, 2048)
np.save(OUT_DIR / 'features_global_matrix.npy', feat_mat)
with open(OUT_DIR / 'features_index.txt', 'w') as f:
    for img_id in idx_map:
        f.write(img_id + '\n')

print("Saved features_global_matrix.npy (shape {}) and features_index.txt".format(feat_mat.shape))


Saved features_global_matrix.npy (shape (1449, 2048)) and features_index.txt


Utility functions to load features later

In [21]:
def load_global_features(path=None):
    p = path or (OUT_DIR / 'features_global.pt')
    return torch.load(str(p))

def load_spatial_features(path=None):
    p = path or (OUT_DIR / 'features_spatial.pt')
    return torch.load(str(p))