In [1]:
import os

import zipfile

import numpy as np

import json

import msgpack

import pandas as pd

from tqdm.auto import tqdm

In [2]:
ROOT_DIR = '../kcg-ml-image-pipeline/output/dataset/'

DATASET = 'environmental'

In [3]:
OUTPUT_DIR = './data/spmi/civitai/'

In [4]:
os.makedirs(OUTPUT_DIR, exist_ok=True)

# load prompts

In [5]:
positive_prompts = list()
negative_prompts = list()
image_hashs = list()

## from file system

In [22]:
for dname in os.listdir(os.path.join(ROOT_DIR, 'image', DATASET)):
    
    for file_name in os.listdir(os.path.join(ROOT_DIR, 'image', DATASET, dname)):
    
        if not file_name.lower().endswith(('.jpg', '.png', '.jpeg', '.bmp')):
            continue
        
        file_name = os.path.splitext(file_name)[0]

        if not os.path.exists(os.path.join(ROOT_DIR, 'clip', DATASET, dname, f'{file_name}_clip.msgpack')):
            continue

        if not os.path.exists(os.path.join(ROOT_DIR, 'data', DATASET, dname, f'{file_name}_data.msgpack')):
            continue
        
        meta = msgpack.load(open(os.path.join(ROOT_DIR, 'data', DATASET, dname, f'{file_name}_data.msgpack'), 'rb'))
        
        image_hashs.append(meta['file_hash'])
        positive_prompts.append(meta['positive_prompt'])
        negative_prompts.append(meta['negative_prompt'])

## from zip

In [6]:
ZIP_PATHs = [
    # './generated/generated-1122.zip',
    
#     './generated/generated-1120.zip',
#     './generated/generated-1123.zip',
#     './generated/generated-1125.zip',
    
    './generated/generated-1126.zip',
]

In [7]:
for zip_path in tqdm(ZIP_PATHs, leave=False):
    
    f = zipfile.ZipFile(zip_path)

    file_paths = list()

    files = set(f.namelist())

    for file_path in f.namelist():

        if file_path.startswith('generated/image/') and file_path.endswith('.jpg'):

            embedding_path = file_path.replace('/image/', '/embedding/').replace('.jpg', '.npz')
            clip_path = file_path.replace('/image/', '/clip/').replace('.jpg', '.npy')

            if embedding_path not in files or clip_path not in files:
                continue

            file_paths.append(file_path)

    hashs = np.array([os.path.splitext(os.path.split(i)[-1])[0] for i in file_paths])

    for image_hash in hashs:

        meta_path = os.path.join('generated', 'meta', f'{image_hash}.json')

        meta = json.load(f.open(meta_path))

        image_hashs.append(image_hash)
        positive_prompts.append(meta['positive_prompt'])
        negative_prompts.append(meta['negative_prompt'])

  0%|          | 0/1 [00:00<?, ?it/s]

# build table

In [8]:
NUM_SAMPLES = 15

In [9]:
df = pd.DataFrame(zip(image_hashs, positive_prompts, negative_prompts), columns=['image_hash', 'positive_prompt', 'negative_prompt'])

In [10]:
counts = df.groupby(['positive_prompt', 'negative_prompt']).count()

In [11]:
np.unique(counts['image_hash'], return_counts=True)

(array([15]), array([20000]))

In [12]:
samples = list()

for prompt, g in df.groupby(['positive_prompt', 'negative_prompt']):
    
    if g.shape[0] < NUM_SAMPLES:
        continue
        
    samples.append(prompt + tuple(g['image_hash'])[:NUM_SAMPLES])
    
samples = pd.DataFrame(samples, columns=['positive_prompt', 'negative_prompt'] + list(map('image_hash_{}'.format, range(NUM_SAMPLES))))

In [13]:
samples.to_csv(os.path.join(OUTPUT_DIR, 'samples.csv.gz'), index=False)

# build dataset

In [14]:
hash_to_id = dict()
for i, hashs in enumerate(samples[samples.columns[2:]].itertuples(index=False, name=None)):
    for j, image_hash in enumerate(hashs):
        hash_to_id[image_hash] = (i, j)

