In [1]:
import zipfile
import os, io
import numpy as np
from glob import glob
import pickle
import matplotlib.pyplot as plt

def unpickle(file):
    """
        Read downsampled 64*64 ImageNet data. Refer to: https://patrykchrabaszcz.github.io/Imagenet32/
        Return (dict): keys: ['data', 'labels', 'mean'], 
                    One file with validation data (“val_data”) contains python dictionary with fields 'data' and 'labels' (There is no 'mean' field)
    """
    with open(file, 'rb') as fo:
        dict = pickle.load(fo)
    return dict

In [3]:
# Calculate mean value among all training data
for f in sorted(glob('/data/ljc/datasets/imagenet64/train' + '/*')):
    print(f)
    dic = unpickle(f)
    mean_val = dic['mean']
    break
mean_val /= np.float32(255)

/data/ljc/datasets/imagenet64/train/train_data_batch_1


In [4]:
print(mean_val)

[0.46454885 0.46888116 0.47197611 ... 0.39271004 0.39108653 0.38879286]


In [14]:
fname = sorted(glob('/data/ljc/datasets/imagenet64/train' + '/*'))[3]
print(fname)
dic_file = unpickle(fname)
data = dic_file['data'][12]

a = data.reshape(3,64,64)
b = transform_to_rgb(data)
print(a)
print(b)

print(np.array_equal(a, b))



/data/ljc/datasets/imagenet64/train/train_data_batch_3
[[[130 136 141 ... 150 138 128]
  [134 142 144 ... 159 145 133]
  [137 143 146 ... 158 148 135]
  ...
  [152 151 149 ... 169 167 166]
  [151 156 156 ... 170 167 165]
  [155 159 161 ... 169 165 162]]

 [[136 144 151 ... 161 147 136]
  [143 152 157 ... 168 154 140]
  [146 154 158 ... 171 160 144]
  ...
  [159 158 159 ... 171 170 168]
  [162 162 163 ... 171 168 165]
  [163 164 167 ... 172 168 165]]

 [[127 146 156 ... 170 145 126]
  [137 154 162 ... 179 158 132]
  [143 158 163 ... 182 165 140]
  ...
  [159 157 156 ... 170 168 164]
  [161 161 162 ... 171 167 166]
  [160 163 166 ... 173 169 165]]]
