In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import sys
import json
import glob
import torch

import numpy as np
import pandas as pd

from PIL import Image

from tqdm.auto import tqdm

from sklearn.model_selection import train_test_split
from sklearn.metrics.pairwise import cosine_similarity

In [2]:
from matplotlib import pyplot
import seaborn as sns

In [3]:
from optimizers import Adan, Lookahead, AGC
from utils import get_score_from_embs

In [4]:
ROOT = '../kcg-ml-image-pipeline/output/dataset/'

dataset_name = 'environmental'
# dataset_name = 'character'
# dataset_name = 'mech'
# dataset_name = 'icons'
# dataset_name = 'waifu'
# dataset_name = 'propaganda-poster'

# use_positive = True
# use_negative = True
# pooling_method = 'pooler_outputs'

# EMB_PATH = os.path.join('./data', dataset_name, 'clip_text_emb.npz')
# WEIGHT_PATH = os.path.join('weight/004', dataset_name, 'clip_text.pt')
# WEIGHT_PATH = os.path.join('weight/004', dataset_name, 'clip_positive.pt')
# WEIGHT_PATH = os.path.join('weight/004', dataset_name, 'clip_negative.pt')

use_positive = True
use_negative = False
pooling_method = 'image_embeds'

EMB_PATH = os.path.join('./data', dataset_name, 'clip_vision_emb.npz')
WEIGHT_PATH = os.path.join('weight/004', dataset_name, 'clip_vision.pt')

# load emb

In [5]:
data = np.load(EMB_PATH, allow_pickle=True)

file_paths = data['file_paths']
file_paths = [os.path.splitext(file_path.split('_')[0])[0] for file_path in file_paths]
path_to_index = {file_path: i for i, file_path in enumerate(file_paths)}

# load rank data

In [6]:
paths = sorted(glob.glob(os.path.join(ROOT, 'ranking', dataset_name, '*.json')))

rank_pairs = list()
for path in tqdm(paths):
    js = json.load(open(path))
    
    file_path_1 = os.path.splitext(js['image_1_metadata']['file_path'])[0].replace('datasets/', '')
    file_path_2 = os.path.splitext(js['image_2_metadata']['file_path'])[0].replace('datasets/', '')
    
    if (file_path_1 not in path_to_index) or (file_path_2 not in path_to_index):
        continue
    rank_pairs.append((file_path_1, file_path_2, js['selected_image_index']))

  0%|          | 0/71079 [00:00<?, ?it/s]

# build dataset

In [7]:
rank_pairs = pd.DataFrame(rank_pairs, columns=['image_1', 'image_2', 'selected_image_index'])

ordered_pairs = [((image_1, image_2) if selected_image_index == 0 else (image_2, image_1)) for image_1, image_2, selected_image_index in rank_pairs.itertuples(index=False, name=None)]
ordered_pairs = pd.DataFrame(ordered_pairs, columns=['image_1', 'image_2'])

In [8]:
train_pairs, val_pairs = train_test_split(ordered_pairs, test_size=0.2, random_state=42)

In [9]:
if pooling_method == 'pooler_outputs':
    positive_pooler_outputs = data['positive_pooler_outputs']
    negative_pooler_outputs = data['negative_pooler_outputs']
elif pooling_method == 'image_embeds':
    image_embeds = data['image_embeds'].astype('float32')

In [10]:
# positive_last_hidden_states = data['positive_last_hidden_states']
# positive_attention_masks = data['positive_attention_masks']
# negative_last_hidden_states = data['negative_last_hidden_states']
# negative_attention_masks = data['negative_attention_masks']

# positive_last_hidden_states_ave_mean = positive_last_hidden_states.mean(axis=1)
# negative_last_hidden_states_ave_mean = negative_last_hidden_states.mean(axis=1)

# positive_last_hidden_states_ave_mean_with_mask = (positive_last_hidden_states * positive_attention_masks[..., None]).sum(axis=1) / positive_attention_masks.sum(axis=-1)[..., None]
# negative_last_hidden_states_ave_mean_with_mask = (negative_last_hidden_states * negative_attention_masks[..., None]).sum(axis=1) / negative_attention_masks.sum(axis=-1)[..., None]

## build feature

In [11]:
def build_feature(index_1, index_2):
    
    if pooling_method == 'pooler_outputs':
        pos_features = positive_pooler_outputs
        neg_features = negative_pooler_outputs
    elif pooling_method == 'ave_mean':
        pos_features = positive_last_hidden_states_ave_mean
        neg_features = negative_last_hidden_states_ave_mean
    # elif pooling_method == 'ave_mean_with_mask':
    #     pos_features = positive_last_hidden_states_ave_mean_with_mask
    #     neg_features = negative_last_hidden_states_ave_mean_with_mask
    elif pooling_method == 'image_embeds':
        pos_features = image_embeds
    
    results = list()
    if use_positive:
        results.append(np.stack([pos_features[index_1], pos_features[index_2]], axis=-1))
    if use_negative:
        results.append(np.stack([neg_features[index_1], neg_features[index_2]], axis=-1))
        
    if len(results) == 1:
        return results[0]
    else:
        return np.concatenate(results, axis=0)

