Skip to content
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
Cannot retrieve contributors at this time
Run the images in RMNIST through a truncated version of ResNet-18, and
save the features in the final layer. Based on the transfer learning
tutorial and code by Sasank Chilamkurthy, at
Note that we ignore the test data. A more thorough treatment would
consider validation and test data separately.
# Standard library
from __future__ import print_function, division
import cPickle
import gzip
# My libraries
import data_loader
# Third-party libraries
import numpy as np
from PIL import Image
import torch
from torch.autograd import Variable
from import Dataset
from torchvision import datasets, models, transforms
# Configuration: use expanded training data or not
expanded = False
# Define the truncated model
net = models.resnet18(pretrained=True)
for param in net.parameters():
param.requires_grad = False
def forward_partial(model, x):
x = model.conv1(x)
x = model.bn1(x)
x = model.relu(x)
x = model.maxpool(x)
x = model.layer1(x)
x = model.layer2(x)
x = model.layer3(x)
x = model.layer4(x)
x = model.avgpool(x)
x = x.view(x.size(0), -1)
return x
class RMNIST(Dataset):
def __init__(self, n=0, train=True, transform=None, expanded=False):
self.n = n
self.transform = transform
td, vd, ts = data_loader.load_data(n, expanded=expanded)
if train: = td
else: = vd
def __len__(self):
return len([0])
def __getitem__(self, idx):
data =[0][idx]
img = (data*256)
img = img.reshape(28, 28)
imgColor = np.zeros((28, 28, 3), 'uint8')
imgColor[..., 0] = img
imgColor[..., 1] = img
imgColor[..., 2] = img
imgColor = Image.fromarray(imgColor, mode="RGB")
if self.transform:
imgColor = self.transform(imgColor)
value =[1][idx]
return (imgColor, value)
# Compute the abstract features for the validation data
data_transform = transforms.Compose(
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
data_loader_val =
RMNIST(10, train=False, transform=data_transform), batch_size=100)
data_val = list(data_loader_val)
print("\nComputing features for validation data")
for j in range(len(data_val)):
inputs, labels = data_val[j]
print("Computing features for batch {} of 100".format(j))
inputs, labels = Variable(inputs), Variable(labels)
outputs = forward_partial(net, inputs)
if j == 0:
vd2 = (,
vd2 = (
# Compute the abstract features for the RMNIST training data
for n in [1, 5, 10]:
print("\nComputing features for n = {}".format(n))
# Do everything in one batch
if not expanded: batch_size = n*10
else: batch_size = 9*n*10
data_loader_train =
RMNIST(n, transform=data_transform, expanded=expanded),
inputs, labels = next(iter(data_loader_train))
inputs, labels = Variable(inputs), Variable(labels)
outputs = forward_partial(net, inputs)
td2 = (,
if expanded:
name = "data/rmnist_abstract_features_expanded_{}.pkl.gz".format(n)
name = "data/rmnist_abstract_features_{}.pkl.gz".format(n)
f =, 'wb')
cPickle.dump((td2, vd2, (0,0)), f)