<a href="https://colab.research.google.com/github/leeds1219/VQA_retriever/blob/main/MAML.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!pip install higher transformers torch torchvision

Collecting higher
  Downloading higher-0.2.1-py3-none-any.whl.metadata (10 kB)
Downloading higher-0.2.1-py3-none-any.whl (27 kB)
Installing collected packages: higher
Successfully installed higher-0.2.1


In [72]:
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel
from torch.optim import Adam

In [None]:
image = Image.open("/content/sample.jpg")

In [None]:
inputs = processor(images=image, return_tensors="pt", padding=True)
image_features = model.get_image_features(**inputs)

# print(image_features)

In [30]:
question = "Why do dogs wag their tails?"

In [31]:
text_inputs = processor(text=[question], return_tensors="pt", padding=True)
text_features = model.get_text_features(**text_inputs)

# print(text_features)

In [None]:
captions = [
    "Dogs wag their tails to show happiness, excitement, or to get attention.",
    "The weather today is absolutely beautiful.",
    "I love ending the day with a good book.",
    "I long for peaceful moments in nature."
]

In [33]:
text_inputs = caption_processor(text=captions, return_tensors="pt", padding=True)
text_features = caption_model.get_text_features(**text_inputs)

# print(text_features)

In [71]:
class CLIPEmbedder:
    def __init__(self, model_name="openai/clip-vit-base-patch16", freeze_question=True, freeze_image_model=True):
        self.model = CLIPModel.from_pretrained(model_name)
        self.processor = CLIPProcessor.from_pretrained(model_name)
        self.projection_layer = nn.Linear(1024, 512)
        if freeze_image_model:
            for param in self.model.vision_model.parameters():
                param.requires_grad = False
            print("Image model is frozen.")

        self.question_model = CLIPModel.from_pretrained(model_name)
        if freeze_question:
            for param in self.question_model.text_model.parameters():
                param.requires_grad = False
            print("Question model is frozen.")


        self.caption_model = CLIPModel.from_pretrained(model_name)
        for param in self.caption_model.text_model.parameters():
            param.requires_grad = True
        print("Caption model is trainable.")

    def get_question_embedding(self, question):

        question_inputs = self.processor(text=question, return_tensors="pt", padding=True)
        question_features = self.question_model.get_text_features(**question_inputs)
        return question_features

    def get_caption_embedding(self, caption):

        caption_inputs = self.processor(text=caption, return_tensors="pt", padding=True)
        caption_features = self.caption_model.get_text_features(**caption_inputs)
        return caption_features

    def get_image_embedding(self, images):

        image_inputs = self.processor(images=images, return_tensors="pt", padding=True)
        image_features = self.model.get_image_features(**image_inputs)
        return image_features

    def cosine_similarity(self, embedding1, embedding2):

        embedding1_norm = embedding1 / embedding1.norm(dim=-1, keepdim=True)
        embedding2_norm = embedding2 / embedding2.norm(dim=-1, keepdim=True)

        similarity = (embedding1_norm * embedding2_norm).sum(dim=-1)
        return similarity

    def compute_query_embedding(self, questions, images):

        question_embedding = self.get_question_embedding(questions)
        image_embedding = self.get_image_embedding(images)

        combined_embedding = torch.cat((question_embedding, image_embedding), dim=-1)

        query_embedding = self.projection_layer(combined_embedding)
        return query_embedding

    def contrastive_loss(self, query_embeddings, caption_embeddings, margin=0.2):

        query_embeddings = query_embeddings / query_embeddings.norm(dim=-1, keepdim=True)
        caption_embeddings = caption_embeddings / caption_embeddings.norm(dim=-1, keepdim=True)

        similarities = torch.matmul(query_embeddings, caption_embeddings.T)

        positive_similarities = torch.diag(similarities)

        batch_size = similarities.size(0)
        negative_loss = torch.sum(F.relu(margin + similarities - positive_similarities.unsqueeze(1)))
        positive_loss = torch.sum(1 - positive_similarities)
        loss = (positive_loss + negative_loss) / batch_size

        return loss

    def train_step(self, questions, images, captions):

        query_embeddings = self.compute_query_embedding(questions, images)

        caption_embeddings = self.get_caption_embedding(captions)

        loss = self.contrastive_loss(query_embeddings, caption_embeddings)
        return loss

    def train_maml(self, tasks, adaptation_steps=1, meta_lr=1e-3, inner_lr=1e-2):

        meta_optimizer = Adam(self.parameters(), lr=meta_lr)

        for epoch in range(10):
            meta_loss = 0.0
            for task in tasks:

                questions, images, captions = task

                adapted_params = {name: param.clone() for name, param in self.named_parameters()}

                for _ in range(adaptation_steps):
                    query_embeddings = self.compute_query_embedding(questions, images)
                    caption_embeddings = self.get_caption_embedding(captions)
                    loss = self.contrastive_loss(query_embeddings, caption_embeddings)

                    grads = torch.autograd.grad(loss, adapted_params.values(), create_graph=True)
                    adapted_params = {name: param - inner_lr * grad for (name, param), grad in zip(adapted_params.items(), grads)}

                query_embeddings = self.compute_query_embedding(questions, images)
                caption_embeddings = self.get_caption_embedding(captions)
                task_loss = self.contrastive_loss(query_embeddings, caption_embeddings)
                meta_loss += task_loss

            meta_optimizer.zero_grad()
            meta_loss.backward()
            meta_optimizer.step()

            print(f"Epoch {epoch+1}, Meta Loss: {meta_loss.item()}")

