<a href="https://colab.research.google.com/github/leeds1219/VQA_retriever/blob/main/MAML.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install higher transformers torch torchvision

Collecting higher
  Downloading higher-0.2.1-py3-none-any.whl.metadata (10 kB)
Downloading higher-0.2.1-py3-none-any.whl (27 kB)
Installing collected packages: higher
Successfully installed higher-0.2.1


In [None]:
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel
from torch.optim import Adam

In [None]:
class CLIPEmbedder:
    def __init__(self, model_name="openai/clip-vit-base-patch16", freeze_question=True, freeze_image_model=True):
        self.model = CLIPModel.from_pretrained(model_name)
        self.processor = CLIPProcessor.from_pretrained(model_name)
        self.projection_layer = nn.Linear(1024, 512)
        if freeze_image_model:
            for param in self.model.vision_model.parameters():
                param.requires_grad = False
            print("Image model is frozen.")

        self.question_model = CLIPModel.from_pretrained(model_name)
        if freeze_question:
            for param in self.question_model.text_model.parameters():
                param.requires_grad = False
            print("Question model is frozen.")


        self.caption_model = CLIPModel.from_pretrained(model_name)
        for param in self.caption_model.text_model.parameters():
            param.requires_grad = True
        print("Caption model is trainable.")

    def get_question_embedding(self, question):

        question_inputs = self.processor(text=question, return_tensors="pt", padding=True)
        question_features = self.question_model.get_text_features(**question_inputs)
        return question_features

    def get_caption_embedding(self, caption):

        caption_inputs = self.processor(text=caption, return_tensors="pt", padding=True)
        caption_features = self.caption_model.get_text_features(**caption_inputs)
        return caption_features

    def get_image_embedding(self, images):

        image_inputs = self.processor(images=images, return_tensors="pt", padding=True)
        image_features = self.model.get_image_features(**image_inputs)
        return image_features

    def cosine_similarity(self, embedding1, embedding2):

        embedding1_norm = embedding1 / embedding1.norm(dim=-1, keepdim=True)
        embedding2_norm = embedding2 / embedding2.norm(dim=-1, keepdim=True)

        similarity = (embedding1_norm * embedding2_norm).sum(dim=-1)
        return similarity

    def compute_query_embedding(self, questions, images):

        question_embedding = self.get_question_embedding(questions)
        image_embedding = self.get_image_embedding(images)

        combined_embedding = torch.cat((question_embedding, image_embedding), dim=-1)

        query_embedding = self.projection_layer(combined_embedding)
        return query_embedding

#################################Did not review###################################################################
    def contrastive_loss(self, query_embeddings, caption_embeddings, margin=0.2):

        query_embeddings = query_embeddings / query_embeddings.norm(dim=-1, keepdim=True)
        caption_embeddings = caption_embeddings / caption_embeddings.norm(dim=-1, keepdim=True)

        similarities = torch.matmul(query_embeddings, caption_embeddings.T)

        positive_similarities = torch.diag(similarities)

        batch_size = similarities.size(0)
        negative_loss = torch.sum(F.relu(margin + similarities - positive_similarities.unsqueeze(1)))
        positive_loss = torch.sum(1 - positive_similarities)
        loss = (positive_loss + negative_loss) / batch_size

        return loss

