In [None]:
ROOT = '../dataset/'

EMB_TYPE = 'text_embeds'

# DATASET = 'environmental'
# DATASET = 'character'
DATASET = 'mech'
# DATASET = 'icons'
# DATASET = 'waifu'
# DATASET = 'propaganda-poster'

In [None]:
import argparse

parser = argparse.ArgumentParser()

parser.add_argument("--EMB_TYPE", default=EMB_TYPE, type=str, help="EMB_TYPE")
parser.add_argument("--DATASET", default=DATASET, type=str, help="DATASET")


try:
    args = parser.parse_args()
    
    EMB_TYPE = args.EMB_TYPE
    DATASET = args.DATASET
except:
    pass

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import sys
import json
import glob
import torch

from io import BytesIO

import numpy as np
import pandas as pd

from PIL import Image
from matplotlib import pyplot

import msgpack

from tqdm.auto import tqdm

from sklearn.model_selection import train_test_split
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
DATA_PATH = f'data/{DATASET}/data.json'

WEIGHT_PATH = os.path.join('weight/004', DATASET, f'clip_{EMB_TYPE}.pt')

In [None]:
from transformers import CLIPTextModel, CLIPTokenizer

MODEL_NAME = 'openai/clip-vit-large-patch14'

tokenizer = CLIPTokenizer.from_pretrained('./input/model/clip/txt_emb_tokenizer', local_files_only=True)
transformer = CLIPTextModel.from_pretrained('./input/model/clip/txt_emb_model', local_files_only=True).cuda().eval()

def get_prompt_embeds(texts, batch_size):
    
    def worker(batch):
        
        batch_encoding = tokenizer(
            batch,
            truncation=True, max_length=77, return_length=True,
            return_overflowing_tokens=False, padding="max_length", return_tensors="pt"
        )
        
        with torch.no_grad():
            tokens = batch_encoding["input_ids"].cuda()
            clip_text_opt = transformer(input_ids=tokens, output_hidden_states=True, return_dict=True)
        
        pooler_output = clip_text_opt.pooler_output.detach().cpu().numpy()
        
        return pooler_output

    pooler_outputs = list()
    for i in tqdm(range(0, len(texts), batch_size), leave=False):
        pooler_output = worker(texts[i:i+batch_size])
        pooler_outputs.append(pooler_output)
    pooler_outputs = np.concatenate(pooler_outputs, axis=0)

    return pooler_outputs

# load emb

In [None]:
js = json.load(open(DATA_PATH))

file_paths = list()
sample_embeds = list()

for info in tqdm(js.values(), total=len(js), leave=False):

    file_path = os.path.splitext(info['file_path'].split('_')[0])[0]
    file_paths.append(file_path)

    if EMB_TYPE == 'image_embeds':
    
        path = os.path.join(ROOT, 'clip', f'{file_path}_clip.msgpack')
        with open(path, 'rb') as f:
            mp = msgpack.load(f)
        sample_embeds.append(np.array(mp['clip-feature-vector']))

    elif EMB_TYPE == 'text_embeds':
        sample_embeds.append((info['positive_prompt'], info['negative_prompt']))
    elif EMB_TYPE == 'pos_embeds':
        sample_embeds.append(info['positive_prompt'])
    elif EMB_TYPE == 'neg_embeds':
        sample_embeds.append(info['negative_prompt'])

file_paths = np.array(file_paths)
path_to_index = {file_path: i for i, file_path in enumerate(file_paths)}

In [None]:
if EMB_TYPE == 'image_embeds':
    sample_embeds = np.concatenate(sample_embeds, axis=0)
elif EMB_TYPE == 'text_embeds':
    pos_prompts, neg_prompts = zip(*sample_embeds)
    sample_embeds = np.concatenate([get_prompt_embeds(pos_prompts, 1024), get_prompt_embeds(neg_prompts, 1024)], axis=1)
elif EMB_TYPE == 'pos_embeds':
    sample_embeds = get_prompt_embeds(sample_embeds, 1024)
elif EMB_TYPE == 'neg_embeds':
    sample_embeds = get_prompt_embeds(sample_embeds, 1024)

# load rank data

In [None]:
paths = sorted(glob.glob(os.path.join(ROOT, 'ranking', DATASET, '*.json')))

rank_file_paths = list()
rank_pairs = list()