In [69]:
clip_embedder = CLIPEmbedder(freeze_question=True, freeze_image_model=True)

Image model is frozen.
Question model is frozen.
Caption model is trainable.


In [36]:
question_embedding = clip_embedder.get_question_embedding(question)

In [38]:
caption = captions[0]

In [41]:
caption_embedding = clip_embedder.get_caption_embedding(caption)

In [40]:
image_embedding = clip_embedder.get_image_embedding(image)

In [43]:
question_embedding.shape

torch.Size([1, 512])

In [42]:
caption_embedding.shape

torch.Size([1, 512])

In [44]:
image_embedding.shape

torch.Size([1, 512])

In [45]:
query_embedding = torch.cat((question_embedding, image_embedding), dim=1)

In [47]:
query_embedding.shape

torch.Size([1, 1024])

In [48]:
projection_layer = nn.Linear(1024, 512)
query_embedding_projected = projection_layer(query_embedding)

In [49]:
query_embedding_projected.shape

torch.Size([1, 512])

In [55]:
import torch
import torch.nn.functional as F

query_embedding_projected = torch.rand(1, 512)  # (1, 512)
caption_embedding = torch.rand(1, 512)  # (1, 512)

cosine_similarity = clip_embedder.cosine_similarity(query_embedding_projected, caption_embedding)

print(f"Cosine Similarity: {cosine_similarity.item()}")

Cosine Similarity: 0.7582837343215942


In [56]:
negative_caption_embedding = clip_embedder.get_caption_embedding(captions[-1])

In [57]:
negative_cosine_similarity = clip_embedder.cosine_similarity(query_embedding_projected, negative_caption_embedding)

In [59]:
print(f"Negative Cosine Similarity: {negative_cosine_similarity.item()}")

Negative Cosine Similarity: -0.002463156823068857


In [64]:
from PIL import Image

In [70]:
# Initialize the model
embedder = CLIPEmbedder()

# Example data
questions = ["What is in the image?", "Describe the scene."]
images = [Image.open("/content/sample_dog.jpg"), Image.open("/content/sample_cat.jpeg")]
captions = ["A cat sitting on a sofa.", "A dog running in a park."]

# Perform a training step
loss = embedder.train_step(questions, images, captions)
print(f"Loss: {loss.item()}")

Image model is frozen.
Question model is frozen.
Caption model is trainable.
Loss: 1.3966864347457886
