# Create CIFAR10 shards

In [None]:
!pip install webdataset

In [None]:
import torchvision
import webdataset as wds
from sklearn.model_selection import train_test_split
from pathlib import Path
import sys

### Download train and test set
* Loop Train set 26 times to get 1,300,000 train samples
* Loop Test set 10 times to get 100,000 test samples

In [None]:
train_list = []

for x in range(26):
    trainset = torchvision.datasets.CIFAR10(root="./", train=True, download=True)
    train_list.extend(trainset)

print("Size of train_list:",sys.getsizeof(train_list))
print("Length of train_list:",len(train_list))

In [None]:
test_list = []

for x in range(10):
    testset = torchvision.datasets.CIFAR10(root="./", train=False, download=True)
    test_list.extend(testset)
    
print("Size of test_list:",sys.getsizeof(test_list))
print("Length of test_list:",len(test_list))

## Create Tar Shards

Create local path for storing shards

In [None]:
output_pth = "cifar-shards"
Path(output_pth).mkdir(parents=True, exist_ok=True)
Path(output_pth + "/train").mkdir(parents=True, exist_ok=True)
Path(output_pth + "/val").mkdir(parents=True, exist_ok=True)

Write sharded tar files; 2,000 samples per shard

In [None]:
output_pth = "cifar-shards"

for name in [(train_list, "train"), (test_list, "val")]:
    with wds.ShardWriter(
        output_pth + "/" + str(name[1]) + "/" + "cifar-" + str(name[1]) + "-%06d.tar",
        maxcount=2000,
    ) as sink:
        for index, (image, cls) in enumerate(name[0]):
            sink.write(
                {"__key__": "%07d" % index, "ppm": image, "cls": cls}
            )

Copy shards to your GCS bucket

In [None]:
!gsutil -m cp -r cifar-shards/val gs:// # TODO: Add your GCS bucket location

In [None]:
!gsutil -m cp -r cifar-shards/train gs:// # TODO: Add your GCS bucket location