In [9]:
import os
import json
from google.colab import drive

!pip install kaggle

drive.mount('/content/gdrive')

f = open("/content/gdrive/My Drive/Colab Notebooks/Kaggle/kaggle.json", 'r')
json_data = json.load(f) 
os.environ['KAGGLE_USERNAME'] = json_data['username']
os.environ['KAGGLE_KEY'] = json_data['key']

!kaggle competitions download -c stable-diffusion-image-to-prompts
!mkdir /content/input
!mv /content/stable-diffusion-image-to-prompts.zip /content/input/stable-diffusion-image-to-prompts.zip
!unzip /content/input/stable-diffusion-image-to-prompts.zip -d /content/input/

Archive:  /content/input/stable-diffusion-image-to-prompts.zip
  inflating: /content/input/images/20057f34d.png  
  inflating: /content/input/images/227ef0887.png  
  inflating: /content/input/images/92e911621.png  
  inflating: /content/input/images/a4e1c55a9.png  
  inflating: /content/input/images/c98f79f71.png  
  inflating: /content/input/images/d8edf2e40.png  
  inflating: /content/input/images/f27825b2c.png  
  inflating: /content/input/prompts.csv  
  inflating: /content/input/sample_submission.csv  


# Library

In [None]:
import os
import random
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.notebook import tqdm
from scipy import spatial
from sklearn.model_selection import train_test_split
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import transforms
import timm
from timm.utils import AverageMeter
import sys
sys.path.append('../input/sentence-transformers-222/sentence-transformers')
from sentence_transformers import SentenceTransformer
import warnings
warnings.filterwarnings('ignore')

# Config

In [None]:
class CFG:
    model_name = 'vit_base_patch16_224'
    input_size = 224
    batch_size = 64
    num_epochs = 6
    lr = 1e-4
    seed = 42

In [None]:
def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True


seed_everything(CFG.seed)

# Dataset

In [None]:
class DiffusionDataset(Dataset):
    def __init__(self, df, transform):
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = Image.open(row['filepath'])
        image = self.transform(image)
        prompt = row['prompt']
        return image, prompt


class DiffusionCollator:
    def __init__(self):
        self.st_model = SentenceTransformer(
            '/kaggle/input/sentence-transformers-222/all-MiniLM-L6-v2',
            device='cpu'
        )
    
    def __call__(self, batch):
        images, prompts = zip(*batch)
        images = torch.stack(images)
        prompt_embeddings = self.st_model.encode(
            prompts, 
            show_progress_bar=False, 
            convert_to_tensor=True
        )
        return images, prompt_embeddings
    
    
