In [1]:
from torchvision import datasets
from torchvision.transforms import ToTensor
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
os.environ["OMP_NUM_THREADS"] = "1"

In [2]:
train_data = datasets.MNIST(
    root="data",
    train=True,
    transform=ToTensor(),
    download=True,
)

test_data = datasets.MNIST(
    root="data",
    train=False,
    transform=ToTensor(),
    download=True,
)

In [3]:
train_data

Dataset MNIST
    Number of datapoints: 60000
    Root location: data
    Split: Train
    StandardTransform
Transform: ToTensor()

In [4]:
test_data

Dataset MNIST
    Number of datapoints: 10000
    Root location: data
    Split: Test
    StandardTransform
Transform: ToTensor()

In [5]:
from torch.utils.data import DataLoader

loaders = {
    'train':DataLoader(train_data,
                         batch_size = 100,
                         shuffle=True,
                         num_workers=1),
    'test': DataLoader(test_data,
                        batch_size = 100,
                        shuffle=False,
                        num_workers=1),
}

In [6]:
loaders

{'train': <torch.utils.data.dataloader.DataLoader at 0x113980f40>,
 'test': <torch.utils.data.dataloader.DataLoader at 0x113980e50>}

In [7]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


class CNN(nn.Module):

    def __init__(self):
        super(CNN, self).__init__()

        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)


    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)

        return F.log_softmax(x, dim=1)

In [8]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CNN().to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)

loss_fn = nn.CrossEntropyLoss()

def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(loaders['train']):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 20 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(loaders["train"].dataset)} ({100. * batch_idx / len(loaders["train"]):.0f}%)]\t{loss.item():.6f}')


