
### 🧠 **Learning Joint Image–Text Embeddings with Contrastive Loss (Flickr8k + GloVe)**

#### **Goal**

We want to train a model that learns a **shared embedding space** where:

* The embedding of an **image** and the embedding of its **caption** are **close**.
* Embeddings of **non-matching** image–caption pairs are **far apart**.

This enables tasks like:

* Retrieving images given text (and vice versa)
* Cross-modal understanding

---

### 🧮 **Mathematical Objective**

Let:

* $x_i \in \mathbb{R}^d$: embedding of the **$i$th caption**
* $y_i \in \mathbb{R}^d$: embedding of the **$i$th image**

We want to **maximize similarity** between $x_i$ and $y_i$
and **minimize similarity** between $x_i$ and all $y_j$, $j \ne i$.

We define a **similarity matrix** $S \in \mathbb{R}^{N \times N}$ for a batch of size $N$:

$$
S_{ij} = \langle x_i, y_j \rangle
$$

This is the dot product (cosine similarity if vectors are normalized) between caption $i$ and image $j$.

We then use **cross-entropy loss** to encourage:

* $S_{ii}$ to be high (correct match)
* $S_{ij}$ for $j \ne i$ to be low (distractors)

$$
\mathcal{L} = \frac{1}{N} \sum_{i=1}^{N} -\log \left( \frac{\exp(S_{ii})}{\sum_{j=1}^{N} \exp(S_{ij})} \right)
$$

---

### 🧩 **Model Components**

* **Text Encoder**: Takes the GloVe vectors of all words in a caption and averages them to a 300-dim vector, then projects to 128-dim.
* **Image Encoder**: A pretrained ResNet-50 with its final layer replaced to output 128-dim embeddings.
* Both outputs are **L2-normalized** to lie on a unit hypersphere.

---

### ⚙️ **Training Summary**

* Data: Flickr8k images and first caption per image
* Loss: Contrastive loss over dot-product similarity
* Optimizer: Adam
* Output: A model that embeds matching image–caption pairs closely

---

### ✅ **Why This Works**

This method is inspired by **CLIP** (Contrastive Language–Image Pretraining), where aligning images and natural language in a shared vector space enables powerful multimodal tasks—even with simple linear classifiers or nearest-neighbor search.



In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image
import numpy as np
import os
from nltk.tokenize import word_tokenize
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader, Dataset

# Constants
EMBED_DIM = 128
BATCH_SIZE = 32
NUM_EPOCHS = 10

# Download GloVe and Flickr8k datasets
!wget http://nlp.stanford.edu/data/glove.6B.zip
!unzip glove.6B.zip
!wget https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip
!wget https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip
!unzip Flickr8k_Dataset.zip
!unzip Flickr8k_text.zip

# Load GloVe embeddings
def load_glove_embeddings(filepath):
    embeddings_index = {}
    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            values = line.split()
            word = values[0]
            coefs = np.asarray(values[1:], dtype='float32')
            embeddings_index[word] = coefs
    return embeddings_index

glove_path = 'glove.6B.300d.txt'
embeddings_index = load_glove_embeddings(glove_path)

# Data preprocessing
def preprocess_caption(caption):
    tokens = word_tokenize(caption.lower())
    return [embeddings_index[token] for token in tokens if token in embeddings_index]

class FlickrDataset(Dataset):
    def __init__(self, image_folder, captions_file, transform=None):
        self.image_folder = image_folder
        self.captions = self.load_captions(captions_file)
        self.transform = transform

    def load_captions(self, file_path):
        with open(file_path, 'r') as f:
            lines = f.readlines()
        captions = {}
        for line in lines:
            tokens = line.split('\t')
            image_id, caption = tokens[0], tokens[1]
            if image_id not in captions:
                captions[image_id] = []
            captions[image_id].append(caption)
        return captions

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        image_id = list(self.captions.keys())[idx]
        caption = self.captions[image_id][0]  # Use the first caption
        image_path = os.path.join(self.image_folder, image_id)
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        caption_embedding = preprocess_caption(caption)
        return image, torch.tensor(caption_embedding, dtype=torch.float)

# Define image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Create dataset and dataloader
image_folder = 'Flickr8k_Dataset/Flicker8k_Dataset'
captions_file = 'Flickr8k_text/Flickr8k.token.txt'
dataset = FlickrDataset(image_folder, captions_file, transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# Define the Model
class SimpleAlignModel(nn.Module):
    def __init__(self, embed_dim):
        super(SimpleAlignModel, self).__init__()
        self.resnet = models.resnet50(pretrained=True)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, embed_dim)
        self.text_fc = nn.Linear(300, embed_dim)

    def forward(self, text, image):
        text_embed = F.normalize(self.text_fc(text), dim=-1)
        image_embed = F.normalize(self.resnet(image), dim=-1)
        return text_embed, image_embed

# Initialize Model
model = SimpleAlignModel(embed_dim=EMBED_DIM)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Training Loop
for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0.0

    for images, captions in dataloader:
        optimizer.zero_grad()
        batch_size = images.size(0)

        # Flatten the caption embeddings to (batch_size, 300)
        captions = captions.view(batch_size, -1, 300).mean(dim=1)

        text_embed, image_embed = model(captions, images)

        # Compute similarity matrix
        similarity_matrix = torch.matmul(text_embed, image_embed.t())

        # Compute loss (contrastive loss using labels as ground truth)
        labels = torch.arange(similarity_matrix.size(0)).to(similarity_matrix.device)
        loss = criterion(similarity_matrix, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f'Epoch [{epoch + 1}/{NUM_EPOCHS}], Loss: {total_loss / len(dataloader):.4f}')

print("Training completed.")


### Explanation:

1. **Data Preparation:**
   - **GloVe Embeddings:** We download and load the GloVe word embeddings.
   - **Flickr8k Dataset:** We download and unzip the Flickr8k dataset.
   - **Preprocessing:** Captions are tokenized and converted to embeddings using GloVe. Images are transformed using standard preprocessing steps.

2. **Model Definition:**
   - **ResNet:** We use a pretrained ResNet50 model to extract image embeddings, modifying the final layer to output the desired embedding dimension.
   - **Text Embeddings:** A simple linear layer maps the GloVe embeddings to the common embedding space.

3. **Training Loop:**
   - **Forward Pass:** We compute embeddings for both text and images.
   - **Similarity Matrix:** A similarity matrix is computed using the dot product of text and image embeddings.
   - **Loss Calculation:** We use cross-entropy loss to align the embeddings.
   - **Backpropagation:** The loss is backpropagated, and the model parameters are updated.

