[COCO](https://cocodataset.org/) 여기서 데이터를 다운받을 수 있습니다.

In [None]:
# !pip install pycocotools==2.0.0
# !conda install -c conda-forge pycocotools 

In [32]:
import nltk
import sys, os, pickle 
import numpy as np 

import torch 
import torch.nn as nn 
from torch.utils.data import Dataset, DataLoader 
from pycocotools.coco import COCO 
import torchvision.models as model 
import torchvision.transforms as transform 
from torch.nn.utils.rnn import pack_padded_sequence 
# from settings import * 

import matplotlib.pyplot as plt 

sys.path.append('/workspace')

In [3]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [4]:
class Vocab(object):
    """Simple vocabulary wrapper."""
    def __init__(self):
        self.w2i = {}
        self.i2w = {}
        self.index = 0
 
    def __call__(self, token):
        if not token in self.w2i:
            return self.w2i['']
        return self.w2i[token]
 
    def __len__(self):
        return len(self.w2i)
    def add_token(self, token):
        if not token in self.w2i:
            self.w2i[token] = self.index
            self.i2w[self.index] = token
            self.index += 1

In [7]:
def build_vocaburary(json, threshold):
    coco = COCO(json)
    counter = nltk.Counter()
    ids = coco.anns.keys() 
    for i, id in enumerate(ids):
        caption = str(coco.anns[id]['caption'])
        tokens = nltk.tokenize.word_tokenize(caption.lower())
        
        counter.update(tokens)
        
        if (i+1) % 1000 == 0:
            print(f'[{i+1}/{len(ids)}] Tokenized the captions.')
        
    tokens = [token for token, cnt in counter.items() if cnt >= threshold]
    vocab = Vocab()
    vocab.add_token('<pad>')
    vocab.add_token('<start>')
    vocab.add_token('<end>')
    vocab.add_token('<unk>')
    
    for token in tokens:
        vocab.add_token(token)
        
    return vocab 

In [6]:
dataType = 'val2017'
annFile = f'../annotations/instances_{dataType}.json'
# annFile = f'../annotations/instances_{dataType}.json'

In [26]:
vocab = build_vocaburary(json='workspace/annotations/captions_train2017.json', threshold=4)

loading annotations into memory...
Done (t=1.27s)
creating index...
index created!
[1000/591753] Tokenized the captions.
[2000/591753] Tokenized the captions.
[3000/591753] Tokenized the captions.
[4000/591753] Tokenized the captions.
[5000/591753] Tokenized the captions.
[6000/591753] Tokenized the captions.
[7000/591753] Tokenized the captions.
[8000/591753] Tokenized the captions.
[9000/591753] Tokenized the captions.
[10000/591753] Tokenized the captions.
[11000/591753] Tokenized the captions.
[12000/591753] Tokenized the captions.
[13000/591753] Tokenized the captions.
[14000/591753] Tokenized the captions.
[15000/591753] Tokenized the captions.
[16000/591753] Tokenized the captions.
[17000/591753] Tokenized the captions.
[18000/591753] Tokenized the captions.
[19000/591753] Tokenized the captions.
[20000/591753] Tokenized the captions.
[21000/591753] Tokenized the captions.
[22000/591753] Tokenized the captions.
[23000/591753] Tokenized the captions.
[24000/591753] Tokenized the 

In [None]:
def reshape_image(image, shape):
    return image.resize(shape, Image.ANTIALIAS)

def reshape_images(image_path, output_path, shape):
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    
    images = os.listdir(image_path)
    num_im = len(images)
    for i, im in enumerate(images):
        with open(os.path.join(image_path, im), 'r+b') as f:
            with Image.open(f) as image:
                image = reshape_image(image, shape)
                image.save(os.path.join(output_path, im), image.format)
        
        if (i+1) % 100 == 0 :
            print(f"[{i+1}/{num_im}] Resized the images and saved into : '{output_path}'")

image_path = './data/train2017/'
output_path = './data/resized_images/'
image_shape = [256, 256]
reshape_images(image_path, output_path, image_shape)


In [30]:
class CustomCocoDataset(Dataset):
    def __init__(self, data_path, coco_json_path, vocabulary, transform=None):
        self.root = data_path 
        self.coco_data = COCO(coco_json_path)
        self.indices = list(self.coco_data.anns.keys())
        self.vocabulary = vocabulary
        self.transform = transform 
        
    def __getitem__(self, idx):
        coco_data = self.coco_data 
        vocabulary = self.vocabulary
        annotation_id = self.indices[idx]
        caption = coco_data.anns[annotation_id]['caption']
        image_id = coco_data.anns[annotation_id]['image_id']
        image_path = coco_data.loadings(image_id)[0]['file_name']

        image = Image.open(op.path.join(self.root, image_path)).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)
            
        # Convert caption to word ids.
        word_tokens = nltk.tokenize.word_tokenize(str(caption).lower())
        caption = []
        caption.append(vocabulary(''))
        caption.extend([vocabulary(token) for token in word_tokens])
        caption.append(vocabulary(''))
        ground_truth = torch.Tensor(caption)
        return image, ground_truth # x, y 
    
    def __len__(self):
        return len(self.indices)

