<a href="https://colab.research.google.com/github/leeds1219/VQA_retriever/blob/main/MAML.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from transformers import BertModel, BertTokenizer

# 1. BERT 기반 Retriever 모델 정의
class BertRetriever(nn.Module):
    def __init__(self, pretrained_model="bert-base-uncased"): # BERT, RoBERTa, DistilBERT CLIP
        super(BertRetriever, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model)  # Pretrained BERT

    def forward(self, input_ids, attention_mask):
        # BERT를 사용해 입력 문장의 임베딩 생성
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_embedding = outputs.last_hidden_state[:, 0, :]  # [CLS] 토큰의 임베딩 (768차원)
        return cls_embedding

In [None]:
# 2. Triplet Loss 계산 함수
def triplet_loss(anchor, positive, negative, margin=1.0):
    pos_distance = F.pairwise_distance(anchor, positive, p=2)
    neg_distance = F.pairwise_distance(anchor, negative, p=2)
    loss = F.relu(pos_distance - neg_distance + margin)
    return loss.mean()

In [None]:
# 3. Meta-Learning Training (MAML)
def meta_learning_training(tasks, model, meta_optimizer, inner_lr, outer_lr):
    meta_loss = 0

    for task in tasks:
        support_set, query_set = task

        # Inner Loop
        adapted_model = clone_model(model)
        task_optimizer = optim.SGD(adapted_model.parameters(), lr=inner_lr)

        for anchor, positive, negative in support_set:
            anchor_emb = adapted_model(**anchor)
            positive_emb = adapted_model(**positive)
            negative_emb = adapted_model(**negative)

            loss = triplet_loss(anchor_emb, positive_emb, negative_emb)
            task_optimizer.zero_grad()
            loss.backward()
            task_optimizer.step()

        # Outer Loop
        for anchor, positive, negative in query_set:
            anchor_emb = adapted_model(**anchor)
            positive_emb = adapted_model(**positive)
            negative_emb = adapted_model(**negative)

            loss = triplet_loss(anchor_emb, positive_emb, negative_emb)
            meta_loss += loss

    meta_optimizer.zero_grad()
    meta_loss.backward()
    meta_optimizer.step()

    return model

In [None]:
def clone_model(model):
    cloned_model = BertRetriever()
    cloned_model.load_state_dict(model.state_dict())
    return cloned_model

In [2]:
import random

# 5. 샘플링 함수
def sample_tasks(data, tokenizer, num_tasks=10, samples_per_task=32):
    tasks = []
    for _ in range(num_tasks):  # num_tasks 만큼 태스크 샘플링
        support_set = []
        query_set = []

        sampled_data = random.sample(data, samples_per_task)  # 데이터에서 샘플 추출

        for entry in sampled_data:
            anchor_text = entry['query']
            positive_text = entry['positive_document']
            negative_text = entry['negative_document']

            anchor = tokenizer(anchor_text, return_tensors="pt", padding=True, truncation=True)
            positive = tokenizer(positive_text, return_tensors="pt", padding=True, truncation=True)
            negative = tokenizer(negative_text, return_tensors="pt", padding=True, truncation=True)

            support_set.append((anchor, positive, negative))
            query_set.append((anchor, positive, negative))

        tasks.append((support_set, query_set))
    return tasks

In [None]:
def train_model(model, meta_optimizer, tokenizer, num_epochs, inner_lr, outer_lr):
    for epoch in range(num_epochs):
        tasks = sample_tasks(tokenizer)  # 샘플 태스크 생성
        model = meta_learning_training(tasks, model, meta_optimizer, inner_lr, outer_lr)
        print(f"Epoch {epoch + 1}: Training complete")

    return model

In [None]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertRetriever()
meta_optimizer = optim.Adam(model.parameters(), lr=0.001)

trained_model = train_model(model, meta_optimizer, tokenizer, num_epochs=5, inner_lr=0.01, outer_lr=0.001)


Query: Text? or Image?

What to retrieve? Text?

Where to retrieve from? Wiki corpus?