In [1]:
from torch.utils.data import DataLoader
import torch
from torchvision.models import mobilenet_v3_large
import torch.nn as nn

torch.manual_seed(42)

images = torch.load("03_baseline_cnn/data/images.pt")
labels = torch.load("03_baseline_cnn/data/labels.pt")

dataset = torch.utils.data.TensorDataset(images, labels)

loader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)

In [2]:
model = mobilenet_v3_large(pretrained=True, weights='IMAGENET1K_V1')

for param in model.parameters():
    param.requires_grad = False

model.set_submodule('classifier', nn.Identity())

model.get_model_inputs_from_batch = (lambda batch: [batch[0]])
model.get_labels_from_batch = (lambda batch: batch[1])


with torch.no_grad():
    model.eval()
    for batch in loader:
        processed_images = model(*model.get_model_inputs_from_batch(batch))
        print(processed_images.shape)
        labels = model.get_labels_from_batch(batch)
        print(labels.shape)

        torch.save(processed_images, "03_baseline_cnn/data/extracted_features/basedata.pt")
        torch.save(labels, "03_baseline_cnn/data/extracted_features/basedata_labels.pt")

torch.Size([1000, 960])
torch.Size([1000])
