In [10]:
from easydict import EasyDict as edict
import torch.nn as nn
import torch
import numpy as np
import os
import pickle
from torchvision import transforms
import nltk
from PIL import Image
from pycocotools.coco import COCO
import torch.utils.data as Data

In [11]:
device = torch.device('cuda: 0' if torch.cuda.is_available() else 'cpu')

In [12]:
args = edict({
    'model_path': './models/',
    'crop_size': 224,
    'vocab_path': './data/vocab.pkl',
    'image_dir': './data/resized2014',
    'caption_path': './data/annotations/captions_train2014.json',
    'log_step': 10,
    'save_step': 1000,
    'embed_size': 256,
    'hidden_size': 512,
    'num_layers': 1,
    'num_epochs': 5,
    'batch_size': 128,
    'num_works': 2,
    'learning_rate': 0.001,
})

In [13]:
transform = transforms.Compose([
    transforms.RandomCrop(args.crop_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),
                        (0.229, 0.224, 0.225))
])

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
是来自于imagenet数据集计算出来的，所有预训练的模型的期望输入图像相同的归一化，  
即小批量形状通道的RGB图像（3 x H x W），其中H和W预计将至少224。  
这些图像必须被加载到[ 0, 1 ]的范围内，  
然后使用平均= [ 0.485，0.456，0.406 ]和STD＝[ 0.229，0.224，0.225 ]进行归一化。
这里应该是ResNet-152的预训练模型，所以统一了标准，如果完全是重新训练的，
自己根据数据集重新计算会更好

## Load data

In [14]:
from build_vocab import Vocabulary
# 必须要导入这个，下面pickle.load才不会出错，为何？？？

# Load vocabulary wrapper
with open(args.vocab_path, 'rb') as f:
    vocab = pickle.load(f)

In [15]:
class CocoDataset(Data.Dataset):
    """COCO Custom Dataset compatible with torch.utils.data.DataLoader."""
    def __init__(self, root, json, vocab, transform=None):
        """Set the path for images, captions and vocabulary wrapper.
        
        Args:
            root: image directory.
            json: coco annotation file path.
            vocab: vocabulary wrapper.
            transform: image transformer.
        """
        self.root = root
        self.coco = COCO(json)
        self.ids = list(self.coco.anns.keys())
        self.vocab = vocab
        self.transform = transform

    def __getitem__(self, index):
        """Returns one data pair (image and caption)."""
        coco = self.coco
        vocab = self.vocab
        ann_id = self.ids[index]
        caption = coco.anns[ann_id]['caption']
        img_id = coco.anns[ann_id]['image_id']
        path = coco.loadImgs(img_id)[0]['file_name']

        image = Image.open(os.path.join(self.root, path)).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)

        # Convert caption (string) to word ids.
        tokens = nltk.tokenize.word_tokenize(str(caption).lower())
        caption = []
        caption.append(vocab('<start>'))
        caption.extend([vocab(token) for token in tokens])
        caption.append(vocab('<end>'))
        target = torch.Tensor(caption)
        return image, target

    def __len__(self):
        return len(self.ids)


def collate_fn(data):
    """Creates mini-batch tensors from the list of tuples (image, caption).
    
    We should build custom collate_fn rather than using default collate_fn, 
    because merging caption (including padding) is not supported in default.

    Args:
        data: list of tuple (image, caption). 
            - image: torch tensor of shape (3, 256, 256).
            - caption: torch tensor of shape (?); variable length.

    Returns:
        images: torch tensor of shape (batch_size, 3, 256, 256).
        targets: torch tensor of shape (batch_size, padded_length).
        lengths: list; valid length for each padded caption.
    """
    # Sort a data list by caption length (descending order).
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions = zip(*data)

    # Merge images (from tuple of 3D tensor to 4D tensor).
    images = torch.stack(images, 0)

    # Merge captions (from tuple of 1D tensor to 2D tensor).
    lengths = [len(cap) for cap in captions]
    targets = torch.zeros(len(captions), max(lengths)).long()
    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]        
    return images, targets, lengths

collate_fn是用于将一个batch的数据stack起来的，但是因为没有能对caption起作用的，所以需要重新定义。

In [16]:
# COCO caption dataset
coco = CocoDataset(
    root=args.image_dir,
    json=args.caption_path,
    vocab=vocab,
    transform=transform
)

# Build data loader
data_loader = Data.DataLoader(
    dataset=coco, 
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=args.num_works,
    collate_fn=collate_fn
)


loading annotations into memory...
Done (t=0.99s)
creating index...
index created!


In [18]:
for i, (images, captions, lengths) in enumerate(data_loader):
    print(images.shape)
    print(ca.shape)
    print(len(lengths))
    
    break

torch.Size([128, 3, 224, 224])
torch.Size([128, 20])
128


In [27]:
emb = nn.Embedding(len(vocab), 256)