In [12]:
train_data = list()
for image_1, image_2 in train_pairs.itertuples(index=False, name=None):
    index_1, index_2 = path_to_index[image_1], path_to_index[image_2]
    train_data.append(build_feature(index_1, index_2))
train_data = np.stack(train_data, axis=0)

val_data = list()
for image_1, image_2 in val_pairs.itertuples(index=False, name=None):
    index_1, index_2 = path_to_index[image_1], path_to_index[image_2]
    val_data.append(build_feature(index_1, index_2))
val_data = np.stack(val_data, axis=0)

train_data.shape, val_data.shape

((56863, 768, 2), (14216, 768, 2))

In [13]:
train_dataset = torch.tensor(train_data).half().cuda()
val_dataset = torch.tensor(val_data).half().cuda()
# ext_dataset = torch.tensor(ext_data).half().cuda()

# build model

In [14]:
# model = torch.nn.Linear(train_data.shape[1], 1, bias=True)
model = torch.nn.Linear(train_data.shape[1], 1, bias=True)

In [15]:
# model.load_state_dict(torch.load(WEIGHT_PATH))

In [16]:
# model = torch.nn.DataParallel(model.cuda())
model = model.cuda()

# train model

In [17]:
LR = 1e-3

parameters = list(model.parameters())

optimizer = Adan(parameters, lr=LR, weight_decay=1e-3)
optimizer = Lookahead(optimizer)
# optimizer = AGC(optimizer)
warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, [lambda step: step / 100. if step < 100 else 1.])

In [18]:
scaler = torch.cuda.amp.GradScaler()
    
bces, accs = list(), list()

for epoch in tqdm(range(1000)):
    
    model.train()
    
    label = torch.zeros((train_dataset.shape[0],), device='cuda')
    
    x = train_dataset

    optimizer.zero_grad()

    with torch.cuda.amp.autocast(True):

        y0 = model(x[..., 0])
        y1 = model(x[..., 1])

        y = torch.concat([y0, y1], dim=-1)

    # backward

    bce = torch.nn.functional.cross_entropy(y, label.long())

    acc = (y0 > y1).float().mean()

    l1 = torch.norm(model.weight, p=1)

    loss = bce + l1 * 1e-3

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    warmup.step()

    bces.append(bce.detach().cpu().numpy())
    accs.append(acc.detach().cpu().numpy())
        
    if (epoch + 1) % 100 == 0:
        
        model.eval()

        val_bces, val_accs = list(), list()

        with torch.no_grad():
            
            x = val_dataset

            with torch.cuda.amp.autocast(True):

                y0 = model(x[..., 0])
                y1 = model(x[..., 1])

                y = torch.concat([y0, y1], dim=-1)

            label = torch.zeros((y.shape[0],), device='cuda').long()

            bce = torch.nn.functional.cross_entropy(y, label)

            acc = (y.argmax(dim=-1) == 0).float().mean()

            val_bces.append(bce.detach().cpu().numpy())
            val_accs.append(acc.detach().cpu().numpy())

        print(f'{np.mean(bces):.4f} {np.mean(accs):.4f} {np.mean(val_bces):.4f} {np.mean(val_accs):.4f}')
    
        bces, accs = list(), list()

  0%|          | 0/1000 [00:00<?, ?it/s]



0.4983 0.7401 0.2949 0.8734
0.2451 0.8962 0.2277 0.9048
0.2172 0.9087 0.2150 0.9091
0.2075 0.9144 0.2091 0.9144
0.2030 0.9167 0.2056 0.9169
0.1998 0.9181 0.2035 0.9177
0.1976 0.9191 0.2020 0.9181
0.1963 0.9201 0.2010 0.9183
0.1952 0.9203 0.2006 0.9184
0.1947 0.9205 0.2006 0.9177


# save model

In [19]:
os.makedirs(os.path.split(WEIGHT_PATH)[0], exist_ok=True)
torch.save(model.state_dict(), WEIGHT_PATH)

# analysis

## score distribution

In [21]:
score = get_score_from_embs(image_embeds, model, batch_size=1024)
# score = get_score_from_embs(np.concatenate([positive_pooler_outputs, negative_pooler_outputs], axis=-1), model, batch_size=1024)
# score = get_score_from_embs(positive_pooler_outputs, model, batch_size=1024)
# score = get_score_from_embs(negative_pooler_outputs, model, batch_size=1024)

  0%|          | 0/104 [00:00<?, ?it/s]

