In [1]:
## Libraries
import glob
from multiprocessing import cpu_count
import os
import sys

## 3rd party
from gensim.models import Word2Vec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image, ImageFilter
import torch
from torch.utils.data import DataLoader

_path = ".."
if _path not in sys.path:
    sys.path.append(_path)
from lib.dataset import TextArtDataLoader, AlignCollate, ImageBatchSampler
from lib.config import Config
# from lib.preprocess import (pad_image, crop_edges_lr, )

%reload_ext autoreload
%autoreload 2

In [14]:
WORD2VEC_MODEL_FILE = "../models/united_word2vec.model"
BATCH_SIZE = 4
# N_WORKERS = cpu_count() - 1
N_WORKERS = 0
CONFIG = Config()

In [67]:
train_dataset = TextArtDataLoader('united', WORD2VEC_MODEL_FILE, mode='train')
val_dataset = TextArtDataLoader('united', WORD2VEC_MODEL_FILE, mode='val')
test_dataset = TextArtDataLoader('united', WORD2VEC_MODEL_FILE, mode='test')
train_align_collate = AlignCollate('train',
                                   CONFIG.MEAN,
                                   CONFIG.STD,
                                   CONFIG.IMAGE_SIZE_HEIGHT,
                                   CONFIG.IMAGE_SIZE_WIDTH,
                                   horizontal_flipping=CONFIG.HORIZONTAL_FLIPPING,
                                   random_rotation=CONFIG.RANDOM_ROTATION,
                                   color_jittering=CONFIG.COLOR_JITTERING,
                                   random_grayscale=CONFIG.RANDOM_GRAYSCALE,
                                   random_channel_swapping=CONFIG.RANDOM_CHANNEL_SWAPPING,
                                   random_gamma=CONFIG.RANDOM_GAMMA,
                                   random_resolution=CONFIG.RANDOM_RESOLUTION)
val_align_collate = AlignCollate('val',
                                   CONFIG.MEAN,
                                   CONFIG.STD,
                                   CONFIG.IMAGE_SIZE_HEIGHT,
                                   CONFIG.IMAGE_SIZE_WIDTH,
                                   horizontal_flipping=CONFIG.HORIZONTAL_FLIPPING,
                                   random_rotation=CONFIG.RANDOM_ROTATION,
                                   color_jittering=CONFIG.COLOR_JITTERING,
                                   random_grayscale=CONFIG.RANDOM_GRAYSCALE,
                                   random_channel_swapping=CONFIG.RANDOM_CHANNEL_SWAPPING,
                                   random_gamma=CONFIG.RANDOM_GAMMA,
                                   random_resolution=CONFIG.RANDOM_RESOLUTION)

batch_sampler = ImageBatchSampler('united', BATCH_SIZE)

train_loader = DataLoader(train_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=False,
                          num_workers=N_WORKERS,
                          pin_memory=True,
                          collate_fn=train_align_collate,
                          sampler=batch_sampler,
                         )
val_loader = DataLoader(val_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=False,
                          num_workers=N_WORKERS,
                          pin_memory=True,
                          collate_fn=val_align_collate,
                         )
test_loader = DataLoader(test_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=False,
                          num_workers=N_WORKERS,
                          pin_memory=True,
                          collate_fn=None,
                         )

In [70]:
# for image, label, image_file in train_loader:
# for label_sentence, image_file in train_loader:
for image, word_vectors_tensor in train_loader:
#     print(label_sentence, image_file)
    print("IMAGE:", image.shape)
    print("WV:", word_vectors_tensor.shape)

