In [1]:
split = 'train'

from collect_twitter_data.data_info import data_info
use_account = data_info['animal']

In [2]:
from img_transform import *

In [29]:
import pickle

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.utils
from torchvision import models
import torchvision.datasets as datasets

import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

from PIL import Image
import os, glob
import numpy as np

def loadPickle(fileName):
    with open(fileName, mode="rb") as f:
        return pickle.load(f)

class TwitterDataset(Dataset):
    def __init__(self, split, use_account, vocab, image_transform=None, text_tokenizer=None, data_dir='data'):
        self.split = split
        self.image_transform = image_transform
#         self.text_transform = text_transform
        self.text_tokenizer = text_tokenizer
        self.vocab = vocab

        self.data_dir =  data_dir
#         self.imgs = glob.glob(os.path.join(self.data_dir, "resized_images/*.png"))
        self.imgs = glob.glob(os.path.join(self.data_dir, "images/*.png"))

        # 使うtwitterアカウントのアノテーションだけ読み込む
        self.annos = []
        for user in use_account:
            ann_path = os.path.join(self.data_dir, f"annos/{user}.pickle")
            ann = loadPickle(ann_path)
            self.annos += ann

        print(f'Created {self.split} Dataset of Len: {len(self.annos)}')
        
    def __getitem__(self, idx):
        ann = self.annos[idx]
        image_file = os.path.join(self.data_dir, f'images/{ann["filename"]}')

        orig_text = ann['text']
        orig_img = Image.open(image_file).convert("RGB")

        img = self.image_transform[self.split](orig_img)
#         img = self.image_transform(orig_img)
        tokens = self.text_tokenizer.tokenize(orig_text, return_str=True).split()
        caption = []
        caption.append(self.vocab('<start>'))
        caption.extend([self.vocab(token) for token in tokens])
        caption.append(self.vocab('<end>'))
        target = torch.Tensor(caption)

#         data = {'image': img, 'text': text, 'orig_img': orig_img, 'orig_text': orig_text ,'screen_name': ann['screen_name']}
        
        return img, target

    def __len__(self):
        return len(self.annos)

In [30]:
from torchvision import transforms
transform = transforms.Compose([ 
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(), 
    transforms.ToTensor(), 
    transforms.Normalize((0.485, 0.456, 0.406), 
                         (0.229, 0.224, 0.225))])

In [37]:
from preprocess.japanese_tokenizer import JapaneseTokenizer
from build_vocab import Vocabulary

In [34]:
mecab_dict_path = "/home/smg/nishikawa/src/lib/mecab/dic/ipadic"
text_tokenizer = JapaneseTokenizer(splitter="MeCab", model=mecab_dict_path)

In [38]:
# with open('data/vocab_ja.pkl', 'rb') as f:
#     vocab = pickle.load(f)

with open("/home/smg/nishikawa/pytorch-tutorial/tutorials/03-advanced/image_captioning/data/vocab_ja.pkl", 'rb') as f:
    vocab = pickle.load(f)

In [39]:
dataset = TwitterDataset(split, use_account, vocab, image_transform=image_transform, text_tokenizer=text_tokenizer, data_dir='data')
# dataset = TwitterDataset(split, use_account, image_transform=transform, data_dir='data')

Created train Dataset of Len: 343


In [40]:
dataset.annos[0]

{'screen_name': 'mofumofu_cn',
 'text': 'えへへ〜どうだ〜w ',
 'media_url': 'http://pbs.twimg.com/media/B2jQH86CIAAfwJ-.jpg',
 'media_id': 533905390870601728,
 'filename': 'mofumofu_cn_533905390870601728.png'}

In [47]:
data_loader = torch.utils.data.DataLoader(dataset=dataset, 
                                          batch_size=16,
                                          shuffle=True,
                                          num_workers=1,
                                          collate_fn=collate_fn)

In [48]:
def collate_fn(data):
    """Creates mini-batch tensors from the list of tuples (image, caption).
    
    We should build custom collate_fn rather than using default collate_fn, 
    because merging caption (including padding) is not supported in default.

    Args:
        data: list of tuple (image, caption). 
            - image: torch tensor of shape (3, 256, 256).
            - caption: torch tensor of shape (?); variable length.

    Returns:
        images: torch tensor of shape (batch_size, 3, 256, 256).
        targets: torch tensor of shape (batch_size, padded_length).
        lengths: list; valid length for each padded caption.
    """
    # Sort a data list by caption length (descending order).
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions = zip(*data)

    # Merge images (from tuple of 3D tensor to 4D tensor).
    images = torch.stack(images, 0)

    # Merge captions (from tuple of 1D tensor to 2D tensor).
    lengths = [len(cap) for cap in captions]
    targets = torch.zeros(len(captions), max(lengths)).long()
    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]        
    return images, targets, lengths


In [49]:
len(batch)

2

In [50]:
batch[0].shape

torch.Size([16, 3, 224, 224])

In [51]:
for batch in data_loader:
    break

In [56]:
batch[2]

[34, 19, 16, 11, 10, 9, 9, 9, 9, 9, 8, 7, 6, 6, 5, 4]

In [58]:
import heapq                        # heapq のインポート
hq = []                             # heapqの作成


In [59]:
heapq.heappush(hq, (1, 2))     # 要素の追加

In [60]:
y = heapq.heappop(hq)  

In [84]:
heapq.heappush(hq, (1, [2,3]))

In [85]:
hq

[(1, [2, 3]), (9, 2)]

In [77]:
y = heapq.heappop(hq)  

In [78]:
y

(1, 2)

In [83]:
heapq.heappop(hq)  

(5, 2)