In [1]:
import numpy as np
# PyTorch
import torch
import torchvision

In [2]:
import sys
sys.path.append('../src/')

%load_ext autoreload
%autoreload 2
# Importing our custom module(s)
import utils

In [3]:
dataset_directory = '/cluster/tufts/hugheslab/eharve06/CIFAR-10'
n = 100
random_state = 1001
_, train_dataset, val_dataset = utils.get_cifar10_datasets(dataset_directory, n, True, random_state)
_, train_and_val_dataset, test_dataset = utils.get_cifar10_datasets(dataset_directory, n, False, random_state)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [4]:
batch_size = 128
num_workers = 0
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers)
train_and_val_loader = torch.utils.data.DataLoader(train_and_val_dataset, batch_size=batch_size, num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers)

In [5]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [6]:
checkpoint = torch.load('/cluster/tufts/hugheslab/eharve06/resnet50_torchvision/resnet50_torchvision_model.pt', map_location=torch.device('cpu'), weights_only=False)
model = torchvision.models.resnet50()
model.fc = torch.nn.Identity()
model.load_state_dict(checkpoint)
model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [7]:
train_metrics = utils.encode_images(model, train_loader)
print(np.stack(train_metrics['encoded_images']).shape)
print(np.stack(train_metrics['labels']).shape)

val_metrics = utils.encode_images(model, val_loader)
print(np.stack(val_metrics['encoded_images']).shape)
print(np.stack(val_metrics['labels']).shape)

train_and_val_metrics = utils.encode_images(model, train_and_val_loader)
print(np.stack(train_and_val_metrics['encoded_images']).shape)
print(np.stack(train_and_val_metrics['labels']).shape)

test_metrics = utils.encode_images(model, test_loader)
print(np.stack(test_metrics['encoded_images']).shape)
print(np.stack(test_metrics['labels']).shape)

np.savez(
    file=f'/cluster/tufts/hugheslab/eharve06/understanding-SNGP/data/CIFAR-10_n={n}_random_state={random_state}.npz',
    X_train=np.stack(train_metrics['encoded_images']), 
    y_train=np.stack(train_metrics['labels']), 
    X_val=np.stack(val_metrics['encoded_images']), 
    y_val=np.stack(val_metrics['labels']),
    X_train_and_val=np.stack(train_and_val_metrics['encoded_images']), 
    y_train_and_val=np.stack(train_and_val_metrics['labels']),
    X_test=np.stack(test_metrics['encoded_images']), 
    y_test=np.stack(test_metrics['labels']),
)

(80, 2048)
(80,)
(20, 2048)
(20,)
(100, 2048)
(100,)
(10000, 2048)
(10000,)
