# An Exploration of the Data

In [1]:
from collections import defaultdict
import json

import loader as L

import numpy as np
np.random.seed(42)
import torch
torch.manual_seed(42)


<torch._C.Generator at 0x7f86137f19b0>

In [2]:
punks = json.loads(open(L.PUNK_LABELS).read())
train = json.loads(open(L.TRAIN_LABELS).read())
test = json.loads(open(L.TEST_LABELS).read())
datasets = (punks, train, test)

In [3]:
def key_splits(key, wholeset, trainset, testset):
    ds = len([p for p in wholeset.items() if p[1][key]])
    tr = len([p for p in trainset.items() if p[1][key]])
    te = len([p for p in testset.items() if p[1][key]])
    return (ds, tr, te)


def print_splits(key, *datasets):
    splits = key_splits(key, *datasets)
    print(f'{key}: {splits}')


print_splits('alien', *datasets)
print_splits('ape', *datasets)
print_splits('zombie', *datasets)


alien: (9, 9, 0)
ape: (24, 22, 2)
zombie: (88, 80, 8)


In [4]:
punks_df = L.make_punks_df(L.PUNK_LABELS)
train_idx, test_idx = L.df_split(punks_df, 1000)
print(len(train_idx))
print(len(test_idx))

9000
1000


In [5]:
from torch.utils.data import SubsetRandomSampler, DataLoader

punks_ds = L.PunksDataset(L.ALL_FILTERS, test_size=2000)
train_sampler = SubsetRandomSampler(punks_ds.train_idx)
test_sampler = SubsetRandomSampler(punks_ds.test_idx)


batch_size = 32

train_loader = DataLoader(
    dataset=punks_ds, batch_size=batch_size, shuffle=False, sampler=train_sampler
)
test_loader = DataLoader(
    dataset=punks_ds, batch_size=batch_size, shuffle=False, sampler=test_sampler
)


In [6]:
print(len(train_loader))                     # 250
print(len(iter(train_loader)) * batch_size)  # 8000

print(len(test_loader))                      # 63
print(len(iter(test_loader)) * batch_size)   # 2016


ps = defaultdict(int)
for idx, (punk, label) in enumerate(train_loader):
    ps[idx] = punk

# last batch is 32
print(len(punk))


ps = defaultdict(int)
for idx, (punk, label) in enumerate(test_loader):
    ps[idx] = punk

# last batch is 16, giving total of 2000 items
print(len(punk))


250
8000
63
2016
32
16


In [7]:
punks_ds = L.PunksDataset(L.ALL_FILTERS, test_size=1000)
print(len(punks_ds.train_idx))
print(len(punks_ds.test_idx))

punks_ds = L.PunksDataset(L.ALL_FILTERS, test_size=0)
print(len(punks_ds.train_idx))
print(len(punks_ds.test_idx))

9000
1000
10000
0
