In [None]:
import os
import torch
import pickle
import argparse
import h5py
from torch.autograd import Variable
from model import NetD, NetG
from PIL import Image, ImageDraw
import torchvision.transforms as transforms
from misc import get_logger, ges_Aonfig
import pandas as pd
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.max_colwidth', 500)
import matplotlib.pyplot as plt

In [None]:
from parse_metadata import EcommerceDataParser

In [None]:
def concat_image(images,nrows, ncols):
    width = 128
    height = 128
    
    background = Image.new('RGB',(128*ncols, 128*nrows), (255, 255, 255, 255))
    bg_w, bg_h = background.size
    
    for i,image in enumerate(images):
    #print(i)
        img = image
        draw = ImageDraw.Draw(img, "RGB")
        #font = ImageFont.truetype("/Library/Fonts/Arial.ttf",50)
        #draw.text((50, 50), str(i),)
        start_X = i%ncols * 128
        start_Y = i//ncols * 128
        offset = (start_X,start_Y,start_X+width,start_Y+height)
        background.paste(img,offset)
        
    return background

In [None]:
d_checkpoint_path = 'outputs/g_netD__epoch_995.pth'
g_checkpoint_path = 'outputs/netg_checkpoints_0-995/netG__epoch_995.pth'

In [None]:
n_z = 100
n_l = 100
n_t = 300
n_c = 64
netG = NetG(n_z=n_z, n_l=n_l, n_c=n_c, n_t=n_t)
netD = NetD(n_cls=10, n_t=100, n_f=64,docvec_size=300)
netD.load_state_dict(torch.load(d_checkpoint_path, map_location=lambda storage, loc: storage))
netG.load_state_dict(torch.load(g_checkpoint_path, map_location=lambda storage, loc: storage))

In [None]:
transform = transforms.Compose([transforms.ToPILImage(),])

In [None]:
netG.zero_grad()
netD.zero_grad()

### train sample

In [None]:
config = ges_Aonfig('configs/config-real.yaml')['PARSEMETA']
config

In [None]:
config['SPM_DIR_PATH'] = 'data/g_spm'
config['SPM_WP_PATH'] = 'data/g_spm/spm.vocab'
config['PARSE_DATA_PATH'] = 'data/datasets/products/g_products.tsv'
config['DOC2VEC_DIR_PATH'] = 'data/g_doc2vec'

In [None]:
parser = EcommerceDataParser(config, use=True)

In [None]:
df = pd.read_csv('./data/datasets/products/g_products.tsv',sep='\t',header=None)
df.info()

In [None]:
q = parser.text2wp('yamaha classical nylon string guitars yamaha c40 full size nylon string classical guitar')
parser.query_doc2vec_topn(q)

In [None]:
df[[1,2]].head()

In [None]:
def run():
    while True:
        title= input('1. 상품명을 입력해주세요!:\t') #'solid body schecter electric guitar'
        category= input('2. 카테고리명을 입력해주세요!:\t') #'electric guitars'
        brand= input('3. 브랜드명을 입력해주세요!:\t') #'yamaha'
        attr = input('4. 색상을 입력해주세요!:\t') #'sea blue'
        text = ' '.join([category, brand, attr, title])
        print('\n [title]: ',text)

        vec = parser.text2vec(text)
        caption = Variable(torch.from_numpy(vec.reshape(1,-1)))
        noise = Variable(torch.randn(1, 100)) # create random noise
        noise.data.normal_(0,1) # normalize the noise
        fake = netG(noise, caption)
        img = transform(fake[0].data)
        plt.figure()
        plt.imshow(img)
        plt.show()

In [None]:
run()

### abuse checker

In [None]:
def var_to_numpy(obj, isReal=True):
    obj = obj.permute(0,2,3,1)

    if isReal:
        obj = (obj+1) / 2
    else:
        obj = obj.squeeze(3)
    obj = torch.clamp(obj, min=0, max=1)
    return obj.data.cpu().numpy()

In [None]:
trans_img = transforms.Compose([transforms.Resize((128, 128)), #transforms.CenterCrop(image_size),
                                             transforms.ToTensor(),])

In [None]:
n = 100
images = []
real_images = []
imgdir = 'data/datasets/products/images'
temp_df = []
for index in range(n,n+10):
    asin = df[0][index] + '.jpg'
    title = df[2][index]
    real_images.append(Image.open(os.path.join(imgdir,asin)).resize((128,128)))
    img = Variable(trans_img(real_images[-1]).view(-1,3,128,128))
    vec = parser.text2vec(title)
    caption = Variable(torch.from_numpy(vec.reshape(1,-1)))
    noise = Variable(torch.randn(1, 100)) # create random noise
    noise.data.normal_(0,1) # normalize the noise
    fake = netG(noise, caption)
    
    isreal, isclass = netD(img,caption)
    result = isreal.data.numpy()#, isclass.data.numpy()
    print(result)
    temp_df.append(df.loc[[index]])
pd.concat(temp_df)[[1,2]]

In [None]:
concat_image(real_images,1,10)

In [None]:
# 1070 1: 95s
# parallel 1: 0.61 min 36s (batchsize=128, load=90~), 0.73 min  43s (batchsize=32, load=85~90)
# parallel 2: 0.56 min 33s (batchsize=64, load=80~90)
# parallel 3: 0.46 min 27s (batchsize=96, load=60~70)
# parallel 4: 0.43 min 25s (batchsize=128, load=40~50)

In [None]:
# plt.rcParams["figure.figsize"] = (20,4)
# x = ['gtx 1070', 
#      'dgx single (batch=128)', 
#      'dgx single (batch=32)', 
#      'dgx multi-2 (batch=64)', 
#      'dgx multi-3 (batch=96)', 
#      'dgx multi-4 (batch=128)']

# energy = [95,  36, 43, 33, 27, 25]

# x_pos = [i for i, _ in enumerate(x)]
# print(x_pos)
# plt.barh(x_pos, energy, color='green')
# plt.ylabel("GPU models",size=15)
# plt.xlabel("Sec per Epoch",size=15)
# plt.title("Performance by gpu model",size=20)

# plt.yticks(x_pos, x)

# plt.show()