In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader, Dataset, random_split, TensorDataset
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, ToPILImage, CenterCrop, RandomResizedCrop
from torchvision.datasets import ImageFolder
from torchvision.models import alexnet, resnet18, inception_v3

from torchvision.models.alexnet import AlexNet_Weights
from torchvision.models.inception import Inception_V3_Weights
from torchvision.models.resnet import ResNet18_Weights

In [2]:
from download_rps import download_rps

data_path = "../data/"
train_data_path = data_path + "rps"
val_data_path = data_path + "rps-test-set"
download_rps(data_path)

rps folder already exists!
rps-test-set folder already exists!


In [8]:
# ImageNet statistics
normalizer = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transformer = Compose([Resize(256), CenterCrop(224), ToTensor(), normalizer])

#transformer = ResNet18_Weights.IMAGENET1K_V1.transforms

train_data = ImageFolder(root=train_data_path, transform=transformer)
val_data = ImageFolder(root=val_data_path, transform=transformer)

# Builds a loader of each set
train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
val_loader = DataLoader(val_data, batch_size=16)

In [9]:
resnet = resnet18(weights=ResNet18_Weights.DEFAULT)

In [10]:
def freeze_model(model):
    for parameter in model.parameters():
        parameter.requires_grad = False

In [11]:
torch.manual_seed(42)

<torch._C.Generator at 0x1fa114a3df0>

In [12]:
def preprocessed_dataset(model, loader, device=None):
    if device is None:
        device = next(model.parameters()).device

    features = None
    labels = None

    for i, (x, y) in enumerate(loader):
        model.eval()
        x = x.to(device)
        output = model(x)
        if i == 0:
            features = output.detach().cpu()
            labels = y.cpu()
        else:
            features = torch.cat([features, output.detach().cpu()])
            labels = torch.cat([labels, y.cpu()])

    dataset = TensorDataset(features, labels)
    return dataset

In [13]:
from preprocessed_dataset import preprocessed_dataset

# Change the top layer to Identity
resnet.fc = nn.Identity()
# Freeze the model
freeze_model(resnet)

train_preproc = preprocessed_dataset(resnet, train_loader)
val_preproc = preprocessed_dataset(resnet, val_loader)

In [None]:
torch.save(train_preproc.tensors, data_path + "train_preproc.pth")
torch.save(val_preproc.tensors, data_path + "val_preproc.pth")
train_preproc_loader = DataLoader(train_preproc, batch_size=16, shuffle=True)
val_preproc_loader = DataLoader(val_preproc, batch_size=16)