def test():
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in loaders['test']:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += loss_fn(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(loaders['test'].dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(loaders["test"].dataset)} ({100. * correct / len(loaders["test"].dataset):.0f}%)\n')

In [9]:
for epoch in range(1, 11):
    train(epoch)
    test()  


Test set: Average loss: 0.0013, Accuracy: 9603/10000 (96%)


Test set: Average loss: 0.0009, Accuracy: 9719/10000 (97%)


Test set: Average loss: 0.0007, Accuracy: 9779/10000 (98%)


Test set: Average loss: 0.0006, Accuracy: 9804/10000 (98%)


Test set: Average loss: 0.0006, Accuracy: 9807/10000 (98%)


Test set: Average loss: 0.0005, Accuracy: 9843/10000 (98%)


Test set: Average loss: 0.0005, Accuracy: 9853/10000 (99%)


Test set: Average loss: 0.0005, Accuracy: 9852/10000 (99%)


Test set: Average loss: 0.0004, Accuracy: 9876/10000 (99%)


Test set: Average loss: 0.0004, Accuracy: 9884/10000 (99%)



In [10]:
# After training is complete, extract features and build FAISS index
# (Run this cell after all epochs are finished)

import faiss
import numpy as np
from torch.utils.data import DataLoader, Subset

# Optimized feature extraction: minimal memory, clear variables

def extract_features(model, loader, device):
    model.eval()
    features = []
    labels = []
    with torch.no_grad():
        for data, target in loader:
            data = data.to(device)
            # Forward pass up to the penultimate layer
            x = F.relu(F.max_pool2d(model.conv1(data), 2))
            x = F.relu(F.max_pool2d(model.conv2_drop(model.conv2(x)), 2))
            x = x.view(-1, 320)
            x = F.relu(model.fc1(x))
            features.append(x.cpu().numpy())
            labels.append(target.cpu().numpy())
    features = np.concatenate(features, axis=0)
    labels = np.concatenate(labels, axis=0)
    return features, labels

# Use the first 20 test images, batch size 2
larger_test_loader = DataLoader(Subset(test_data, range(20)), batch_size=2, shuffle=False)

# Clear previous variables to free memory
import gc
gc.collect()

features, labels = extract_features(model, larger_test_loader, device)

index = faiss.IndexFlatL2(features.shape[1])
index.add(features)

query = features[0].reshape(1, -1)
D, I = index.search(query, k=5)
print('Query label:', labels[0])
print('Nearest labels:', labels[I[0]])

Query label: 7
Nearest labels: [7 7 9 9 9]


## Step 1: Attach Captions to Images
For each image, we will create a simple caption (e.g., "Digit 7"). This will be used as context for the generative model.

In [11]:
# Create captions for the first N test images (e.g., 20)
num_caption_images = 20
test_captions = [f"Handwritten digit: {test_data[i][1]}" for i in range(num_caption_images)]
print(test_captions[:5])  # Show a few example captions

['Handwritten digit: 7', 'Handwritten digit: 2', 'Handwritten digit: 1', 'Handwritten digit: 0', 'Handwritten digit: 4']


## Step 1 (Multilingual): Attach Multilingual Captions to Images
We will add captions in multiple languages for each image.

In [12]:
# Example: Add multilingual captions for the first N test images
num_caption_images = 20
test_captions_multilingual = []
for i in range(num_caption_images):
    label = test_data[i][1]
    captions = {
        'en': f'Handwritten digit: {label}',
        'hi': f'हस्तलिखित अंक: {label}',
        'es': f'Dígito escrito a mano: {label}'
    }
    test_captions_multilingual.append(captions)
print(test_captions_multilingual[:2])

[{'en': 'Handwritten digit: 7', 'hi': 'हस्तलिखित अंक: 7', 'es': 'Dígito escrito a mano: 7'}, {'en': 'Handwritten digit: 2', 'hi': 'हस्तलिखित अंक: 2', 'es': 'Dígito escrito a mano: 2'}]


### Output Example: Multilingual Captions
The cell above prints the first two multilingual caption dictionaries, showing the format for English, Hindi, and Spanish. This helps verify that captions are correctly generated for each language.

## Step 2: Retrieval Pipeline
We will use the FAISS index to retrieve the top-k similar images for a given query image, and collect their captions.

In [13]:
# Example: Retrieve top-k similar images and their captions for a query image
k = 5
query_idx = 0  # Use the first image as the query
query_feature = features[query_idx].reshape(1, -1)
D, I = index.search(query_feature, k)
retrieved_indices = I[0]
retrieved_captions = [test_captions[i] for i in retrieved_indices]
print("Query caption:", test_captions[query_idx])
print("Retrieved captions:", retrieved_captions)

Query caption: Handwritten digit: 7
Retrieved captions: ['Handwritten digit: 7', 'Handwritten digit: 7', 'Handwritten digit: 9', 'Handwritten digit: 9', 'Handwritten digit: 9']


## Step 2 (Multilingual): Retrieval Pipeline with Multilingual Captions
Retrieve top-k similar images and their captions in the desired language.

In [14]:
# Set the desired language for retrieval and generation
lang = 'hi'  # 'en', 'hi', or 'es'

# Retrieve captions in the selected language
retrieved_captions_multilingual = [test_captions_multilingual[i][lang] for i in retrieved_indices]
print("Query caption:", test_captions_multilingual[query_idx][lang])
print("Retrieved captions:", retrieved_captions_multilingual)

Query caption: हस्तलिखित अंक: 7
Retrieved captions: ['हस्तलिखित अंक: 7', 'हस्तलिखित अंक: 7', 'हस्तलिखित अंक: 9', 'हस्तलिखित अंक: 9', 'हस्तलिखित अंक: 9']


## Step 3: Generative Model Integration
We will use a small Hugging Face model (e.g., T5-small) to generate a response based on the retrieved captions.

## Step 3 (Multilingual): Multilingual Generative Model Integration
Use a multilingual generative model (e.g., mT5 or mBART) to generate a response in the selected language.

In [None]:
# Install missing dependencies for transformers
%pip install tiktoken protobuf --quiet

# Install and use a multilingual text generation model (e.g., mT5-small)
try:
    from transformers import pipeline
except ImportError:
    %pip install transformers --quiet
    from transformers import pipeline

# Use a multilingual model (mT5-small) for text2text-generation
generator_multi = pipeline("text2text-generation", model="google/flan-t5-small")

# Prepare the prompt in the selected language
prompt_multi = f"प्रश्न: {test_captions_multilingual[query_idx][lang]}\nसमान उदाहरण: " + ", ".join(retrieved_captions_multilingual) + "\nसारांश:" if lang == 'hi' else (f"Pregunta: {test_captions_multilingual[query_idx][lang]}\nEjemplos similares: " + ", ".join(retrieved_captions_multilingual) + "\nResumen:" if lang == 'es' else f"Question: {test_captions_multilingual[query_idx][lang]}\nSimilar examples: " + ", ".join(retrieved_captions_multilingual) + "\nSummary:")

# Generate a response in the selected language
response_multi = generator_multi(prompt_multi, max_length=50, num_return_sequences=1)[0]['generated_text']
print(f"Generated response in {lang}:\n", response_multi)

Note: you may need to restart the kernel to use updated packages.


  from .autonotebook import tqdm as notebook_tqdm


ValueError: Converting from SentencePiece and Tiktoken failed, if a converter for SentencePiece is available, provide a model path with a SentencePiece tokenizer.model file.Currently available slow->fast converters: ['AlbertTokenizer', 'BartTokenizer', 'BarthezTokenizer', 'BertTokenizer', 'BigBirdTokenizer', 'BlenderbotTokenizer', 'CamembertTokenizer', 'CLIPTokenizer', 'CodeGenTokenizer', 'ConvBertTokenizer', 'DebertaTokenizer', 'DebertaV2Tokenizer', 'DistilBertTokenizer', 'DPRReaderTokenizer', 'DPRQuestionEncoderTokenizer', 'DPRContextEncoderTokenizer', 'ElectraTokenizer', 'FNetTokenizer', 'FunnelTokenizer', 'GPT2Tokenizer', 'HerbertTokenizer', 'LayoutLMTokenizer', 'LayoutLMv2Tokenizer', 'LayoutLMv3Tokenizer', 'LayoutXLMTokenizer', 'LongformerTokenizer', 'LEDTokenizer', 'LxmertTokenizer', 'MarkupLMTokenizer', 'MBartTokenizer', 'MBart50Tokenizer', 'MPNetTokenizer', 'MobileBertTokenizer', 'MvpTokenizer', 'NllbTokenizer', 'OpenAIGPTTokenizer', 'PegasusTokenizer', 'Qwen2Tokenizer', 'RealmTokenizer', 'ReformerTokenizer', 'RemBertTokenizer', 'RetriBertTokenizer', 'RobertaTokenizer', 'RoFormerTokenizer', 'SeamlessM4TTokenizer', 'SqueezeBertTokenizer', 'T5Tokenizer', 'UdopTokenizer', 'WhisperTokenizer', 'XLMRobertaTokenizer', 'XLNetTokenizer', 'SplinterTokenizer', 'XGLMTokenizer', 'LlamaTokenizer', 'CodeLlamaTokenizer', 'GemmaTokenizer', 'Phi3Tokenizer']

In [None]:
# Install transformers if not already installed
try:
    from transformers import pipeline
except ImportError:
    %pip install transformers --quiet
    from transformers import pipeline

# Use a small text generation model (T5-small or DistilGPT-2)
generator = pipeline('text-generation', model='distilgpt2')

# Prepare the prompt for the generative model
prompt = f"Query: {test_captions[query_idx]}\nSimilar examples: " + ", ".join(retrieved_captions) + "\nSummary:"

# Generate a response
response = generator(prompt, max_length=50, num_return_sequences=1)[0]['generated_text']
print("Generated response:\n", response)

In [None]:
import matplotlib.pyplot as plt

model.eval()

data, target = test_data[1]

data = data.unsqueeze(0).to(device)

output = model(data)
pred = output.argmax(dim=1, keepdim=True)

print(f'Predicted Label: {pred}')

image = data.squeeze().cpu().numpy()

plt.imshow(image, cmap='gray')
plt.show()