In [22]:
json.dump({
    'mean': float(score.mean()),
    'std': float(score.std()),
}, open(WEIGHT_PATH.replace('.pt', '_stats.json'), 'wt'))

In [None]:
_ = pyplot.hist(score, bins=100, density=True)

## val set distribution

In [None]:
y0 = get_score_from_embs(val_dataset[..., 0].float(), model, 1024)
y1 = get_score_from_embs(val_dataset[..., 1].float(), model, 1024)

delta = y0 - y1

In [None]:
pyplot.figure(figsize=(10, 4))

pyplot.subplot(1, 2, 1)

correct = (y0 > y1)

_ = pyplot.hist(y0[correct], bins=100, color='r', alpha=0.5)
_ = pyplot.hist(y1[correct], bins=100, color='b', alpha=0.5)

pyplot.title('score distribution of correct pairs')

pyplot.subplot(1, 2, 2)

_ = pyplot.hist(y0[~correct], bins=100, color='r', alpha=0.5)
_ = pyplot.hist(y1[~correct], bins=100, color='b', alpha=0.5)

pyplot.title('score distribution of incorrect pairs')

## prepare for visualization

In [None]:
def select_samples(indices, n_select):
    
    selected = np.random.choice(indices, n_select, False)
    
    selected_file_paths = [file_paths[i] for i in selected]
    selected_file_paths = [i.split('_')[0] + '.jpg' for i in selected_file_paths]
    
    images = np.stack([np.array(Image.open(os.path.join('../kcg-ml-image-pipeline/output/dataset/image/', i))) for i in selected_file_paths])
    images = images.reshape(-1, int(n_select ** 0.5), *images.shape[-3:])
    images = np.concatenate(np.concatenate(images, axis=-3), axis=-2)
    
    images = Image.fromarray(images).resize((512, 512))
    
    return selected, images

## show top score samples

In [None]:
threshold = np.quantile(score, q=0.95)
selected, images = select_samples(np.arange(score.shape[0])[score > threshold], n_select=9)

threshold

In [None]:
images

## show lowest score samples

In [None]:
threshold = np.quantile(score, q=0.05)
selected, images = select_samples(np.arange(score.shape[0])[score < threshold], n_select=9)

threshold

In [None]:
images

## show lowest delta samples

In [None]:
n_select = 6

# threshold = np.quantile(val_delta, q=0.05)
# indices = val_pairs.index[np.arange(val_delta.shape[0])[val_delta < threshold]]
# selected = np.random.choice(indices, n_select, False)

selected = val_pairs.index[np.argsort(delta)[:n_select]]

indices_1 = [path_to_index[i] for i in val_pairs.loc[selected, 'image_1']]
indices_2 = [path_to_index[i] for i in val_pairs.loc[selected, 'image_2']]

In [None]:
selected_file_paths = [file_paths[i] for i in indices_1]
selected_file_paths = [i.split('_')[0] + '.jpg' for i in selected_file_paths]

images = np.stack([np.array(Image.open(os.path.join('../kcg-ml-image-pipeline/output/dataset/image/', i))) for i in selected_file_paths])
images_1 = np.concatenate(images, axis=-2)

selected_file_paths = [file_paths[i] for i in indices_2]
selected_file_paths = [i.split('_')[0] + '.jpg' for i in selected_file_paths]

images = np.stack([np.array(Image.open(os.path.join('../kcg-ml-image-pipeline/output/dataset/image/', i))) for i in selected_file_paths])
images_2 = np.concatenate(images, axis=-2)

images = np.concatenate([images_1, images_2], axis=-3)

images = Image.fromarray(images).resize((512 * n_select // 2, 512))

In [None]:
images

# check conflicts

In [None]:
import networkx

In [None]:
graph = networkx.DiGraph()

In [None]:
for img_1, img_2, sel_id in rank_pairs:
    if sel_id == 0:
        graph.add_edge(img_2, img_1)
    else:
        graph.add_edge(img_1, img_2)

In [None]:
len(graph.nodes), len(graph.edges)

In [None]:
cycles = list(networkx.simple_cycles(graph))
len(cycles)

In [None]:
subgraphs = list(networkx.weakly_connected_components(graph))
len(subgraphs)

# check transitive relationship

In [None]:
trans_pairs = list()

for image_2, d in networkx.all_pairs_shortest_path_length(graph):
    
    for image_1, dist in d.items():
        
        if dist <= 1:
            continue
        
        trans_pairs.append((image_1, image_2, dist))
        
trans_pairs = pd.DataFrame(trans_pairs, columns=['image_1', 'image_2', 'dist'])
trans_pairs