def get_dataloaders(
    trn_df,
    val_df,
    input_size,
    batch_size
):
    transform = transforms.Compose([
        transforms.Resize(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    trn_dataset = DiffusionDataset(trn_df, transform)
    val_dataset = DiffusionDataset(val_df, transform)
    collator = DiffusionCollator()
    
    dataloaders = {}
    dataloaders['train'] = DataLoader(
        dataset=trn_dataset,
        shuffle=True,
        batch_size=batch_size,
        pin_memory=True,
        num_workers=2,
        drop_last=True,
        collate_fn=collator
    )
    dataloaders['val'] = DataLoader(
        dataset=val_dataset,
        shuffle=False,
        batch_size=batch_size,
        pin_memory=True,
        num_workers=2,
        drop_last=False,
        collate_fn=collator
    )
    return dataloaders

# Train

In [None]:
def cosine_similarity(y_trues, y_preds):
    return np.mean([
        1 - spatial.distance.cosine(y_true, y_pred) 
        for y_true, y_pred in zip(y_trues, y_preds)
    ])

In [None]:
def train(
    trn_df,
    val_df,
    model_name,
    input_size,
    batch_size,
    num_epochs,
    lr
):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataloaders = get_dataloaders(
        trn_df,
        val_df,
        input_size,
        batch_size
    )

    model = timm.create_model(
        model_name,
        pretrained=True,
        num_classes=384
    )
    model.set_grad_checkpointing()
    model.to(device)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    ttl_iters = num_epochs * len(dataloaders['train'])
    scheduler = CosineAnnealingLR(optimizer, T_max=ttl_iters, eta_min=1e-6)
    criterion = nn.CosineEmbeddingLoss()
    
    best_score = -1.0

    for epoch in range(num_epochs):
        train_meters = {
            'loss': AverageMeter(),
            'cos': AverageMeter(),
        }
        model.train()
        for X, y in tqdm(dataloaders['train'], leave=False):
            X, y = X.to(device), y.to(device)

            optimizer.zero_grad()
            X_out = model(X)
            target = torch.ones(X.size(0)).to(device)
            loss = criterion(X_out, y, target)
            loss.backward()

            optimizer.step()
            scheduler.step()

            trn_loss = loss.item()
            trn_cos = cosine_similarity(
                X_out.detach().cpu().numpy(), 
                y.detach().cpu().numpy()
            )

            train_meters['loss'].update(trn_loss, n=X.size(0))
            train_meters['cos'].update(trn_cos, n=X.size(0))

        print('Epoch {:d} / trn/loss={:.4f}, trn/cos={:.4f}'.format(
            epoch + 1,
            train_meters['loss'].avg,
            train_meters['cos'].avg))

        val_meters = {
            'loss': AverageMeter(),
            'cos': AverageMeter(),
        }
        model.eval()
        for X, y in tqdm(dataloaders['val'], leave=False):
            X, y = X.to(device), y.to(device)

            with torch.no_grad():
                X_out = model(X)
                target = torch.ones(X.size(0)).to(device)
                loss = criterion(X_out, y, target)

                val_loss = loss.item()
                val_cos = cosine_similarity(
                    X_out.detach().cpu().numpy(), 
                    y.detach().cpu().numpy()
                )

            val_meters['loss'].update(val_loss, n=X.size(0))
            val_meters['cos'].update(val_cos, n=X.size(0))

        print('Epoch {:d} / val/loss={:.4f}, val/cos={:.4f}'.format(
            epoch + 1,
            val_meters['loss'].avg,
            val_meters['cos'].avg))
        
        if val_meters['cos'].avg > best_score:
            best_score = val_meters['cos'].avg
            torch.save(model.state_dict(), f'{model_name}.pth')

In [None]:
df = pd.read_csv('/kaggle/input/diffusiondb-data-cleansing/diffusiondb.csv')
trn_df, val_df = train_test_split(df, test_size=0.1, random_state=CFG.seed)

In [None]:
# trn_df = df[:2]
# val_df = df[2:4]
# val_df

In [None]:
train(trn_df, val_df, CFG.model_name, CFG.input_size, CFG.batch_size, CFG.num_epochs, CFG.lr)

  0%|          | 0/2170 [00:00<?, ?it/s]

Epoch 1 / trn/loss=0.5024, trn/cos=0.4976


  0%|          | 0/242 [00:00<?, ?it/s]

Epoch 1 / val/loss=0.4674, val/cos=0.5326


  0%|          | 0/2170 [00:00<?, ?it/s]

Epoch 2 / trn/loss=0.4364, trn/cos=0.5636


  0%|          | 0/242 [00:00<?, ?it/s]

Epoch 2 / val/loss=0.4468, val/cos=0.5532


  0%|          | 0/2170 [00:00<?, ?it/s]

Epoch 3 / trn/loss=0.3949, trn/cos=0.6051


  0%|          | 0/242 [00:00<?, ?it/s]

Epoch 3 / val/loss=0.4381, val/cos=0.5619


  0%|          | 0/2170 [00:00<?, ?it/s]

Epoch 4 / trn/loss=0.3519, trn/cos=0.6481


  0%|          | 0/242 [00:00<?, ?it/s]

Epoch 4 / val/loss=0.4394, val/cos=0.5606


  0%|          | 0/2170 [00:00<?, ?it/s]

Epoch 5 / trn/loss=0.3116, trn/cos=0.6884


  0%|          | 0/242 [00:00<?, ?it/s]

Epoch 5 / val/loss=0.4456, val/cos=0.5544


  0%|          | 0/2170 [00:00<?, ?it/s]

Epoch 6 / trn/loss=0.2856, trn/cos=0.7144


  0%|          | 0/242 [00:00<?, ?it/s]

Epoch 6 / val/loss=0.4529, val/cos=0.5471


In [None]:
# model = timm.create_model(
#     CFG.model_name,
#     pretrained=True,
#     num_classes=384
# )
# model.load_state_dict(torch.load('{CFG.model_name}.pth'))

# model.eval()
# for X, y in tqdm(dataloaders['val'], leave=False):
#     X, y = X.to(device), y.to(device)

#     with torch.no_grad():
#         X_out = model(X)
#         target = torch.ones(X.size(0)).to(device)
#         loss = criterion(X_out, y, target)