In [36]:
emb(captions).shape

torch.Size([128, 20, 256])

In [26]:
vocab_size

NameError: name 'vocab_size' is not defined

Note: caption的信息长度是按照一组（128条）中最长的哪个caption来确定的，因为经过了排序，所以最长的肯定是第一条，然后其他的补0到相同的长度，length记录了原caption的长度。注意，起始标签1、2也算在内

In [56]:
for n in captions[0]:
    print(vocab.idx2word[int(n.numpy())])

<start>
the
surfer
's
rides
the
crest
of
a
beautiful
breaking
wave
,
but
he
's
doing
down
!
<end>


## Model

In [81]:
import torchvision.models as models

In [205]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        """Load the pretrained ResNet-152 and replace top fc layer."""
        super(EncoderCNN, self).__init__()
        resnet = models.resnet152(pretrained=True)
        modules = list(resnet.children())[:-1]      # delete the last fc layer.
        self.resnet = nn.Sequential(*modules)
        self.linear = nn.Linear(resnet.fc.in_features, embed_size)
        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
        
    def forward(self, images):
        """Extract feature vectors from input images."""
        with torch.no_grad():
            features = self.resnet(images)
        print("encoder features: ", features.shape)
        features = features.reshape(features.size(0), -1)
        print("encoder features reshape: ", features.shape)
        features = self.bn(self.linear(features))
        return features


class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=20):
        """Set the hyper-parameters and build the layers."""
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.max_seg_length = max_seq_length
        
    def forward(self, features, captions, lengths):
        """Decode image feature vectors and generates captions."""
        print('-----------------------------------')
        print("D_caption: ", captions.shape)
        embeddings = self.embed(captions)
        print("D_embeddings: ", embeddings.shape)
        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
        print("D_embeddings: ", embeddings.shape)
        packed = pack_padded_sequence(embeddings, lengths, batch_first=True)
        print("D_packed: ", packed[0].shape)
        hiddens, _ = self.lstm(packed)
        print("D_hiddens: ", hiddens[0].shape)
        outputs = self.linear(hiddens[0])
        print("D_outputs: ", outputs.shape)
        print('-----------------------------------')
        return outputs
    
    def sample(self, features, states=None):
        """Generate captions for given image features using greedy search."""
        sampled_ids = []
        inputs = features.unsqueeze(1)
        
        for i in range(self.max_seg_length):             # self.max_seg_length == 20
            hiddens, states = self.lstm(inputs, states)  # hiddens: (batch_size, 1, hidden_size)
            outputs = self.linear(hiddens.squeeze(1))    # outputs:  (batch_size, vocab_size)

            _, predicted = outputs.max(1)                # predicted: (batch_size)
            sampled_ids.append(predicted)
            
            inputs = self.embed(predicted)               # inputs: (batch_size, embed_size)
            inputs = inputs.unsqueeze(1)                 # inputs: (batch_size, 1, embed_size)
            
        sampled_ids = torch.stack(sampled_ids, 1)        # sampled_ids: (batch_size, max_seq_length)
        return sampled_ids

In [211]:
rnn = nn.LSTM(10, 20, 2)
input_ = torch.randn(5, 3, 10)

output, (hn, cn) = rnn(input_)

In [206]:
# Build the models
encoder = EncoderCNN(args.embed_size).to(device)
decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab), args.num_layers).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters())
optimizer = torch.optim.Adam(params, lr=args.learning_rate)

In [189]:
len(lengths)

128

In [207]:
for i, (images, captions, lengths) in enumerate(data_loader):

    # Set mini-batch dataset
    images = images.to(device)
    captions = captions.to(device)
    targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]

    # Forward, backward and optimize
    features = encoder(images)  # [128, 256]
    print(features.shape)
    outputs = decoder(features, captions, lengths) # [x, 9956]
    loss = criterion(outputs, targets)
    
    break

encoder features:  torch.Size([128, 2048, 1, 1])
encoder features reshape:  torch.Size([128, 2048])
torch.Size([128, 256])
-----------------------------------
D_caption:  torch.Size([128, 25])
D_embeddings:  torch.Size([128, 25, 256])
D_embeddings:  torch.Size([128, 26, 256])
D_packed:  torch.Size([1687, 256])
D_hiddens:  torch.Size([1687, 512])
D_outputs:  torch.Size([1687, 9956])
-----------------------------------


In [9]:
decoder

NameError: name 'decoder' is not defined

In [184]:
decoder.sample(features).shape

torch.Size([128, 20])

In [171]:
cap = captions[0:5]
leng = lengths[0:5]

In [177]:
ran = torch.rand((128, 24, 256))

In [195]:
tar = pack_padded_sequence(ran, lengths, batch_first=True)

In [197]:
tar[0].shape

torch.Size([1669, 256])

In [198]:
sum(lengths)

1669