#################################Did not review###################################################################
    def train_maml(self, tasks, adaptation_steps=1, meta_lr=1e-3, inner_lr=1e-2):
        meta_optimizer = Adam(self.caption_model.parameters(), lr=meta_lr)

        for epoch in range(10):
            meta_loss = 0.0
            for task in tasks:
                # task는 이제 하나의 리스트로 되어 있고, 이를 반으로 나누어 support set과 query set으로 분리
                split_idx = len(task) // 2
                support_set = task[:split_idx]  # 앞부분을 support set으로
                query_set = task[split_idx:]  # 뒷부분을 query set으로

                # Support set에서 질문, 이미지, 캡션을 추출
                support_questions, support_images, support_captions = zip(*support_set)
                # Query set에서 질문, 이미지, 캡션을 추출
                query_questions, query_images, query_captions = zip(*query_set)

                self.caption_model.train()
                adapted_params = list(self.caption_model.parameters())

                # Adaptation steps (훈련용 데이터로 모델 적응)
                for step in range(adaptation_steps):
                    print(f"\nAdaptation step {step+1}")

                    # Support set을 사용하여 임베딩 계산
                    query_embeddings = self.compute_query_embedding(support_questions, support_images)
                    caption_embeddings = self.get_caption_embedding(support_captions)

                    # Contrastive loss 계산
                    loss = self.contrastive_loss(query_embeddings, caption_embeddings)

                    # Gradient 계산 후 모델 파라미터 업데이트
                    loss.backward()

                    with torch.no_grad():
                        for param in adapted_params:
                            if param.grad is not None:
                                param.data -= inner_lr * param.grad.data

                    self.caption_model.zero_grad()

                # Query set을 사용하여 손실 계산 (평가)
                query_embeddings = self.compute_query_embedding(query_questions, query_images)
                caption_embeddings = self.get_caption_embedding(query_captions)
                task_loss = self.contrastive_loss(query_embeddings, caption_embeddings)
                meta_loss += task_loss

            # Meta optimization step (전체 task에 대해 손실을 최소화)
            meta_optimizer.zero_grad()
            meta_loss.backward()
            meta_optimizer.step()

            print(f"Epoch {epoch+1}, Meta Loss: {meta_loss.item()}")



In [None]:
from PIL import Image

# Example batch of tasks
tasks = [
    # 첫 번째 task
    [
        # Support set: 첫 번째 2개의 샘플
        ("What is in the image?", Image.open("/content/image1.jpg"), "A cat on a sofa."),
        ("What is in the image?", Image.open("/content/image2.jpg"), "A dog in the park."),
        # Query set: 나머지 2개의 샘플
        ("What is in the image?", Image.open("/content/image3.jpg"), "A bird flying."),
        ("What is in the image?", Image.open("/content/image4.jpg"), "A car on the road.")
    ],
    # 두 번째 task
    [
        # Support set: 첫 번째 2개의 샘플
        ("Describe the scene.", Image.open("/content/image5.jpg"), "A lake at sunset."),
        ("Describe the scene.", Image.open("/content/image6.jpg"), "A mountain covered with snow."),
        # Query set: 나머지 2개의 샘플
        ("Describe the scene.", Image.open("/content/image7.jpg"), "A city skyline."),
        ("Describe the scene.", Image.open("/content/image8.jpg"), "A desert with a camel.")
    ]
]


In [None]:
embedder = CLIPEmbedder()
embedder.train_maml(tasks, adaptation_steps=5, meta_lr=1e-3, inner_lr=1e-2)

Image model is frozen.
Question model is frozen.
Caption model is trainable.

Adaptation step 1

Adaptation step 2

Adaptation step 3

Adaptation step 4

Adaptation step 5

Adaptation step 1

Adaptation step 2

Adaptation step 3

Adaptation step 4

Adaptation step 5
Epoch 1, Meta Loss: 1.6934982538223267

Adaptation step 1

Adaptation step 2

Adaptation step 3

Adaptation step 4

Adaptation step 5

Adaptation step 1

Adaptation step 2

Adaptation step 3

Adaptation step 4

Adaptation step 5
Epoch 2, Meta Loss: 1.0984865427017212

Adaptation step 1

Adaptation step 2

Adaptation step 3

Adaptation step 4

Adaptation step 5

Adaptation step 1

Adaptation step 2

Adaptation step 3

Adaptation step 4

Adaptation step 5
Epoch 3, Meta Loss: 1.0045394897460938

Adaptation step 1

Adaptation step 2

Adaptation step 3

Adaptation step 4

Adaptation step 5

Adaptation step 1

Adaptation step 2

Adaptation step 3

Adaptation step 4

Adaptation step 5
Epoch 4, Meta Loss: 0.9734621047973633

Adapta