for path in tqdm(paths):
    js = json.load(open(path))

    if js['task'] != 'selection':
        continue
    
    file_path_1 = os.path.splitext(js['image_1_metadata']['file_path'])[0].replace('datasets/', '')
    file_path_2 = os.path.splitext(js['image_2_metadata']['file_path'])[0].replace('datasets/', '')
    
    if (file_path_1 not in path_to_index) or (file_path_2 not in path_to_index):
        continue
    rank_file_paths.append(path)
    rank_pairs.append((file_path_1, file_path_2, js['selected_image_index']))

# build dataset

In [None]:
rank_pairs = pd.DataFrame(rank_pairs, columns=['image_1', 'image_2', 'selected_image_index'])

In [None]:
ordered_pairs = [((image_1, image_2) if selected_image_index == 0 else (image_2, image_1)) for image_1, image_2, selected_image_index in rank_pairs.itertuples(index=False, name=None)]
ordered_pairs = pd.DataFrame(ordered_pairs, columns=['image_1', 'image_2'])

ordered_pairs['index_1'] = ordered_pairs['image_1'].apply(path_to_index.get)
ordered_pairs['index_2'] = ordered_pairs['image_2'].apply(path_to_index.get)
ordered_pairs['file_path'] = [i.replace(f'{ROOT}ranking/{DATASET}/', f'datasets/{DATASET}/data/ranking/aggregate/') for i in rank_file_paths]

## build feature

In [None]:
train_indices, val_indices = train_test_split(ordered_pairs.index, test_size=0.2, random_state=42)

In [None]:
train_data = list()
for index_1, index_2 in ordered_pairs.loc[train_indices, ['index_1', 'index_2']].itertuples(index=False, name=None):
    train_data.append(np.stack([sample_embeds[index_1], sample_embeds[index_2]], axis=-1))
train_data = np.stack(train_data, axis=0)

val_data = list()
for index_1, index_2 in ordered_pairs.loc[val_indices, ['index_1', 'index_2']].itertuples(index=False, name=None):
    val_data.append(np.stack([sample_embeds[index_1], sample_embeds[index_2]], axis=-1))
val_data = np.stack(val_data, axis=0)

train_data.shape, val_data.shape

In [None]:
train_dataset = torch.tensor(train_data).cuda().float()
val_dataset = torch.tensor(val_data).cuda().float()

# build model

In [None]:
model = torch.nn.Linear(train_data.shape[1], 1, bias=True)
model = model.cuda()

# train model

In [None]:
LR = 1e-3

optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-3)

In [None]:
bces, accs = list(), list()

for epoch in tqdm(range(1000)):
    
    model.train()
    
    label = torch.zeros((train_dataset.shape[0],), device='cuda')
    
    x = train_dataset

    optimizer.zero_grad()

    y0 = model(x[..., 0])
    y1 = model(x[..., 1])

    y = torch.concat([y0, y1], dim=-1)

    # backward

    bce = torch.nn.functional.cross_entropy(y, label.long())

    acc = (y0 > y1).float().mean()

    l1 = torch.norm(model.weight, p=1)

    loss = bce + l1 * 1e-3

    loss.backward()
    optimizer.step()

    # warmup.step()

    bces.append(bce.detach().cpu().numpy())
    accs.append(acc.detach().cpu().numpy())
        
    if (epoch + 1) % 100 == 0:
        
        model.eval()

        val_bces, val_accs = list(), list()

        with torch.no_grad():
            
            x = val_dataset

            with torch.cuda.amp.autocast(True):

                y0 = model(x[..., 0])
                y1 = model(x[..., 1])

                y = torch.concat([y0, y1], dim=-1)

            label = torch.zeros((y.shape[0],), device='cuda').long()

            bce = torch.nn.functional.cross_entropy(y, label)

            acc = (y.argmax(dim=-1) == 0).float().mean()

            val_bces.append(bce.detach().cpu().numpy())
            val_accs.append(acc.detach().cpu().numpy())

        print(f'{np.mean(bces):.4f} {np.mean(accs):.4f} {np.mean(val_bces):.4f} {np.mean(val_accs):.4f}')
    
        bces, accs = list(), list()

## calculate score

In [None]:
with torch.no_grad():
    with torch.cuda.amp.autocast(True):
        score = model(torch.tensor(sample_embeds).half().cuda())[:, 0]
        score = score.detach().cpu().numpy().astype('float32')

In [None]:
score_mean, score_std = score.mean(axis=0), score.std(axis=0)
sigma_score = (score - score_mean[None]) / score_std[None]