In [31]:
'''
collate_function은 DataLoader의 배치 단위에서 실행하는 함수입니다. 
텍스트의 길이가 서로 다르기 때문에 연산 효율성을 위해 collate_function을 지정하는 것입니다. (필수가 아님)
'''

def collate_function(data_batch):
    
    data_batch.sort(key=lambda x: len(x[1]), reverse=True)
    imgs, caps = zip(*data_batch)
    
    imgs = torch.stack(imgs, 0)
    
    cap_lens = [len(cap) for cap in caps]
    targets = torch.zeros(len(caps), max(cap_lens)).long()
    for i, cap in enumerate(caps):
        end = cap_lens[i]
        targets[i, end] = cap[:end]
    return imgs, targets, cap_lens 

In [None]:
def get_loader(data_path, coco_json_path, vocabulary, transform, batch_size, shuffle, num_workers):
    coco_dataset = CustomCocoDataset(data_path=data_path, 
                                     coco_json_path=coco_json_path, 
                                     vocabulary=vocabulary, 
                                     transform=transform)
    
    custom_data_loader = DataLoader(dataset=coco_dataset, 
                                    batch_size=batch_size, 
                                    shuffle=shuffle, 
                                    num_workers=num_workers, 
                                    collate_fn=collate_function)
    return custom_data_loader 

In [33]:
class CNNModel(nn.Module):
    def __init__(self, embedding_size):
        super(CNNModel, self).__init__()
        resnet = model.resnet152(pretrained=True)
        module_list = list(resnet.children())[:-1] # remove fully-connect layer
        self.resnet_module = nn.Sequential(*module_list)
        self.prediction_layer = nn.Linear(resnet.fc.in_features, embedding_size)
        self.batch_norm = nn.BatchNorm1d(embedding_size, momentum=0.01)
        
    def forward(self, input_images):
        with torch.no_grad():
            resnet_features = self.resnet_module(input_images)
            resnet_features = resnet_features.reshape(resnet_features.size(0), -1) # view(-1)
            final_features = self.batch_norm(self.prediction_layer(resnet_features))
            return final_features 


