In [65]:
# encoding=utf-8
import sys
sys.path.append('/workspace/external-libraries/')

import jieba
import os
import json
import numpy as np
import h5py
from scipy.misc import imread,imresize
from tqdm import tqdm
import torch
from random import seed, choice, sample
from PIL import Image

In [66]:
def bulid_data(data_root,filename):  
    entity = {}
    sents = []
    sents_tokenize = []
    with open(os.path.join(data_root,filename),'r') as f:
        data = json.load(f)
    for data_part in data['response']['annotations']:
        if len(data_part['attributes']) > 5:
            seg_list = jieba.cut(data_part['attributes'].strip().replace(u'。',''),cut_all = False)
            sents_tokenize.append(list(seg_list))
            sents.append(data_part['attributes'])
    img_name = annotation.split('_')[0]
    if sents != []:
        entity['image_name'] = img_name
        entity['sents'] = sents
        entity['sents_token'] = sents_tokenize
        return entity
    else:
        return {}

In [87]:
def bulid_vocab(imgs):
    param = {}
    counts = {}
    # bulid_vocab
    for img in imgs:
        #print(img['sents_token'])
        for sent in img['sents_token']:
            for w in sent:
                counts[w] = counts.get(w,0) + 1
                
    cw = sorted([(count,w) for w,count in counts.items()],reverse=True)               
    print('top words and their counts:')
    print('\n'.join(map(str,cw[:100])))
    total_words = sum(counts.values())           
    print('total words:', total_words)
  
    bad_words = [w for w,n in counts.items() if n <= 1]
    vocab = [w for w,n in counts.items() if n > 1]
    bad_count = sum(counts[w] for w in bad_words)
    print('number of bad words: %d/%d = %.2f%%' % (len(bad_words), len(counts), len(bad_words)*100.0/len(counts)))
    print('number of words in vocab would be %d' % (len(vocab), ))
    print('number of UNKs: %d/%d = %.2f%%' % (bad_count, total_words, bad_count*100.0/total_words))
    
    sent_lengths = {}
    for img in imgs:
        for sent in img['sents_token']:
            nw = len(sent)
            sent_lengths[nw] = sent_lengths.get(nw,0) +1
    max_len = max(sent_lengths.keys())
    print('max length sentence in raw data: ', max_len)
    print('sentence length distribution (count, number of words):')
    sum_len = sum(sent_lengths.values())
    for i in range(max_len+1):
        print('%2d: %10d   %f%%' % (i, sent_lengths.get(i,0), sent_lengths.get(i,0)*100.0/sum_len))
    param['max_length'] = max_len
    
       
    vocab.append(u'<unk>')
    vocab.append(u'<start>')
    vocab.append(u'<end>')
    vocab.insert(0,u'<pad>')
    
    #print(vocab)
    for img in imgs:
        img['final_captions'] = []
        for sent in img['sents_token']:
            caption = [w if counts.get(w,0) > 1 else u'UNK' for w in sent]
            caption.append(u'<end>')
            caption.insert(0,u'<start>')
            img['final_captions'].append(caption)
    return vocab,param