In [15]:
clip_embs = np.zeros(samples[samples.columns[2:]].shape + (768,), dtype='float32')

## from file system

In [33]:
for dname in os.listdir(os.path.join(ROOT_DIR, 'data', DATASET)):
    
    for file_name in os.listdir(os.path.join(ROOT_DIR, 'data', DATASET, dname)):
        
        file_name = file_name.replace('_data.msgpack', '')
        
        meta = msgpack.load(open(os.path.join(ROOT_DIR, 'data', DATASET, dname, f'{file_name}_data.msgpack'), 'rb'))
        image_hash = meta['file_hash']
        
        if image_hash not in hash_to_id:
            continue
        
        meta = msgpack.load(open(os.path.join(ROOT_DIR, 'clip', DATASET, dname, f'{file_name}_clip.msgpack'), 'rb'))
        
        clip_embs[hash_to_id[image_hash]] = np.array(meta['clip-feature-vector'][0])

## from zip

In [16]:
for zip_path in tqdm(ZIP_PATHs, leave=False):
    
    f = zipfile.ZipFile(zip_path)
    
    for file_path in f.namelist():

        if file_path.startswith('generated/clip/') and file_path.endswith('.npy'):

            image_hash = os.path.splitext(os.path.split(file_path)[-1])[0]

            if image_hash not in hash_to_id:
                continue

            clip_embs[hash_to_id[image_hash]] = np.load(f.open(file_path))

  0%|          | 0/1 [00:00<?, ?it/s]

In [17]:
(clip_embs.min(axis=-1) == 0).sum()

0

## save

In [18]:
np.save(os.path.join(OUTPUT_DIR, 'clip_vision.npy'), clip_embs)

# build prompts embs

In [19]:
import torch
from transformers import AutoTokenizer, AutoModel

In [20]:
MODEL_NAME = 'openai/clip-vit-large-patch14'
BATCH_SIZE = 64
MAX_LENGTH = 77

In [21]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, local_files_only=True)

In [22]:
transformer = AutoModel.from_pretrained(MODEL_NAME, local_files_only=True).text_model.cuda().eval()

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.


In [23]:
def worker(texts):
    
    batch_encoding = tokenizer(
        texts,
        truncation=True, max_length=MAX_LENGTH, return_length=True,
        return_overflowing_tokens=False, padding="max_length", return_tensors="pt"
    )

    tokens = batch_encoding["input_ids"].cuda()

    clip_text_opt = transformer(input_ids=tokens, output_hidden_states=True, return_dict=True)
    
    pooler_output = clip_text_opt.pooler_output.detach().cpu().numpy()
    
    return pooler_output

In [32]:
samples = pd.read_csv(os.path.join(OUTPUT_DIR, 'samples.csv.gz'))

positive_prompts = list(samples['positive_prompt'])
negative_prompts = list(samples['negative_prompt'])

In [33]:
positive_pooler_outputs = list()
negative_pooler_outputs = list()

with torch.no_grad():
    
    for i in tqdm(range(0, len(positive_prompts), BATCH_SIZE), leave=False):
        
        pooler_output = worker(positive_prompts[i:i+BATCH_SIZE])
        positive_pooler_outputs.append(pooler_output)
    
    for i in tqdm(range(0, len(negative_prompts), BATCH_SIZE), leave=False):
        
        pooler_output = worker(negative_prompts[i:i+BATCH_SIZE])
        negative_pooler_outputs.append(pooler_output)
        
positive_pooler_outputs = np.concatenate(positive_pooler_outputs, axis=0)
negative_pooler_outputs = np.concatenate(negative_pooler_outputs, axis=0)

  0%|          | 0/119 [00:00<?, ?it/s]

  0%|          | 0/119 [00:00<?, ?it/s]

In [34]:
np.savez(
    os.path.join(OUTPUT_DIR, 'clip_text.npz'), 
    positive_pooler_outputs=positive_pooler_outputs,
    negative_pooler_outputs=negative_pooler_outputs
)