INDEX: 32
LABEL SENTENCE: ['portrait', 'surrealism', 'female', 'portraits']
INDEX: 57
LABEL SENTENCE: ['flor', 'de', 'pascua', 'illustration', 'expressionism']
INDEX: 70
LABEL SENTENCE: ['stamp']
INDEX: 124
LABEL SENTENCE: ['portrait', 'rococo', 'male', 'portraits']
IMAGE: torch.Size([4, 3, 196, 256])
WV: torch.Size([4, 5, 2000])
INDEX: 131
LABEL SENTENCE: ['abstract', 'spatialism', 'monochrome']
INDEX: 186
LABEL SENTENCE: ['icon', 'pixel']
INDEX: 223
LABEL SENTENCE: ['landscape', 'realism', 'autumn', 'forests', 'trees']
INDEX: 246
LABEL SENTENCE: ['abstract', 'minimalism', 'monochrome']
IMAGE: torch.Size([4, 3, 196, 256])
WV: torch.Size([4, 3, 2000])
INDEX: 249
LABEL SENTENCE: ['stamp']
INDEX: 306
LABEL SENTENCE: ['landscape', 'symbolism', 'mountains']
INDEX: 338
LABEL SENTENCE: ['portrait', 'post', 'impressionism', 'female', 'portraits']
INDEX: 345
LABEL SENTENCE: ['portrait', 'impressionism', 'family', 'portraits']
IMAGE: torch.Size([4, 3, 196, 256])
WV: torch.Size([4, 4, 2000])
IND

