<a href="https://colab.research.google.com/github/jonyghosh444/transformer-res-ger/blob/master/convert_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


For CIFAR-10 and CIFAR-100

In [None]:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, CIFAR100
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np

def extract_resnet_features_cifar(dataset_class, root, out_file, num_classes):
    model = models.resnet18(pretrained=True)
    model.fc = torch.nn.Identity()
    model.eval()

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])

    dataset_train = dataset_class(root=root, train=True, transform=transform, download=True)
    dataset_test = dataset_class(root=root, train=False, transform=transform, download=True)

    train_loader = DataLoader(dataset_train, batch_size=128, shuffle=False)
    test_loader = DataLoader(dataset_test, batch_size=128, shuffle=False)

    def encode(loader):
        features, labels = [], []
        with torch.no_grad():
            for x, y in tqdm(loader, desc="Extracting features"):
                feats = model(x).cpu().numpy()
                features.append(feats)
                labels.append(y.numpy())
        return np.concatenate(features), np.concatenate(labels)

    feats_train, labels_train = encode(train_loader)
    feats_test, labels_test = encode(test_loader)

    np.savez(out_file,
             traindata=feats_train,
             trainlabel=labels_train,
             testdata=feats_test,
             label_test=labels_test)

    print(f"Saved encoded dataset to {out_file}")

In [None]:
# For CIFAR-10
extract_resnet_features_cifar(CIFAR10, "/content/drive/MyDrive/transformer-r&d/dataset/data/cifar10", "/content/drive/MyDrive/transformer-r&d/dataset/data/CIFAR10_resnet18_224.npz", 10)

# For CIFAR-100
extract_resnet_features_cifar(CIFAR100, "/content/drive/MyDrive/transformer-r&d/dataset/data/cifar100", "/content/drive/MyDrive/transformer-r&d/dataset/data/CIFAR100_resnet18_224.npz", 100)

Saved encoded dataset to /content/drive/MyDrive/transformer-r&d/dataset/data/CIFAR10_resnet18_224.npz


100%|██████████| 169M/169M [00:01<00:00, 101MB/s]


Saved encoded dataset to /content/drive/MyDrive/transformer-r&d/dataset/data/CIFAR100_resnet18_224.npz
