In [1]:
# %% [code]
# Imports
import os
import json
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd

# Import the teacher's FineTunedResNet model from a_resnet_transfer_trainer.py
from a_resnet_transfer_trainer import FineTunedResNet
# Import the ImageDataset (assumes it takes a CSV file path and an image directory)
from image_dataset_pytorch import ImageDataset

# %% [code]
# Set device (GPU if available, else CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# %% [code]
# Parameters (adjust as needed)
num_classes = 13  # Update if your dataset has a different number of classes
saved_weights = 'data/model_evaluation/model_20250211-192554/weights/epoch_5.pth' # for future reference: 'data/final_model/image_model.pt'
training_csv = 'data/training_data.csv'   # This CSV should have columns "Image" and "labels"
image_dir = 'cleaned_images/'

# %% [code]
# Instantiate the teacher's model and load the saved weights.
model_training = FineTunedResNet(num_classes)
state_dict = torch.load(saved_weights, map_location=device)
model_training.load_state_dict(state_dict)
model_training.to(device)
print("Model weights loaded.")

# %% [code]
# Convert the model for feature extraction.
# According to the teacher's instruction, we remove the final classification head by taking all children except the last.
# The teacher’s FineTunedResNet is defined with:
#    self.combined_model = nn.Sequential(self.model, self.new_layers)
# Removing the last element gives us just self.model.
model_extractor = nn.Sequential(*list(model_training.children())[:-1])
model_extractor.to(device)
model_extractor.eval()
print("Model converted for feature extraction using teacher's instruction:")
print(model_extractor)

# %% [code]
# Define the transformation pipeline (should match what was used during training).
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# %% [code]
# Load a small subset of the training dataset.
# We assume the CSV file (training_csv) is up-to-date and has columns "Image" and "labels".
dataset = ImageDataset(training_csv, image_dir)
# For testing, use a DataLoader with a small batch size.
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

# %% [code]
# Compute embeddings for a few images.
image_embeddings = {}
with torch.no_grad():
    for idx, (image, label, img_name) in enumerate(dataloader):
        image = image.to(device)
        # Use the converted model to extract features.
        embedding = model_extractor(image)
        # Save the embedding using the image index as key.
        image_embeddings[str(idx)] = embedding.cpu().tolist()
        if idx >= 4:  # Process the first 5 images for testing.
            break

# Print sample embeddings.
for key, emb in image_embeddings.items():
    print(f"Image index {key} embedding (first 5 values): {emb[:5]}")

# %% [code]
# Save the computed embeddings to a JSON file.
output_dir = 'data/output'
os.makedirs(output_dir, exist_ok=True)
embeddings_path = os.path.join(output_dir, 'test_image_embeddings.json')
with open(embeddings_path, 'w') as f:
    json.dump(image_embeddings, f)
print(f"Sample image embeddings saved to {embeddings_path}")


Using device: cpu
Model weights loaded.
Model converted for feature extraction using teacher's instruction:
Sequential(
  (0): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, 