In [83]:
def create_input_files(imgs,split, params, word_map,image_root):
    output_folder = './process_data_3/'
    max_len = param['max_length'] + 2 
    image_name_data = os.listdir(image_root)
    with h5py.File(os.path.join(output_folder,'{}_IMAGE.hdf5'.format(split)),'a') as h:
        h.attrs['captions_per_image']  = 2
        images = h.create_dataset('images',(len(imgs),3,256,256), dtype='uint8')
        print('\nReading images and captions, storing to file...\n"')
        enc_captions = []
        caplens = []

        for i, img in enumerate(tqdm(imgs)):
            if len(img['final_captions']) < 2:
                print(img['image_path'])
                captions = img['final_captions'] + [choice(img['final_captions']) for _ in range(2 - len(img['final_captions']))]
            else:
                captions = sample(img['final_captions'], k=2)

            # Sanity check
            assert len(captions) == 2
            #read image
            for image_name in image_name_data:
                #print(img['image_path'].split('/')[-1], image_name.split('.')[0])
                if img['image_path'].split('/')[-1] == image_name.split('.')[0]:
                    img_path = img['image_path'] +'.'+ image_name.split('.')[-1]
            #print(img_path)
            IMG = np.asarray(Image.open(img_path))
            
            if len(IMG.shape) == 2:
                IMG = IMG[:, :, np.newaxis]
                IMG = np.concatenate([IMG, IMG, IMG], axis=2)
            
            if IMG.shape[-1] == 4:
                #print(IMG.shape)
                IMG = Image.open(img_path).convert("RGB") 
           
            IMG = imresize(IMG, (256, 256))
            IMG = IMG.transpose(2, 0, 1)
            assert IMG.shape == (3, 256, 256)
            assert np.max(IMG) <= 255
            
            images[i] = IMG
            
            #encode caption
            for sent in captions:
                # word --> word_map['word'] = 1, [1,2,3,4,5] + [0,0,0] 
                enc_c = [word_map.get(word, word_map['<unk>']) for word in sent] + [word_map['<pad>']] * (max_len - len(sent))
                # Find caption lengths
                c_len = len(sent)
                
                enc_captions.append(enc_c)
                caplens.append(c_len)
        # Sanity check
        print(images.shape[0])
        print(len(enc_captions), len(caplens))
        assert images.shape[0] * 2 == len(enc_captions) == len(caplens)

        # Save encoded captions and their lengths to JSON files
        with open(os.path.join(output_folder,  '{}_CAPTIONS'.format(split) + '.json'), 'w') as j:
            json.dump(enc_captions, j)

        with open(os.path.join(output_folder,  '{}_CAPLENS'.format(split) + '.json'), 'w') as j:
            json.dump(caplens, j)
            

In [None]:
['img_rgb','img2','img3']
['img_an','img2_an',]

In [71]:
annotation_file_list = ['./raw_data/新闻描述二期第一批数据/','./raw_data/新闻描述二期第二批数据/']
image_root = './raw_data/images_all'

In [77]:
an_all = []
for i in annotation_file_list:
    annotation_file = [os.path.join(i,file) for file in os.listdir(i) if file[-4:] == '.txt']
    #print(annotation_file)
    an_all.extend(annotation_file)
an_all    

