# setup

In [1]:
import torch

from PIL import Image
from torchvision import transforms

from datasets import load_dataset

from tqdm import tqdm
import os

In [2]:
BATCH_SIZE = 256

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [4]:
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def transform(examples):
    examples['image'] = [preprocess(image.convert("RGB")) for image in examples['image']]
    return examples

def collate_fn(examples):
    images = []
    labels = []
    for example in examples:
        images.append((example['image']))
        labels.append(example['label'])

    images = torch.stack(images)
    labels = torch.tensor(labels)
    return {'image': images, 'label': labels}

# data

In [5]:
trn = load_dataset("evanarlian/imagenet_1k_resized_256", split = 'train')
trn

Resolving data files:   0%|          | 0/52 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/52 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/39 [00:00<?, ?it/s]

Dataset({
    features: ['image', 'label'],
    num_rows: 1281167
})

In [6]:
trn = trn.with_transform(transform)
trn_loader = torch.utils.data.DataLoader(trn, collate_fn = collate_fn, batch_size = BATCH_SIZE)
trn.num_rows / BATCH_SIZE

2502.279296875

# model

In [7]:
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet34', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet101', pretrained=True)
# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet152', pretrained=True)

Using cache found in /home/josegfer/.cache/torch/hub/pytorch_vision_v0.10.0


In [8]:
model = model.to(device)

In [9]:
backbone = list(model.children())[:-1]
model = torch.nn.Sequential(*backbone)

# synthesis

In [10]:
H_trn = torch.empty(size = [0])

model.eval()
with torch.no_grad():
    for i, sample in tqdm(enumerate(trn_loader)):
        x = sample['image'].to(device)
        h = model.forward(x)

        H_trn = torch.cat((H_trn, h.cpu()))

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
2305it [47:55,  3.13s/it]

: 

# write

In [None]:
if not os.path.exists('output'):
    os.makedirs('output')

In [None]:
torch.save(H_trn, 'output/H_trn.pt')
# H = torch.load('output/H_trn.pt')