In [1]:
## Libraries
import glob
from multiprocessing import cpu_count
import os
import sys

## 3rd party
from gensim.models import Word2Vec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image, ImageFilter
import torch
from torch.utils.data import DataLoader

_path = ".."
if _path not in sys.path:
    sys.path.append(_path)
from lib.dataset import TextArtDataLoader, AlignCollate, ImageBatchSampler
from lib.config import Config
# from lib.preprocess import (pad_image, crop_edges_lr, )

%reload_ext autoreload
%autoreload 2

In [2]:
WORD2VEC_MODEL_FILE = "../models/deviant_wiki_word2vec.model"
BATCH_SIZE = 4
# N_WORKERS = cpu_count() - 1
N_WORKERS = 1
CONFIG = Config()

In [106]:
train_dataset = TextArtDataLoader('united', WORD2VEC_MODEL_FILE, mode='train')
val_dataset = TextArtDataLoader('united', WORD2VEC_MODEL_FILE, mode='val')
test_dataset = TextArtDataLoader('united', WORD2VEC_MODEL_FILE, mode='test')
train_align_collate = AlignCollate('train',
                                   CONFIG.MEAN,
                                   CONFIG.STD,
                                   CONFIG.IMAGE_SIZE_HEIGHT,
                                   CONFIG.IMAGE_SIZE_WIDTH,
                                   horizontal_flipping=CONFIG.HORIZONTAL_FLIPPING,
                                   random_rotation=CONFIG.RANDOM_ROTATION,
                                   color_jittering=CONFIG.COLOR_JITTERING,
                                   random_grayscale=CONFIG.RANDOM_GRAYSCALE,
                                   random_channel_swapping=CONFIG.RANDOM_CHANNEL_SWAPPING,
                                   random_gamma=CONFIG.RANDOM_GAMMA,
                                   random_resolution=CONFIG.RANDOM_RESOLUTION)
val_align_collate = AlignCollate('val',
                                   CONFIG.MEAN,
                                   CONFIG.STD,
                                   CONFIG.IMAGE_SIZE_HEIGHT,
                                   CONFIG.IMAGE_SIZE_WIDTH,
                                   horizontal_flipping=CONFIG.HORIZONTAL_FLIPPING,
                                   random_rotation=CONFIG.RANDOM_ROTATION,
                                   color_jittering=CONFIG.COLOR_JITTERING,
                                   random_grayscale=CONFIG.RANDOM_GRAYSCALE,
                                   random_channel_swapping=CONFIG.RANDOM_CHANNEL_SWAPPING,
                                   random_gamma=CONFIG.RANDOM_GAMMA,
                                   random_resolution=CONFIG.RANDOM_RESOLUTION)

train_loader = DataLoader(train_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=True,
                          num_workers=N_WORKERS,
                          pin_memory=True,
                          collate_fn=train_align_collate,
                          sampler=None,
                         )
val_loader = DataLoader(val_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=True,
                          num_workers=N_WORKERS,
                          pin_memory=True,
                          collate_fn=val_align_collate,
                         )
test_loader = DataLoader(test_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=True,
                          num_workers=N_WORKERS,
                          pin_memory=True,
                          collate_fn=None,
                         )

In [None]:
DataLoader?

In [None]:
# for image, label, image_file in train_loader:
# for label_sentence, image_file in train_loader:
for image, word_vectors_tensor in train_loader:
#     print(label_sentence, image_file)
    print("IMAGE:", image.shape)
    print("WV:", word_vectors_tensor.shape)
    break

In [None]:
img = np.array(image)[1].transpose(1, 2, 0)

In [None]:
plt.imshow(img)

In [None]:
np.all(img <= 1.0)

In [None]:
np.all(img >= -1.0)

In [20]:
img

NameError: name 'img' is not defined

In [142]:
batch_sampler = ImageBatchSampler('united', BATCH_SIZE)

In [143]:
df = batch_sampler.df

In [97]:
## Group batches
df_n_labels_grouped = df.groupby(by=pd.cut(df['n_labels'], [0, 5, 7, 11, 1000]))
for key1, item1 in df_n_labels_grouped:
    print("\nN LABEL GROUP", key1)
    df1 = df_n_labels_grouped.get_group(key1)
    df_width_grouped = df1.groupby(by=pd.cut(df1['width'], [0, 500, 700, 1000, 10000]))
    for key2, item2 in df_width_grouped:
        print("\tWIDTH GROUP", key2)
        df2 = df_width_grouped.get_group(key2)
        df_height_grouped = df2.groupby(by=pd.cut(df2['height'], [0, 590, 10000]))
        for key3, item3 in df_height_grouped:
            print("\t\tHEIGHT GROUP", key3)
#             print(df_height_grouped.get_group(key3), "\n\n")
            print('\t\t\t', len(df_height_grouped.get_group(key3).index), 'samples')


N LABEL GROUP (0, 5]
	WIDTH GROUP (0, 500]
		HEIGHT GROUP (0, 590]
			 2072 samples
		HEIGHT GROUP (590, 10000]
			 7162 samples
	WIDTH GROUP (500, 700]
		HEIGHT GROUP (0, 590]
			 750 samples
		HEIGHT GROUP (590, 10000]
			 1827 samples
	WIDTH GROUP (700, 1000]
		HEIGHT GROUP (0, 590]
			 4112 samples
		HEIGHT GROUP (590, 10000]
			 1633 samples
	WIDTH GROUP (1000, 10000]
		HEIGHT GROUP (0, 590]
			 191 samples
		HEIGHT GROUP (590, 10000]
			 1540 samples

N LABEL GROUP (5, 7]
	WIDTH GROUP (0, 500]
		HEIGHT GROUP (0, 590]
			 1328 samples
		HEIGHT GROUP (590, 10000]
			 6970 samples
	WIDTH GROUP (500, 700]
		HEIGHT GROUP (0, 590]
			 777 samples
		HEIGHT GROUP (590, 10000]
			 1763 samples
	WIDTH GROUP (700, 1000]
		HEIGHT GROUP (0, 590]
			 4706 samples
		HEIGHT GROUP (590, 10000]
			 1694 samples
	WIDTH GROUP (1000, 10000]
		HEIGHT GROUP (0, 590]
			 101 samples
		HEIGHT GROUP (590, 10000]
			 769 samples

N LABEL GROUP (7, 11]
	WIDTH GROUP (0, 500]
		HEIGHT GROUP (0, 590]
			 1088

In [147]:
batches = []
for batch in batch_sampler:
    batches.append(batch)

In [161]:
index = 10
batches[index]

[1092, 1126, 1132, 1141]

In [162]:
df[df['index'].isin(batches[index])]

Unnamed: 0,image_file,width,height,n_labels,index
1092,data/wikiart/images/0063588.jpeg,495,512,4,1092
1126,data/wikiart/images/0049695.jpeg,305,500,4,1126
1132,data/wikiart/images/0001607.jpeg,300,364,4,1132
1141,data/wikiart/images/0023579.jpeg,306,467,5,1141


In [107]:
len(train_dataset)

63832

In [117]:
12786 * 4

51144

In [119]:
63832 // 4

15958

In [120]:
s = 0
for group in batch_sampler.groups:
    s += len(group)

In [122]:
s

63831