['./raw_data/新闻描述二期第一批数据/MBM0W5J4M8XP_cm新闻标注二期—游行.txt',
 './raw_data/新闻描述二期第一批数据/V2X4HARUIC28_cm新闻标注二期—火灾.txt',
 './raw_data/新闻描述二期第一批数据/W60JBBZY0433_cm新闻标注二期—地震.txt',
 './raw_data/新闻描述二期第一批数据/7CBDH771T2XH_cm新闻标注二期—空难.txt',
 './raw_data/新闻描述二期第一批数据/NY5SRX7Y3THH_cm新闻标注二期—暴乱.txt',
 './raw_data/新闻描述二期第二批数据/TV5LJFQ99YDV_cm新闻标注二期—火灾2.txt',
 './raw_data/新闻描述二期第二批数据/PY4KBCL7GIWI_cm新闻标注二期—交通事故2.txt',
 './raw_data/新闻描述二期第二批数据/11XT4VZXXG4I_cm新闻标注二期—洪水2.txt',
 './raw_data/新闻描述二期第二批数据/ENEA042DXMPS_cm新闻标注二期—暴乱2.txt',
 './raw_data/新闻描述二期第二批数据/11LL6SPDHRZP_cm新闻标注二期—坍塌2.txt',
 './raw_data/新闻描述二期第二批数据/FNOTA94NT9TZ_cm新闻标注二期—矿难2.txt',
 './raw_data/新闻描述二期第二批数据/2I4X0LCU0507_cm新闻标注二期—游行2.txt',
 './raw_data/新闻描述二期第二批数据/JVLA5FM8MZMG_cm新闻标注二期—海啸2.txt',
 './raw_data/新闻描述二期第二批数据/RYZNDDUE4RIZ_cm新闻标注二期—泥石流2.txt',
 './raw_data/新闻描述二期第二批数据/PNBZDL3U82TJ_cm新闻标注二期—空难2.txt',
 './raw_data/新闻描述二期第二批数据/G5LXGRCH8TGX_cm新闻标注二期—爆炸2.txt',
 './raw_data/新闻描述二期第二批数据/40YRCTD7WHVD_cm新闻标注二期—山体滑坡2.txt',
 './raw_data/新闻描述二期第二批数据/KFW9XX

In [82]:
#构造数据集
data_annotations = []
for files in an_all:
    for annotation in os.listdir(files):
        #print(os.path.join(files,annotation))
        #print(files)
        entity = bulid_data(files,annotation)
        if entity:
            data_annotations.append(entity)
len(data_annotations)

10201

['我'，‘是’，‘下’，‘x’]

In [84]:
# 构造单词表        
vocab,param = bulid_vocab(data_annotations)
# 单词索引表
itow = {i : w for i,w in enumerate(vocab)} # a 1-indexed vocab translation table
wtoi = {w:i  for i,w in enumerate(vocab)} # inverse table
# 保存单词表
with open('./process_data_3/word_map.json','w') as f:
    json.dump(wtoi,f)

for img in data_annotations:
    img['image_path'] = image_root +'/'+ img['image_name']

train_data = data_annotations[:int(0.8*len(data_annotations))]
val_data = data_annotations[int(0.8*len(data_annotations)):]
#输出网络输入文件：编码caption
create_input_files(train_data,'train',param, wtoi,image_root)
create_input_files(val_data, 'val', param, wtoi,image_root) 
#保存data_annotations
with open('./process_data_3/data_annotations.json','w') as f:
    json.dump(data_annotations,f)

number of vocab is 3769


`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.
  0%|          | 2/8160 [00:00<06:51, 19.83it/s]


Reading images and captions, storing to file...
"


 82%|████████▏ | 6664/8160 [03:45<00:36, 40.85it/s]

./raw_data/images_all/139169


 86%|████████▌ | 7031/8160 [03:55<00:23, 49.03it/s]

./raw_data/images_all/117395


100%|██████████| 8160/8160 [04:23<00:00, 31.01it/s]


8160
16320 16320


  0%|          | 4/2041 [00:00<01:01, 33.31it/s]


Reading images and captions, storing to file...
"


 72%|███████▏  | 1479/2041 [00:38<00:11, 50.79it/s]

./raw_data/images_all/139060


 83%|████████▎ | 1696/2041 [00:42<00:07, 48.10it/s]

./raw_data/images_all/111227


100%|██████████| 2041/2041 [00:50<00:00, 40.07it/s]


2041
4082 4082


In [88]:
vocab,param = bulid_vocab(data_annotations)

top words and their counts:
(28920, '，')
(14971, '发生')
(14747, '某地')
(8760, '现场')
(6183, '在')
(5914, '救援')
(5482, '的')
(3718, '正在')
(3510, '某')
(3088, '火灾')
(3027, '人员')
(2958, '群众')
(2952, '坍塌')
(2342, '地震')
(2317, '被')
(2284, '地区')
(2260, '游行')
(2127, '房屋')
(2023, '上')
(1866, '消防员')
(1862, '汽车')
(1844, '严重')
(1811, '事故')
(1782, '进行')
(1728, '暴乱')
(1623, '滑坡')
(1594, '山体')
(1509, '事故现场')
(1495, '大量')
(1468, '交通事故')
(1467, '着')
(1355, '浓烟')
(1324, '警察')
(1281, '废墟')
(1201, '建筑')
(1194, '多名')
(1193, '散落')
(1174, '一名')
(1117, '一辆')
(1078, '抗议')
(1074, '事件')
(952, '大火')
(915, '了')
(902, '飞机')
(892, '倒塌')
(860, '发生爆炸')
(857, '举行')
(843, '矿难')
(841, '工作')
(833, '围观')
(770, '街头')
(758, '残骸')
(738, '受损')
(701, '一')
(694, '滚滚')
(692, '清理')
(687, '有')
(686, '后')
(665, '车祸')
(645, '和')
(638, '搜救')
(634, '海啸')
(634, '活动')
(632, '工作人员')
(630, '损毁')
(625, '中')
(613, '袭击')
(608, '地面')
(602, '建筑物')
(570, '实施')
(562, '浓烟滚滚')
(562, '展开')
(545, '一片')
(540, '道路')
(537, '洪水')
(528, '一片狼藉')
(524, '查看')
(52