class LSTMModel(nn.Module):
    def __init__(self, embedding_size, hidden_size, vocabulary_size, num_layers, max_seq_len=20):
        super(LSTMModel, self).__init__()
        self.embedding_layer = nn.Embedding(vocabulary_size, embedding_size)
        self.lstm_layer = nn.LSTM(embedding_size, hidden_size, num_layers, batch_first=True)
        self.prediction_layer = nn.Linear(hidden_size, vocabulary_size)
        self.max_seq_len = max_seq_len 
    
    def forward(self, input_features, caps, lens):
        embeddings = self.embedding_layer(caps)
        embeddings = torch.cat((input_features.unsqueeze(1), embeddings), 1)
        lstm_input = pack_padded_sequence(embeddings, lens, batch_first=True)
        hidden_variables, _ = self.lstm_layer(lstm_input)
        model_outputs = self.prediction_layer(hidden_variables[0])
        return model_outputs 

    def sample(self, input_features, lstm_states=None):
        sampled_indices = []
        lstm_inputs = input_features.unsqueeze(1)
        
        for i in range(self.max_sen_len):
            hidden_variables, lstm_states = self.lstm_layer(lstm_inputs, lstm_states)
            model_outputs = self.prediction_layer(hidden_variables.squeeze(1))
            _, predicted_outputs = model_outputs.max(1)
            sampled_indices.append(predicted_outputs)
            lstm_inputs = self.embedding_layer(predicted_outputs)
            lstm_inputs = lstm_inputs.unsqueeze(1)
        sampled_indices = torch.stack(sampled_indices, 1)
        return sampled_indices


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transforms = transform.Compose([ 
                                transform.RandomCrop(224),
                                transform.RandomHorizontalFlip(), 
                                transform.ToTensor(), 
                                transform.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

with open('data/vocabulary.pkl', 'rb') as f:
    vocabulary = pickle.load(f)


custom_data_loader = get_loader(
    'data/resized_images', 
    'annotations/captions_val2017.json', 
    vocabulary, 
    transforms,
    128, 
    shuffle=True, 
    num_workers=2
)

encoder_model = CNNModel(256).to(device)
decoder_model = LSTMModel(256, 512, len(vocabulary), 1).to(device)

criterion = nn.CrossEntropyLoss()
parameters = list(decoder_model.parameters()) + list(encoder_model.prediction_layer.parameters()) + list(encoder_model.batch_norm.parameters())
optimizer = torch.optim.Adam(parameters, lr=1e-3)

total_num_steps = len(custom_data_loader)

In [None]:
for epoch in range(5):
    for i, (imgs, caps, lens) in enumerate(custom_data_loader):
        targets = pack_padded_sequence(caps, lens, batch_first=True)[0]
        
        features = encoder_model(imgs) 
        outputs = decoder_model(features, caps, lens)
        loss = criterion(outputs, targets)
        decoder_model.zero_grad()
        encoder_model.zero_grad()
        loss.backward()
        optimizer.step()
        
        if i % 10 == 0:
            print(f'Epoch [{epoch}/{5}], step [{i}/{total_num_steps}], Loss: {loss.item():.4f}, Perplexity: {np.exp(loss.item()):5.4f}')
        
        if (i+1) % 1000 == 0:
            torch.save(decoder_model.state_dict(), os.path.join('models_dir/', f'decoder-{epoch+1}-{i+1}.ckpt'))
            torch.save(encoder_model.state_dict(), os.path.join('models_dir/', f'encoder-{epoch+1}-{i+1}.ckpt'))

In [None]:
image_file_path = 'sample.jpg'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def load_image(image_file_path, transform=None):
    img = Image.open(image_file_path).convert('RGB')
    img = img.resize([224, 224], Image.LANCZOS)
    
    if transform is not None:
        img = transform(img).unsqueeze(0)
    
    return img 

transforms = transform.Compose([ 
                                transform.ToTensor(), 
                                transform.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

In [None]:
with open('data/vocabularyt.pkl', 'rb') as f:
    vocabulary = pickle.load(f)

encoder_model = CNNModel(256).eval()
decoder_model = LSTMModel(256, 512, len(vocabulary), 1)

encoder_model = encoder_model.to(device)
decoder_model = decoder_model.to(device)

encoder_model.load_state_dict(torch.load('model_dir/encoder-2-3000.ckpt'))
decoder_model.load_state_dict(torch.load('model_dir/decoder-2-3000.ckpt'))

In [None]:
img = load_image(image_file_path, transform)
img_tensor = img.to(device)

features = encoder_model(img_tensor)
sampled_indices = decoder_model.sample(features)

sampled_indices = sampled_indices[0].cpu().numpy()

In [None]:
predicted_caption = []

for token_index in sampled_indices:
    word = vocabulary.i2w[token_index]
    predicted_caption.append(word)
    
    if word == '<end>':
        break 
    
predicted_sentence = ' '.join(predicted_caption)

In [None]:
%matplotlib inline 

print(predicted_sentence)
img = Image.open(image_file_path)
plt.imshow(np.asarray(img))