In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
#for dirname, _, filenames in os.walk('/kaggle/input'):
#    for filename in filenames:
#        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image
from tqdm.notebook import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import timm

In [None]:
class CFG:
#     model_path = '/kaggle/input/stable-diffusion-vit-baseline-train/vit_base_patch16_224.pth'
    model_path = '/kaggle/input/swin-large-finetune-stablediffusion-textimage-pair/swin_large_patch4_window7_224_15_epochs.pth'
#     model_name = 'vit_base_patch16_224'
    model_name = 'swin_large_patch4_window7_224'
    input_size = 224
    batch_size = 64

In [None]:
class DiffusionTestDataset(Dataset):
    def __init__(self, images, transform):
        self.images = images
        self.transform = transform
    
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx])
        image = self.transform(image)
        return image

In [None]:
def predict(
    images,
    model_path,
    model_name,
    input_size,
    batch_size
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    transform = transforms.Compose([
        transforms.Resize(input_size),
        transforms.RandomHorizontalFlip(p=0.5),
#         transforms.RandomRotation(degrees=10),

        # transforms.RandomVerticalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    dataset = DiffusionTestDataset(images, transform)
    dataloader = DataLoader(
        dataset=dataset,
        shuffle=False,
        batch_size=batch_size,
        pin_memory=True,
        num_workers=2,
        drop_last=False
    )

    model = timm.create_model(
        model_name,
        pretrained=False,
        num_classes=384
    )
    
    state_dict = torch.load(model_path, map_location=torch.device('cpu'))
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    
    tta_preds = None
    for _ in range(2):
        preds = []
        for X in tqdm(dataloader, leave=False):
            X = X.to(device)

            with torch.no_grad():
                X_out = model(X)
                preds.append(X_out.cpu().numpy())
                
        if tta_preds is None:
            tta_preds = np.vstack(preds).flatten()
        else:
            tta_preds += np.vstack(preds).flatten()
    
    return tta_preds / 2

In [None]:
images = list(Path('/kaggle/input/stable-diffusion-image-to-prompts/images').glob('*.png'))
imgIds = [i.stem for i in images]
EMBEDDING_LENGTH = 384
imgId_eId = [
    '_'.join(map(str, i)) for i in zip(
        np.repeat(imgIds, EMBEDDING_LENGTH),
        np.tile(range(EMBEDDING_LENGTH), len(imgIds)))]

prompt_embeddings = predict(images, CFG.model_path, CFG.model_name, CFG.input_size, CFG.batch_size)
prompt_embeddings = 0.45 * prompt_embeddings
submission = pd.DataFrame(
    index=imgId_eId,
    data=prompt_embeddings,
    columns=['val']
).rename_axis('imgId_eId')
submission.to_csv('submission.csv')

In [None]:
submission.head()

In [None]:
#prompt_embeddings.shape

# sentence tranformer

In [None]:
"""
import sys
import os
from pathlib import Path
sys.path.append('../input/sentence-transformers-222/sentence-transformers')
from sentence_transformers import SentenceTransformer, models
sentence_model_path = "/kaggle/input/sentence-transformers-222/all-MiniLM-L6-v2"
st_model = SentenceTransformer(sentence_model_path)
"""

# BLIP

In [None]:
"""
transform = transforms.Compose([
        transforms.Resize(224),
        transforms.RandomHorizontalFlip(p=0.5),
#         transforms.RandomRotation(degrees=10),

        # transforms.RandomVerticalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

dataset = DiffusionTestDataset(images, transform)
dataloader = DataLoader(
        dataset=dataset,
        shuffle=False,
        batch_size=64,
        pin_memory=True,
        num_workers=2,
        drop_last=False
    )
"""

In [None]:
#import matplotlib.pyplot as plt
#img = Image.open("/kaggle/input/stable-diffusion-image-to-prompts/images/20057f34d.png")
#plt.imshow(img)

In [None]:
"""
from PIL import Image
import requests
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import torch

#device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"

processor = Blip2Processor.from_pretrained("/kaggle/input/image-caption-models/blip2-opt-2.7b")
"""

In [None]:
#model = Blip2ForConditionalGeneration.from_pretrained("/kaggle/input/image-caption-models/blip2-opt-2.7b")

In [None]:
"""
for X in tqdm(dataloader, leave=False):
    #inputs = processor(images=image, return_tensors="pt").to(device)
    generated_ids = model.generate(X, min_length=20, max_length = 100)
    #generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
    prompts = [output.strip() for output in processor.batch_decode(generated_ids, skip_special_tokens=True)]
    #prompts.append(generated_text)
    #out = model.generate(X)
    #for emb in out:
    #    prompts.append(processor.decode(emb, skip_special_tokens=True))

prompts
"""

In [None]:
#blip_embeddings = st_model.encode(prompts).flatten()

In [None]:
#blip_embeddings.shape

# Esamble

In [None]:
#ratio_swin_large         = 0.5
#ratio_blip2 = 0.5

In [None]:
#final_embeddings = ratio_swin_large*prompt_embeddings + ratio_blip2*blip_embeddings

In [None]:
"""
submission = pd.DataFrame(
    index=imgId_eId,
    data=final_embeddings,
    columns=['val']
).rename_axis('imgId_eId')
"""

In [None]:
#submission.to_csv('submission.csv')

In [None]:
#submission.head()