In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import vgg16
from torchvision import transforms
from PIL import Image

In [2]:
# 단어사전
vocab = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "a", 4: "dog", 5: "is", 6: "sitting", 7: "on", 8: "grass"}
inv_vocab = {v: k for k, v in vocab.items()}

In [11]:
# 이미지 전처리
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

image = Image.open('shin.jpg').convert('RGB')
image_tansor = transform(image).unsqueeze(0)

In [13]:
# vgg16
vgg =vgg16(pretrained=True).features
for param in vgg.parameters():
  param.requires_grad = False

with torch.no_grad():
  features = vgg(image_tansor)
  features = features.view(features.shape[0], -1).unsqueeze(1)



In [14]:
caption = [1, 3, 4, 5, 6, 7, 8, 2]
input_seq = torch.tensor([caption[:-1]])
target_seq = torch.tensor([caption[1:]])

In [15]:
class captionGenerator(nn.Module):
  def __init__(self, feature_dim, embed_dim, hidden_dim, vocab_size):
    super(captionGenerator, self).__init__()

    self.embed = nn.Embedding(vocab_size, embed_dim)
    self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
    self.decoder = nn.Linear(hidden_dim, vocab_size)
    self.init_linear = nn.Linear(feature_dim, embed_dim)

  def forward(self, features, captions):
    embedded_feats = self.init_linear(features)
    embeds = self.embed(captions)
    inputs = torch.cat((embedded_feats, embeds), dim=1)
    hiddens, _ = self.lstm(inputs)
    outputs = self.decoder(hiddens)
    return outputs[:, 1:, :]


In [17]:
# 학습
model = captionGenerator(feature_dim=25088, embed_dim=256, hidden_dim=512, vocab_size=len(vocab))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(20):
  model.train()
  optimizer.zero_grad()

  outputs = model(features, input_seq)
  loss = criterion(outputs.squeeze(0), target_seq.squeeze(0))
  loss.backward()
  optimizer.step()

  print(f"Epoch: {epoch}, Loss: {loss.item()}")

Epoch: 0, Loss: 2.198516607284546
Epoch: 1, Loss: 1.7331022024154663
Epoch: 2, Loss: 1.3086729049682617
Epoch: 3, Loss: 0.9348776936531067
Epoch: 4, Loss: 0.6206761598587036
Epoch: 5, Loss: 0.3842798173427582
Epoch: 6, Loss: 0.22651979327201843
Epoch: 7, Loss: 0.130711168050766
Epoch: 8, Loss: 0.07572486251592636
Epoch: 9, Loss: 0.045013971626758575
Epoch: 10, Loss: 0.027807502076029778
Epoch: 11, Loss: 0.017946090549230576
Epoch: 12, Loss: 0.012095452286303043
Epoch: 13, Loss: 0.008492927066981792
Epoch: 14, Loss: 0.006190059240907431
Epoch: 15, Loss: 0.004663849715143442
Epoch: 16, Loss: 0.003618306014686823
Epoch: 17, Loss: 0.002880186541005969
Epoch: 18, Loss: 0.0023449421860277653
Epoch: 19, Loss: 0.0019473283318802714


In [18]:
model.eval()
with torch.no_grad():
  generated = []
  input_word = torch.tensor([[1]])
  embed_feat = model.init_linear(features)
  hidden = None

  for _ in range(10):
    embed_input = model.embed(input_word)
    lstm_input = torch.cat((embed_feat, embed_input), dim=1) if len(generated) == 0 else embed_input
    out, hidden = model.lstm(lstm_input, hidden)
    pred = model.decoder(out[:, -1, :])
    pred_id = pred.argmax(dim=-1).item()

    if pred_id == 2:
      break

    generated.append(pred_id)
    input_word = torch.tensor([[pred_id]])
    embed_feat = None

  sentence = ' '.join([vocab[idx] for idx in generated])
  print('생성된 켑션: ', sentence)

생성된 켑션:  a dog is sitting on grass