LABEL SENTENCE: ['horse', 'equine', 'unicorn', 'tarot', 'card']
INDEX: 2748
LABEL SENTENCE: ['portrait', 'impressionism', 'male', 'portraits']
INDEX: 2788
LABEL SENTENCE: ['landscape', 'realism', 'gardens', 'parks']
IMAGE: torch.Size([4, 3, 196, 256])
WV: torch.Size([4, 5, 2000])
INDEX: 2802
LABEL SENTENCE: ['abstract', 'minimalism', 'monochrome']
INDEX: 2806
LABEL SENTENCE: ['persona']
INDEX: 2842
LABEL SENTENCE: ['portrait', 'post', 'impressionism', 'male', 'portraits']
INDEX: 2872
LABEL SENTENCE: ['portrait', 'realism', 'male', 'portraits']
IMAGE: torch.Size([4, 3, 196, 256])
WV: torch.Size([4, 4, 2000])
INDEX: 2926
LABEL SENTENCE: ['severe', 'storm', 'weather']
INDEX: 2977
LABEL SENTENCE: ['f2u', 'ftu', 'divider']
INDEX: 2979
LABEL SENTENCE: ['sketch', 'study', 'realism', 'female', 'portraits']
INDEX: 3003
LABEL SENTENCE: ['sketch', 'study', 'cubism', 'streets', 'squares']
IMAGE: torch.Size([4, 3, 196, 256])
WV: torch.Size([4, 4, 2000])
INDEX: 3036
LABEL SENTENCE: ['flower', 'paint

KeyboardInterrupt: 

In [37]:
t = torch.zeros(1, 2000)

In [40]:
t.shape

torch.Size([1, 2000])

In [38]:
mod = train_dataset.word2vec_model

In [31]:
mod.vector_size

2000

In [None]:
img = np.array(image)[1].transpose(1, 2, 0)

In [None]:
plt.imshow(img)

In [None]:
np.all(img <= 1.0)

In [None]:
np.all(img >= -1.0)

In [None]:
img

In [57]:
batch_sampler = ImageBatchSampler('united', BATCH_SIZE)

In [None]:
df = batch_sampler.df

In [None]:
## Group batches
df_n_labels_grouped = df.groupby(by=pd.cut(df['n_labels'], [0, 5, 7, 11, 1000]))
for key1, item1 in df_n_labels_grouped:
    print("\nN LABEL GROUP", key1)
    df1 = df_n_labels_grouped.get_group(key1)
    df_width_grouped = df1.groupby(by=pd.cut(df1['width'], [0, 500, 700, 1000, 10000]))
    for key2, item2 in df_width_grouped:
        print("\tWIDTH GROUP", key2)
        df2 = df_width_grouped.get_group(key2)
        df_height_grouped = df2.groupby(by=pd.cut(df2['height'], [0, 590, 10000]))
        for key3, item3 in df_height_grouped:
            print("\t\tHEIGHT GROUP", key3)
#             print(df_height_grouped.get_group(key3), "\n\n")
            print('\t\t\t', len(df_height_grouped.get_group(key3).index), 'samples')

In [60]:
batches = []
for batch in batch_sampler:
#     batches.extend(batch)
    batches.append(batch)

In [64]:
batches[-5000:-2000]

[[21981, 22059, 22083, 22097],
 [22141, 22187, 22199, 22207],
 [22209, 22212, 22220, 22239],
 [22245, 22279, 22304, 22335],
 [22346, 22359, 22362, 22459],
 [22481, 22655, 22787, 22803],
 [22816, 22870, 22915, 22917],
 [22949, 22966, 23001, 23017],
 [23078, 23081, 23088, 23185],
 [23252, 23349, 23368, 23408],
 [23428, 23434, 23527, 23554],
 [23590, 23657, 23673, 23756],
 [23764, 23832, 23928, 23932],
 [23952, 24003, 24016, 24043],
 [24082, 24091, 24150, 24154],
 [24168, 24198, 24203, 24247],
 [24276, 24338, 24357, 24372],
 [24419, 24426, 24428, 24447],
 [24452, 24468, 24523, 24540],
 [24580, 24640, 24673, 24691],
 [24739, 24760, 24795, 24826],
 [24930, 24942, 24956, 25005],
 [25017, 25070, 25091, 25136],
 [25251, 25259, 25304, 25335],
 [25361, 25432, 25452, 25465],
 [25519, 25521, 25524, 25552],
 [25603, 25652, 25690, 25724],
 [25727, 25758, 25821, 25837],
 [25855, 25887, 25913, 25945],
 [25955, 25960, 25998, 26003],
 [26079, 26121, 26161, 26167],
 [26168, 26180, 26238, 26247],
 [26250,

In [56]:
batches

[32,
 57,
 70,
 124,
 131,
 186,
 223,
 246,
 249,
 306,
 338,
 345,
 428,
 468,
 499,
 538,
 551,
 560,
 574,
 576,
 584,
 614,
 662,
 684,
 685,
 708,
 772,
 788,
 809,
 823,
 856,
 863,
 897,
 909,
 926,
 954,
 984,
 1034,
 1050,
 1058,
 1092,
 1126,
 1132,
 1141,
 1187,
 1190,
 1196,
 1226,
 1233,
 1256,
 1278,
 1287,
 1299,
 1304,
 1311,
 1341,
 1342,
 1351,
 1382,
 1383,
 1423,
 1427,
 1443,
 1516,
 1522,
 1531,
 1563,
 1585,
 1626,
 1629,
 1665,
 1682,
 1710,
 1732,
 1736,
 1756,
 1796,
 1828,
 1862,
 1896,
 1938,
 1959,
 1977,
 2002,
 2040,
 2058,
 2059,
 2061,
 2074,
 2098,
 2103,
 2147,
 2265,
 2292,
 2304,
 2322,
 2417,
 2469,
 2475,
 2491,
 2669,
 2688,
 2748,
 2788,
 2802,
 2806,
 2842,
 2872,
 2926,
 2977,
 2979,
 3003,
 3036,
 3055,
 3079,
 3089,
 3123,
 3168,
 3169,
 3180,
 3188,
 3224,
 3261,
 3286,
 3300,
 3346,
 3352,
 3353,
 3355,
 3382,
 3383,
 3389,
 3419,
 3427,
 3519,
 3534,
 3567,
 3577,
 3580,
 3604,
 3651,
 3705,
 3720,
 3749,
 3776,
 3818,
 3833,
 3861,
 387

In [47]:
np.array(batches)

array([array([ 32,  57,  70, 124]), array([131, 186, 223, 246]),
       array([249, 306, 338, 345]), ...,
       array([63655, 63666, 63682, 63720]),
       array([63743, 63780, 63791, 63808]), array([63823, 63829])],
      dtype=object)

In [11]:
index = 10
batches[index]

[1092, 1126, 1132, 1141]

In [None]:
df[df['index'].isin(batches[index])]

In [None]:
len(train_dataset)

In [None]:
12786 * 4

In [None]:
63832 // 4

In [None]:
s = 0
for group in batch_sampler.groups:
    s += len(group)

In [None]:
s