In [1]:
import torch
import pickle
import sys
import os
import numpy as np
import pandas as pd
import io
from datasets import load_dataset
from transformers import AutoProcessor, BlipForConditionalGeneration
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from PIL import Image

sys.path.append('/kaggle/input/sentence-transformers-222/sentence-transformers')
from sentence_transformers import SentenceTransformer, models

comp_path = Path('/kaggle/input/stable-diffusion-image-to-prompts/')

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Train Model

In [3]:
train = load_dataset('poloclub/diffusiondb', 'large_random_5k', split = 'train')

Downloading builder script:   0%|          | 0.00/15.0k [00:00<?, ?B/s]

Downloading and preparing dataset diffusion_db/large_random_5k to /root/.cache/huggingface/datasets/poloclub___diffusion_db/large_random_5k/0.9.1/547894e3a57aa647ead68c9faf148324098f47f2bc1ab6705d670721de9d89d1...


Downloading data:   0%|          | 0.00/457M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/472M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/442M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/488M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/441M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/825M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset diffusion_db downloaded and prepared to /root/.cache/huggingface/datasets/poloclub___diffusion_db/large_random_5k/0.9.1/547894e3a57aa647ead68c9faf148324098f47f2bc1ab6705d670721de9d89d1. Subsequent calls will reuse this data.


In [4]:
class ImageCaptioningDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        encoding = self.processor(images = item['image'],
                                  text = item['prompt'],
                                  padding = 'max_length',
                                  return_tensors = 'pt')
        encoding = {k:v.squeeze() for k,v in encoding.items()}
        return encoding

In [5]:
processor = AutoProcessor.from_pretrained('Salesforce/blip-image-captioning-base')
model = BlipForConditionalGeneration.from_pretrained('Salesforce/blip-image-captioning-base')

Downloading (…)rocessor_config.json:   0%|          | 0.00/287 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/438 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/4.56k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/990M [00:00<?, ?B/s]

In [6]:
train_dataset = ImageCaptioningDataset(train, processor)
train_dataloader = DataLoader(train_dataset, shuffle = True, batch_size = 5)

In [8]:
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

model.to(device)

model.train()

for epoch in range(5):
  for idx, batch in enumerate(train_dataloader):
    input_ids = batch.pop('input_ids').to(device)
    pixel_values = batch.pop('pixel_values').to(device)

    outputs = model(input_ids = input_ids,
                    pixel_values = pixel_values,
                    labels = input_ids)
    
    loss = outputs.loss

    loss.backward()

    optimizer.step()
    optimizer.zero_grad()

In [9]:
pickle.dump(model, open('/kaggle/working/model_5k.pkl','wb'))
pickle.dump(processor, open('/kaggle/working/processor_5k.pkl','wb'))

# Predictions

In [10]:
class CPU_Unpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module == 'torch.storage' and name == '_load_from_bytes':
            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
        else: return super().find_class(module, name)

model = CPU_Unpickler(open('/kaggle/input/blip-image-captioning-5k/model_5k.pkl', 'rb')).load()
processor = CPU_Unpickler(open('/kaggle/input/blip-image-captioning-5k/processor_5k.pkl', 'rb')).load()

In [11]:
model.to(device)

BlipForConditionalGeneration(
  (vision_model): BlipVisionModel(
    (embeddings): BlipVisionEmbeddings(
      (patch_embedding): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (encoder): BlipEncoder(
      (layers): ModuleList(
        (0): BlipEncoderLayer(
          (self_attn): BlipAttention(
            (dropout): Dropout(p=0.0, inplace=False)
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (projection): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): BlipMLP(
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): BlipEncoderLayer(
          (self_attn): BlipAttention(
 

In [12]:
images = os.listdir(comp_path / 'images')
imgIds = [i.split('.')[0] for i in images]

In [13]:
prompts = []
images_path = "../input/stable-diffusion-image-to-prompts/images/"

for image_name in images:
    image = Image.open(images_path + image_name).convert('RGB')
    inputs = processor(images = image, return_tensors = 'pt').to(device)
    pixel_values = inputs.pixel_values
    generated_ids = model.generate(pixel_values = pixel_values, max_length = 50)
    generated_caption = processor.batch_decode(generated_ids, skip_special_tokens = True)[0]
    prompts.append(generated_caption)

In [14]:
st_model = SentenceTransformer('/kaggle/input/sentence-transformers-222/all-MiniLM-L6-v2')
prompt_embeddings = st_model.encode(prompts).flatten()

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [15]:
df_submission = pd.read_csv(comp_path / 'sample_submission.csv', index_col='imgId_eId')

In [16]:
EMBEDDING_LENGTH = 384
eIds = list(range(EMBEDDING_LENGTH))

imgId_eId = [
    '_'.join(map(str, i)) for i in zip(
        np.repeat(imgIds, EMBEDDING_LENGTH),
        np.tile(range(EMBEDDING_LENGTH), len(imgIds)))]

assert sorted(imgId_eId) == sorted(df_submission.index)

In [17]:
submission = pd.DataFrame(
                index=imgId_eId,
                data=prompt_embeddings,
                columns=['val']).rename_axis('imgId_eId')

In [None]:
submission.to_csv('submission.csv')