[[[130 136 141 ... 150 138 128]
  [134 142 144 ... 159 145 133]
  [137 143 146 ... 158 148 135]
  ...
  [152 151 149 ... 169 167 166]
  [151 156 156 ... 170 167 165]
  [155 159 161 ... 169 165 162]]

 [[136 144 151 ... 161 147 136]
  [143 152 157 ... 168 154 140]
  [146 154 158 ... 171 160 144]
  ...
  [159 158 159 ... 171 170 168]
  [162 162 

In [15]:
"""
    Read batch data and save each image to .npy files
"""

# read a line from the file and transform it into RGB image
def transform_to_rgb(x):
    # red_channel = data[:4096].reshape(64, 64)
    # green_channel = data[4096:8192].reshape(64, 64)
    # blue_channel = data[8192:].reshape(64, 64)

    # # merge channels
    # rgb_image = np.dstack((red_channel, green_channel, blue_channel))
    # return rgb_image
    
    img_size = 64
    img_size2 = img_size * img_size
    
    x = np.dstack((x[:img_size2], x[img_size2:2*img_size2], x[2*img_size2:]))
    x = x.reshape((img_size, img_size, 3)).transpose(2, 0, 1)
    return x
    
    # return data.reshape(3, 64, 64)
    
split = ['train', 'val']
sp = split[1]

folder = '/data/ljc/datasets/imagenet64/' + sp
out_folder = '/data/ljc/datasets/imagenet64/processed/' + sp
label_file = open(out_folder + '/labels.txt', 'w')

cnt = 0
for f in sorted(glob(folder + '/*')):
    print('Processing', f)
    
    dic = unpickle(f)
    
    for i in range(len(dic['data'])):
        data = dic['data'][i] / np.float32(255) - mean_val
        rgb_array = transform_to_rgb(data)       # scale to [0,1]. remove mean value, as suggested by https://patrykchrabaszcz.github.io/Imagenet32/
        np.save(f'{out_folder}/{cnt}.npy', rgb_array)
        
        label_file.write(str(dic['labels'][i]) + '\n')

        cnt += 1
        
label_file.close()

Processing /data/ljc/datasets/imagenet64/val/val_data


In [20]:
from datahelpers import *

train_dataset = ImageNet64(data_dir=config.imagenet_root + '/train')
val_dataset = ImageNet64(data_dir=config.imagenet_root + '/val')

In [33]:
min(train_dataset.labels)

1

In [22]:
from transformers import SwinForImageClassification, SwinConfig
import torch
model = SwinForImageClassification(SwinConfig(num_labels=1000))

In [23]:

print(model(torch.randn((1, 3, 224, 224))))

SwinImageClassifierOutput(loss=None, logits=tensor([[-3.7007e-02, -9.4032e-02, -2.8452e-01, -1.8496e-01, -2.3617e-01,
          1.6389e-01,  3.0949e-01, -3.0664e-01,  1.1930e-02, -1.2461e-02,
          2.4920e-02, -4.4661e-02,  2.3746e-01, -2.6010e-01, -3.9160e-02,
         -1.3676e-01, -2.8632e-01,  7.8793e-02,  2.1298e-01, -1.9918e-01,
         -6.6813e-03, -1.9647e-01, -2.7243e-01,  1.6768e-01, -6.2374e-02,
          2.4330e-01, -5.9319e-02,  2.9655e-02,  1.4096e-02,  5.3957e-01,
          1.7502e-01,  9.8141e-02,  3.8639e-02,  2.5275e-02,  5.9951e-02,
          1.5369e-01, -3.3170e-02,  2.6567e-01,  2.3810e-01, -2.0118e-01,
         -8.3134e-02,  1.9961e-01,  3.0514e-01,  4.0529e-02, -1.2756e-01,
          2.5514e-02, -3.2710e-02,  2.4419e-01, -1.5964e-01, -3.7872e-01,
          1.6504e-02,  1.2596e-01,  3.7921e-02,  7.3245e-02,  1.0696e-01,
          2.0552e-01,  5.5373e-02,  1.5895e-01,  2.9111e-01,  3.0961e-02,
          2.1013e-01,  6.9665e-02, -1.1492e-01,  8.3009e-02,  1.6287

In [20]:
torch.randn((3, 224, 224))

tensor([[[-0.4063,  0.4352,  0.4018,  ...,  0.4083, -0.0632,  0.0698],
         [-0.5786,  0.8937, -0.4722,  ...,  0.0998, -1.7388,  0.0179],
         [ 1.0482,  0.1388, -2.9439,  ...,  0.3666,  0.4443,  1.4896],
         ...,
         [-1.2764,  0.3628, -0.5293,  ..., -0.4802,  0.0855, -1.0949],
         [-0.0820, -0.3424,  0.7744,  ..., -0.1909, -0.0242, -1.4338],
         [ 0.1420, -0.7519, -0.8286,  ...,  0.8351, -0.2846,  1.8606]],

        [[ 0.7406, -1.0296,  0.4085,  ...,  1.5133, -0.1217,  0.0426],
         [-1.4924,  0.9613,  0.5370,  ...,  0.6186, -1.2316, -0.8629],
         [ 0.4451, -0.2484, -1.7573,  ..., -0.0112, -1.1742,  0.5534],
         ...,
         [ 0.5740, -0.6405,  1.1363,  ..., -0.3477,  0.1725,  0.4684],
         [ 0.0635, -0.5757,  0.4703,  ..., -0.3974,  0.4564, -0.1175],
         [ 0.8520,  0.0939, -0.0535,  ..., -0.3939, -1.1417, -0.0738]],

        [[-0.1748, -1.3678,  0.1449,  ...,  0.7008, -1.1654,  1.4726],
         [-0.1947, -0.9391, -1.2697,  ..., -0