<a href="https://colab.research.google.com/github/kiitaamuuraa/Asobiba/blob/main/MyFirstLXMERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# LXMERT × object Refferal

In [1]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/2c/d8/5144b0712f7f82229a8da5983a8fbb8d30cec5fbd5f8d12ffe1854dcea67/transformers-4.4.1-py3-none-any.whl (2.1MB)
[K     |████████████████████████████████| 2.1MB 5.7MB/s 
[?25hCollecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/71/23/2ddc317b2121117bf34dd00f5b0de194158f2a44ee2bf5e47c7166878a97/tokenizers-0.10.1-cp37-cp37m-manylinux2010_x86_64.whl (3.2MB)
[K     |████████████████████████████████| 3.2MB 36.4MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 51.3MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.43-cp37-none-any.whl size=893262 sha256=c39f9

In [3]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, Dataset
from torch.utils.tensorboard import SummaryWriter

from transformers import LxmertTokenizer, LxmertModel

## トークナイザとモデルの定義

In [16]:
tokenizer = LxmertTokenizer.from_pretrained('unc-nlp/lxmert-base-uncased')
model = LxmertModel.from_pretrained('unc-nlp/lxmert-base-uncased')

tensor([[[0.2955, 0.3331, 0.6159, 0.5383],
         [0.7964, 0.1018, 0.9522, 0.5114],
         [0.9440, 0.6369, 0.7073, 0.3295]]])

## 入力（ダミー）の定義  
**注意:** 画像に関する情報を定義しないと、モデルはエラーを返す

In [50]:
sent = ['President Joe Biden sat down with ABC News George Stephanopoulos for a wide-ranging interview Tuesday in which he said his message to migrants was to not come to the border and that New York Gov.',
         'Andrew Cuomo should resign if allegations he committed sexual harassment are confirmed.',
        'And he said it would be "tough" to withdraw all American troops from Afghanistan by May 1, a deadline set out in a deal former President Donald Trump\'s administration made with the Taliban.']

inputs = tokenizer(sent, return_tensors="pt", padding=True, truncation=True)
print(inputs['input_ids'].shape)

# ダミーの画像特徴量を作成
# batch x num_images x dim
visual_feats = torch.randn([3, 64, 2048])
inputs['visual_feats'] = visual_feats

# ダミーの画像特徴量の座標作成
# batch x num_images x dim
visual_pos = torch.rand([3, 64, 4])
inputs['visual_pos'] = visual_pos

torch.Size([3, 43])


## 

In [117]:
outputs = model(**inputs)

In [118]:
# サイズ確認
for k in outputs.keys():
    print(k, outputs[k].shape)

language_output torch.Size([3, 43, 768])
vision_output torch.Size([3, 64, 768])
pooled_output torch.Size([3, 768])


## ターゲットの画像のImage Feature から16個をランダムサンプル

In [147]:
import random

def get_idx(num_sample=16, all_options=63):
    idx = list()
    while len(idx) < num_sample:
        id = random.randint(0, all_options)
        if id not in idx:
            idx.append(id)
    return torch.tensor(idx)

l = list()
for i, b in enumerate(visual_feats):
    l.append(b[get_idx()])
    random_samples = torch.stack(l, dim=0)

random_samples.shape

torch.Size([3, 16, 2048])

## ランダムにラベルを作成

In [208]:
import numpy as np
labels = np.array([random.randint(0,15) for i in range(3)])

## 内積を取る

In [170]:
fc = torch.nn.Linear(768, 2048)

In [195]:
cross_modal_features = outputs['pooled_output'] # クロスモーダルの特徴量
print(cross_modal_features.shape)

# 元画像の特徴量にサイズを合わせる
cross_modal_features = fc(cross_modal_features)

# bmm用にリピート
cross_modal_features = cross_modal_features.unsqueeze(dim=1)
print(cross_modal_features.shape)

torch.Size([3, 768])
torch.Size([3, 1, 2048])


In [202]:
batch_cos_sim = torch.bmm(cross_modal_features, random_samples.permute(0,2,1)).squeeze() # b(3) x target(16)

In [210]:
criterion = nn.CrossEntropyLoss()
criterion(batch_cos_sim, torch.tensor(labels))