First image-description dataset consists on 10 descriptions for 56 Genshin Impact characters.
We select central pose and 9 random images for each character.

Therefore, in total, the dataset will have 560 with 10 descriptions each, resulting in 5600 samples.


In order to obtain this dataset, we first cropped some images manually (reference).

We search which would be a similar cropping for the other images using CLIP, comparing CLIP features with those from the reference image of the same character.

In [1]:
import torch
import clip
import os
import pandas as pd
import random
from PIL import Image
from torchvision import transforms
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
import matplotlib.pyplot as plt
from einops import rearrange
import torch.nn.functional as F
%matplotlib inline


device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

  return torch._C._cuda_getDeviceCount() > 0


In [2]:
mpath = 'genshin_dataset'
original_ims = os.path.join(mpath, 'character_imgs')
dataset_path = os.path.join(mpath, 'dataset1')
#os.makedirs(dataset_path, exist_ok=False)

### Main functions

In [4]:
def _convert_image_to_rgb(image):
    return image.convert("RGB")

candidate_preprocess = Compose([
        _convert_image_to_rgb,
        ToTensor(),
        Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])

CROP_RES = 256
THRESHOLD = 0.8

In [5]:
def generate_crops(image_tensor, crop_res, overlap=0.25):
    h, w = img.shape[2:]
    #print(h, w)
    num_h = h // crop_res
    num_w = w // crop_res
    #print(num_h, num_w)

    # Calculate the stride
    stride = int(crop_res * (1 - overlap))

    # Calculate the number of crops along each dimension
    n_crops_y = (image_tensor.size(2) - crop_res) // stride + 1
    n_crops_x = (image_tensor.size(3) - crop_res) // stride + 1

    # Collect crops
    crops, centers = [], []
    for y in range(n_crops_y):
        for x in range(n_crops_x):
            start_y = y * stride
            start_x = x * stride
            crop = image_tensor[:, :, start_y:start_y + crop_res, start_x:start_x + crop_res]
            crops.append(crop)
            centers.append((start_y + crop_res//2, start_x + crop_res//2))

    # Convert to tensor containing all crops
    crops_tensor = torch.stack(crops)

    #print(f'Generated {crops_tensor.size(0)} crops of size {crop_res}x{crop_res}')
    return crops_tensor, centers

In [6]:
def get_winners(ref_features, crops, thres=THRESHOLD):
    winners = []
    ccrop_fn = CenterCrop(224)
    num_crops = crops.shape[0]
    with torch.no_grad():
        #ref_features = model.encode_image(crop_ref)
        for n in range(num_crops):
            cand_features = model.encode_image(ccrop_fn(crops[n]))
            #print(ref_features.shape, cand_features.shape)
            sim = F.normalize(ref_features, dim=-1) @ F.normalize(cand_features, dim=-1).T
            #print(sim.item())
            if sim >= thres:
                #print(f'Winner sim with {sim}')
                winners.append((sim, crops[n].squeeze(0), n))
        #print(f'Got {len(winners)} winners')
    return winners

### Create dataset

In [12]:
ref_path = os.path.join(mpath, 'face_crop')
df = pd.read_csv(os.path.join(mpath, 'face_crop_256.csv'))

CROP_SIZE = 400

train_dict = {}
ii=0

for idx, row in df.iterrows():
    character = row["character"]
    crop_ref = preprocess(Image.open(os.path.join(ref_path, row['cropped_img_256']))).unsqueeze(0).to(device)
    
    character_ims = [f for f in os.listdir(os.path.join(original_ims, character)) if 'zero' not in f and 'ood' not in f]
    character_ims = random.sample(character_ims, 9) + ['zero.png']

    for i_f, f in enumerate(character_ims):
        im_path = os.path.join(dataset_path, f"{character}_{i_f}.png")
        orig_im_path = os.path.join(original_ims, character, f)
        if not os.path.exists(im_path):
            oimg = Image.open(orig_im_path)
            img = candidate_preprocess(oimg).unsqueeze(0).to(device)
            
            crops, centers = generate_crops(img, CROP_RES, overlap=0.75)
            with torch.no_grad():
                ref_features = model.encode_image(crop_ref)
            winners = get_winners(ref_features, crops)
            winners = sorted(winners, reverse=True)
            try:
                ### Winner:
                win_crop = winners[0]
                win_center = centers[win_crop[2]]
                left = win_center[1] - CROP_SIZE//2
                right = left + CROP_SIZE
                top = win_center[0] - CROP_SIZE//2
                bottom = top + CROP_SIZE
                ## cropped image ##
                im1 = oimg.crop((left, top, right, bottom))
                
                im1.save(im_path)
            except:
                print(f"No winners for {character}, image {orig_im_path}")
                continue
        if os.path.exists(im_path):
            for desc in range(10):
                train_dict[ii] = [character, im_path, orig_im_path, row[f"desc{desc}"]]
                ii+=1
    
print(len(train_dict.keys()))
train_df = pd.DataFrame.from_dict(train_dict, orient='index', columns=['character', 'im_path', 'orig_path', 'description',])

No winners for xinyan, image genshin_dataset/character_imgs/xinyan/Captura de pantalla (699).png
No winners for sayu, image genshin_dataset/character_imgs/sayu/Captura de pantalla (493).png
No winners for kaeya, image genshin_dataset/character_imgs/kaeya/Captura de pantalla (1187).png
5570


In [15]:
print(len(train_df))
train_df

5570


Unnamed: 0,character,im_path,orig_path,description
0,yoimiya,genshin_dataset/dataset1/yoimiya_0.png,genshin_dataset/character_imgs/yoimiya/Captura...,A young woman with light blonde hair tied in a...
1,yoimiya,genshin_dataset/dataset1/yoimiya_0.png,genshin_dataset/character_imgs/yoimiya/Captura...,"A character with blonde hair, dressed in a red..."
2,yoimiya,genshin_dataset/dataset1/yoimiya_0.png,genshin_dataset/character_imgs/yoimiya/Captura...,A woman with light blonde hair tied in a ponyt...
3,yoimiya,genshin_dataset/dataset1/yoimiya_0.png,genshin_dataset/character_imgs/yoimiya/Captura...,"A character with blonde hair, dressed in a red..."
4,yoimiya,genshin_dataset/dataset1/yoimiya_0.png,genshin_dataset/character_imgs/yoimiya/Captura...,A young woman with blonde hair tied in a high ...
...,...,...,...,...
5565,ganyu,genshin_dataset/dataset1/ganyu_9.png,genshin_dataset/character_imgs/ganyu/zero.png,A mysterious female character with short blue ...
5566,ganyu,genshin_dataset/dataset1/ganyu_9.png,genshin_dataset/character_imgs/ganyu/zero.png,A blue-haired character with large purple eyes...
5567,ganyu,genshin_dataset/dataset1/ganyu_9.png,genshin_dataset/character_imgs/ganyu/zero.png,"A girl with pastel blue hair in soft waves, re..."
5568,ganyu,genshin_dataset/dataset1/ganyu_9.png,genshin_dataset/character_imgs/ganyu/zero.png,A young woman with ethereal blue hair and lumi...


In [16]:
train_df.to_csv(os.path.join(mpath, 'dataset1.csv'), index=False)