In [None]:
ordered_pairs['sigma_score_1'] = sigma_score[ordered_pairs['index_1']]
ordered_pairs['sigma_score_2'] = sigma_score[ordered_pairs['index_2']]

In [None]:
pyplot.figure(figsize=(12, 4))

pyplot.subplot(1, 2, 1)

_ = pyplot.hist(score, bins=100, density=True)

pyplot.subplot(1, 2, 2)

_ = pyplot.hist(ordered_pairs['sigma_score_1'].values, bins=100, density=True, alpha=0.5, color='r')
_ = pyplot.hist(ordered_pairs['sigma_score_2'].values, bins=100, density=True, alpha=0.5, color='b')
pyplot.legend()

# save model

In [None]:
os.makedirs(os.path.split(WEIGHT_PATH)[0], exist_ok=True)
torch.save(model.state_dict(), WEIGHT_PATH)
np.savez(
    WEIGHT_PATH.replace('.pt', '.npz'), 
    mean=score_mean,
    std=score_std,
)

In [None]:
raise

# check

In [None]:
# MINIO_ADDRESS = "123.176.98.90:9000"
MINIO_ADDRESS = "192.168.3.5:9000"
access_key = "GXvqLWtthELCaROPITOG"
secret_key = "DmlKgey5u0DnMHP30Vg7rkLT0NNbNIGaM8IwPckD"
bucket_name = 'datasets'

In [None]:
sys.path.append(os.path.abspath('../kcg-ml-image-pipeline/'))

from utility.minio.cmd import connect_to_minio_client

In [None]:
client = connect_to_minio_client(MINIO_ADDRESS, access_key, secret_key)

In [None]:
def get_image_by_path(file_path):
    
    data = client.get_object(bucket_name=bucket_name, object_name=f'{file_path}.jpg')

    return Image.open(BytesIO(data.data))

def show_images(file_paths):

    num_rows = max(1, int(np.floor(len(file_paths) ** 0.5)))
    file_paths = file_paths[:num_rows * num_rows]

    target_size = 1024 // num_rows

    images = list()
    for file_path in tqdm(file_paths, leave=False):
        img = get_image_by_path(file_path)
        images.append(np.array(img.resize((target_size, target_size))))

    images = np.stack(images)
    images = images.reshape(num_rows, num_rows, target_size, target_size, 3)
    images = np.concatenate(np.concatenate(images, axis=-3), axis=-2)
    return Image.fromarray(images)

def show_pairs(file_paths_1, file_paths_2):

    n = len(file_paths_1)

    target_size = 1024 // n

    images_1 = list()
    for file_path in tqdm(file_paths_1, leave=False):
        img = get_image_by_path(file_path)
        images_1.append(np.array(img.resize((target_size, target_size))))

    images_1 = np.stack(images_1)
    images_1 = np.concatenate(images_1, axis=-2)

    images_2 = list()
    for file_path in tqdm(file_paths_2, leave=False):
        img = get_image_by_path(file_path)
        images_2.append(np.array(img.resize((target_size, target_size))))

    images_2 = np.stack(images_2)
    images_2 = np.concatenate(images_2, axis=-2)

    images = np.concatenate([images_1, images_2], axis=-3)
    
    return Image.fromarray(images)

## check pairs

In [None]:
selected = ordered_pairs.query('(sigma_score_1 - sigma_score_2 < -2) and sigma_score_2 > 1')
selected.head()

In [None]:
show_pairs(selected['image_1'][:8], selected['image_2'][:8])

# check conflicts

In [None]:
import networkx

In [None]:
graph = networkx.DiGraph()

In [None]:
# for img_1, img_2, sel_id in rank_pairs:
for img_1, img_2, sel_id in rank_pairs.itertuples(index=False, name=None):
    if sel_id == 0:
        graph.add_edge(img_2, img_1)
    else:
        graph.add_edge(img_1, img_2)

In [None]:
len(graph.nodes), len(graph.edges)

In [None]:
cycles = list(tqdm(networkx.simple_cycles(graph)))
len(cycles)

In [None]:
subgraphs = list(networkx.weakly_connected_components(graph))
len(subgraphs)

# check transitive relationship

In [None]:
trans_pairs = list()

for image_2, d in networkx.all_pairs_shortest_path_length(graph):
    
    for image_1, dist in d.items():
        
        if dist <= 1:
            continue
        
        trans_pairs.append((image_1, image_2, dist))
        
trans_pairs = pd.DataFrame(trans_pairs, columns=['image_1', 'image_2', 'dist'])
trans_pairs