In [1]:
import torch
import torchvision.transforms as T

from utility.get_data import S3ImageFolder

In [2]:
imagenet_a_path = "imagenet-a"
imagenet_b_path = "imagenetv2-matched-frequency-format-val/"

In [3]:
def get_data(batch_size, img_root):
    # Prepare data transformations for the train loader
    transform = T.Compose([
        T.Resize((256, 256)),                                                   # Resize each PIL image to 256 x 256
        T.RandomCrop((224, 224)),                                               # Randomly crop a 224 x 224 patch
        T.ToTensor(),                                                           # Convert Numpy to Pytorch Tensor
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])      # Normalize with ImageNet mean
    ])

    # Load data
    # officehome_dataset = ImageFolder(root=img_root, transform=transform)
    officehome_dataset = S3ImageFolder(root=img_root, transform=transform)

    # Create train and test splits (80/20)
    num_samples = len(officehome_dataset)
    training_samples = int(num_samples * 0.8 + 1)
    test_samples = num_samples - training_samples

    training_data, test_data = torch.utils.data.random_split(officehome_dataset, [training_samples, test_samples])

    # Initialize dataloaders
    train_loader = torch.utils.data.DataLoader(training_data, batch_size, shuffle=True, num_workers=4)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size, shuffle=False, num_workers=4)

    return train_loader, test_loader

In [10]:
train_loader, test_loader = get_data(batch_size=32, img_root= imagenet_a_path)



tensor([[[ 0.9132,  1.0159,  1.2557,  ...,  0.5193,  0.2282,  0.0569],
         [ 0.7591,  0.9988,  1.1358,  ...,  0.8789,  0.7419,  0.5364],
         [ 1.2214,  1.4098,  1.4440,  ...,  0.7933,  0.8618,  0.8789],
         ...,
         [-0.1143, -0.1143, -0.0458,  ..., -1.1418, -0.6623, -0.1486],
         [ 0.5364,  0.5878,  0.5878,  ..., -0.3027, -0.3541, -0.1486],
         [ 0.4166,  0.4508,  0.4166,  ...,  0.0569, -0.1999, -0.1314]],

        [[ 0.6954,  0.8179,  1.0630,  ...,  0.3978,  0.1001, -0.0749],
         [ 0.5728,  0.8354,  0.9755,  ...,  0.8004,  0.6429,  0.3978],
         [ 1.0105,  1.2206,  1.2731,  ...,  0.7304,  0.7829,  0.8004],
         ...,
         [-0.1450, -0.1800, -0.0749,  ..., -1.3354, -0.8978, -0.3901],
         [ 0.2752,  0.3102,  0.2927,  ..., -0.6702, -0.9328, -0.7752],
         [ 0.0826,  0.1176,  0.0826,  ..., -0.3375, -0.7577, -0.7752]],

        [[ 0.9494,  1.0714,  1.3154,  ...,  0.7751,  0.4265,  0.2522],
         [ 0.7402,  1.0017,  1.1411,  ...,  1