# setup

In [1]:
import torch

from PIL import Image


from datasets import load_dataset
from utils import transform, collate_fn, model_select

from tqdm import tqdm
import os

In [2]:
BATCH_SIZE = 256

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

# data

In [22]:
trn = load_dataset("evanarlian/imagenet_1k_resized_256", split = 'train')
val = load_dataset("evanarlian/imagenet_1k_resized_256", split = 'val')
trn, val

Resolving data files:   0%|          | 0/52 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/52 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/39 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/52 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/52 [00:00<?, ?it/s]

(Dataset({
     features: ['image', 'label'],
     num_rows: 1281167
 }),
 Dataset({
     features: ['image', 'label'],
     num_rows: 50000
 }))

In [23]:
trn = trn.with_transform(transform)
trn_loader = torch.utils.data.DataLoader(trn, collate_fn = collate_fn, batch_size = BATCH_SIZE)

val = val.with_transform(transform)
val_loader = torch.utils.data.DataLoader(val, collate_fn = collate_fn, batch_size = BATCH_SIZE)

trn.num_rows / BATCH_SIZE, val.num_rows / BATCH_SIZE

(5004.55859375, 195.3125)

# model

In [15]:
model_label = 'rn34'
model = model_select(model_label)

Using cache found in /home/josegfer/.cache/torch/hub/pytorch_vision_v0.10.0


In [16]:
model = model.to(device)

In [17]:
backbone = list(model.children())[:-1]
model = torch.nn.Sequential(*backbone)

# synthesis

In [18]:
H_trn = torch.empty(size = [0])

model.eval()
with torch.no_grad():
    for i, sample in tqdm(enumerate(trn_loader)):
        x = sample['image'].to(device)
        h = model.forward(x)

        H_trn = torch.cat((H_trn, h.cpu()))

In [19]:
H_val = torch.empty(size = [0])

model.eval()
with torch.no_grad():
    for i, sample in tqdm(enumerate(val_loader)):
        x = sample['image'].to(device)
        h = model.forward(x)

        H_val = torch.cat((H_val, h.cpu()))

0it [00:00, ?it/s]

196it [01:36,  2.03it/s]


# write

In [21]:
if not os.path.exists('output'):
    os.makedirs('output')

torch.save(H_trn, 'output/{}_H_train.pt'.format(model_label))
torch.save(H_val, 'output/{}_H_val.pt'.format(model_label))