In [33]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np

In [34]:

# Load a single split of the dataset
dataset = load_dataset('nlphuji/flickr30k', split='test')

# 1. Prepare your caption data
# Assuming each item's 'caption' field is a list of captions
all_captions = [caption for item in dataset for caption in item['caption']]

# Write captions to a file, one caption per line
with open('captions.txt', 'w', encoding='utf-8') as f:
    for caption in all_captions:
        f.write(caption + '\n')
        

Repo card metadata block was not found. Setting CardData to empty.


In [35]:

#2.  Train a SentencePiece model on the captions

spm.SentencePieceTrainer.train(input='captions.txt', model_prefix='flickr_captions', vocab_size=8000, model_type='bpe')

#  two files: flickr_captions.model and flickr_captions.vocab, which are your trained SentencePiece model and its vocabulary


sentencepiece_trainer.cc(78) LOG(INFO) Starts training with : 
trainer_spec {
  input: captions.txt
  input_format: 
  model_prefix: flickr_captions
  model_type: BPE
  vocab_size: 8000
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  pretokenization_delimiter: 
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  seed_sentencepieces_file: 
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 0
  bos_id: 1
  eos_id: 2
  pad_id: -1
  unk_piece: <unk>
  bos_piece: <s>
  eos_piece: </s>
  pad_piece: <pad>
  unk_surface:  ⁇ 
  enable_differential_privacy: 0
  di

In [58]:

class Flickr30kDataset(Dataset):
    def __init__(self, split, tokenizer_model, patch_size=16):
       
        self.dataset = load_dataset("nlphuji/flickr30k")[split]
        
        # Load the SentencePiece tokenizer
        self.sp = spm.SentencePieceProcessor(model_file='flickr_captions.model')
        
        # Define the special tokens
        self.sos_token = self.sp.piece_to_id('<s>')  # Start of sentence token
        self.eos_token = self.sp.piece_to_id('</s>')  # End of sentence token
        
        self.patch_size = patch_size

        # Flatten the dataset: create an entry for each (image, caption) pair
        self.items = [(item['image'], caption) for item in dataset for caption in item['caption']]
        

        # Define image transformations
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),  # Resize images to a fixed size
            transforms.ToTensor(),  # Convert images to tensor
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize images
        ])
        
        
    def __len__(self):
        return len(self.items)

    
    def __getitem__(self, idx):

         #Access the flattened item
        image_path, caption = self.items[idx]
        
        # Load and process the image
        image = Image.open(image_path).convert('RGB')
        image = self.transform(image)
        
        # Split the image into patches and unroll
        patches = self.image_to_patches(image)

        # Tokenize the selected caption
        tokenized_caption = self.sp.encode(caption, out_type=int)
        caption_input = [self.sp.bos_id()] + tokenized_caption
        caption_label = tokenized_caption + [self.sp.eos_id()]
        
        # Return the processed data
        return image, torch.tensor(caption_input), torch.tensor(caption_label)
        
    
    def image_to_patches(self, image):
        """
        Split the image into fixed-size patches and unroll each patch into a one-dimensional vector.
        """
        C, H, W = image.shape
        patch_size = self.patch_size
        patches = image.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size)
        patches = patches.contiguous().view(C, -1, patch_size*patch_size)
        patches = patches.permute(1, 0, 2).reshape(-1, C*patch_size*patch_size)
        return patches


In [62]:
#image, caption_inputs, caption_labels = dataset[0]  # Get the first item


In [42]:
dataset_split = 'test'  
patch_size = 16  

flickr_dataset = Flickr30kDataset(dataset_split, tokenizer_model, patch_size=patch_size)
dataloader = DataLoader(flickr_dataset, batch_size=32, shuffle=True)



Repo card metadata block was not found. Setting CardData to empty.
