In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms
import torch.optim as optim
import torch.nn.functional as F
from PIL import Image
import random

In [None]:
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=models.resnet50(pretrained=True)
model.eval()
modules=list(model.children())[:-1]
model=nn.Sequential(*modules).to(device)

In [None]:
preprocess=transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def get_img_embed(img_path):
    img=Image.open(img_path).convert('RGB')
    img_tensor=preprocess(img).unsqueeze(0).to(device)
    with torch.no_grad():
        img_emb=model(img_tensor)

    return img_emb.squeeze(0).permute(1, 2, 0)

In [None]:
img_path='/Users/deepmalikpalthya/Downloads/Designer.jpeg'
print(get_img_embed(img_path).shape)

In [None]:
captions={}
with open('captions.txt', 'r') as f:
    for line in f:
        if line.strip():
            filename, caption = line.strip().split(',', 1)
            if filename not in captions:
                captions[filename]=[]
            captions[filename].append(caption)

In [None]:
vocabulary = set()
for filename, caption_list in captions.items():
    updated_captions = []  
    for caption in caption_list:
        words = caption.strip().split()
        updated_captions.append(words)
        vocabulary.update(words)
    captions[filename] = updated_captions
vocabulary.add('<start>')
vocabulary.add('<end>')
vocabulary.add('<pad>')


In [None]:
word_to_idx = {word: idx for idx, word in enumerate(vocabulary)}
idx_to_word = {idx: word for word, idx in word_to_idx.items()}
data_loader = []
count = 0

for filename, captions in captions.items():
    if count >= 200:
        break
    img_embed = get_img_embed(f'Images/{filename}')

    for caption in captions:
        caption_idx = [word_to_idx[word] for word in ['<start>'] + caption + ['<end>']]
        data_loader.append([img_embed, torch.tensor(caption_idx)])
    count += 1

In [None]:
class Decoder(nn.Module):
    def __init__(self, input_size, hidden_size, vocab_size, num_layers):
        super(Decoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, input_size)
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)

    def forward(self, features, captions):
        embeddings = self.embed(captions).unsqueeze(0)
        c0 = torch.zeros(self.lstm.num_layers, 1, self.lstm.hidden_size)
        outputs, _ = self.lstm(embeddings, (features, c0))
        outputs = self.linear(outputs)
        return outputs

In [None]:
input_size = 256
hidden_size = 2048
vocab_size = len(vocabulary) 
num_epochs = 5
batch_size = 32
learning_rate = 0.001
num_layers=1
grad_clip = 5.0 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
decoder = Decoder(input_size, hidden_size, vocab_size, num_layers)
optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

In [None]:
print(len(vocabulary))

In [None]:
for epoch in range(num_epochs):
    batches = [data_loader[i:i+batch_size] for i in range(0, len(data_loader), batch_size)]
    print(f'Epoch {epoch+1}/{num_epochs}')
    total_loss = 0
    for batch in batches:
        optimizer.zero_grad()
        img_embeds = batch[0][0]
        captions = torch.tensor(batch[0][1])  
        pad_token = word_to_idx['<pad>']
        max_length = 200
        pad_length = max_length - captions.shape[-1]
        padded_captions = F.pad(captions, (0, pad_length), "constant", pad_token)
        outputs = decoder(img_embeds, padded_captions)
        outputs = outputs[:, :-1, :]  
        targets = padded_captions[1:]  
        mask = (targets != pad_token)
        loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))
        loss = loss * mask.view(-1)
        loss = loss.sum() / mask.sum()  
        loss.backward()
        nn.utils.clip_grad_norm_(decoder.parameters(), grad_clip)
        optimizer.step()
        total_loss += loss.item()
    average_loss = total_loss / len(data_loader)
    print(f'Average Loss: {average_loss:.4f}')


torch.save(decoder.state_dict(), 'imgcap.pth')

In [None]:
def generate_caption(image_path, max_length=50):
    image_embedding = get_img_embed(image_path)
    start_token = torch.tensor(word_to_idx['<start>']).unsqueeze(0)
    caption = []
    input_sequence = start_token
    cn = torch.zeros(1, 1, 2048)

    for _ in range(max_length):
        # Call the forward function directly
        input_sequence = decoder.embed(input_sequence).unsqueeze(0)
        output, (hn, cn) = decoder.lstm(input_sequence, (image_embedding, cn))
        word_idx = output[:, -1].argmax(dim=1).item()  # Get the predicted word index at the last time step

        if word_idx == word_to_idx['<end>']:
            break

        if word_idx in idx_to_word:  # Check if the word index is in the vocabulary
            word = idx_to_word[word_idx]
            caption.append(word)
            input_sequence = torch.tensor([word_idx])
        else:
            if word_idx < len(word_to_idx):  # Check if the word index is within the vocabulary size
                word = 'unk'
                caption.append('<unk>')
                input_sequence = torch.tensor([word_to_idx['<unk>']])
            else:
                # Handle out-of-range word indices
                word = 'unk'
                caption.append('<unk>')
                input_sequence = torch.tensor([word_to_idx['<unk>']])

        image_embedding = hn
        cn = cn
        print(hn.shape)
        print(cn.shape)

    return ' '.join(caption)

In [None]:
print(generate_caption('Designer.jpeg'))

In [None]:
def beam_search(image_path, beam_width=3, max_length=50):
    image_embedding = get_img_embed(image_path)
    start_token = word_to_idx['<start>']
    end_token = word_to_idx['<end>']
    c0 = torch.zeros(1, 1, 2048)

    beam = [([start_token], 0.0)]

    for _ in range(max_length):
        new_beam = []
        for seq, score in beam:
            input_seq = torch.tensor(seq)
            embeddings = decoder.embed(input_seq).unsqueeze(0)
            output, _ = decoder.lstm(embeddings, (image_embedding, c0))
            word_scores = output[:, -1].squeeze(0)
            word_scores = F.log_softmax(word_scores, dim=-1)
            top_scores, top_indices = word_scores.topk(beam_width, dim=-1)

            for new_score, new_index in zip(top_scores, top_indices):
                new_seq = seq + [new_index.item()] 
                new_score = score + new_score.item()
                new_beam.append((new_seq, new_score))

                if new_index.item() == end_token:
                    break

        new_beam = sorted(new_beam, key=lambda x: x[1], reverse=True)[:beam_width]
        beam = new_beam

        if beam[0][0][-1] == end_token:
            break

    best_seq = beam[0][0][1:-1]  # Remove start and end tokens
    caption = ' '.join(idx_to_word[idx] for idx in best_seq)
    return caption

print(beam_search('Designer.jpeg'))