In [1]:
import numpy as np
import os
import torch
from torchvision import transforms
from tqdm import tqdm

import pandas as pd
import PIL

import matplotlib.pyplot as plt


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
sd_version = '2_1'
pretrained_model = 'laion2b_s34b_b79k'
path_image_tensors = '../artstation_sd_' + sd_version + '_' + pretrained_model + '_ViT-B-32.pt'

In [3]:
import open_clip
from open_clip import tokenizer
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained=pretrained_model)

In [4]:
device = torch.device('cuda')
model.to(device)

CLIP(
  (visual): VisionTransformer(
    (patchnorm_pre_ln): Identity()
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (patch_dropout): Identity()
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): ModuleList(
        (0): ResidualAttentionBlock(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ls_1): Identity()
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): GELU(approximate='none')
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ls_2): Identity()
        )
        (1): ResidualAttentionBlock(
          (l

In [5]:
image_features_torch = torch.load(path_image_tensors)

In [6]:
# read csv file
import pandas as pd
df = pd.read_csv('../gen_images_artstation_1_5.csv')
df.head()

print(df.shape)

(3960, 5)


In [7]:
artists = df['artist'].unique()
text_features = []

with torch.no_grad():
    for artist in artists:
        prompt = "The following work is done in the style of " + artist
        text_tokens = tokenizer.tokenize(prompt)
        text_tokens = text_tokens.to(device)
        txt_feat = model.encode_text(text_tokens).float()
        text_features.append(txt_feat)
text_features_torch = torch.concatenate(text_features).cpu()

In [8]:
idx2artist = {}
artist2idx = {}
for i, artist in enumerate(artists):
    idx2artist[i] = artist
    artist2idx[artist] = i

In [9]:
image_features_torch /= image_features_torch.norm(dim=-1, keepdim=True)
text_features_torch /= text_features_torch.norm(dim=-1, keepdim=True)

text_probs = (100.0 * image_features_torch @ text_features_torch.T).softmax(dim=-1)
top_probs, top_k_labels = text_probs.cpu().topk(5, dim=-1)

In [10]:
gt_labels = torch.tensor([artist2idx[x] for x in df['artist'].to_list()])

top_one_labels = top_k_labels[:, 0]

print(gt_labels.shape, top_one_labels.shape)

correct = (gt_labels == top_one_labels).sum()
print(f"Top 1 score is {round((correct / gt_labels.shape[0]).item() * 100, 2)}")


topk_correct = 0
for i in range(5):
    top_one_labels = top_k_labels[:, i]
    correct = (gt_labels == top_one_labels).sum()
    topk_correct += correct
print(f"Top 5 score is {round((topk_correct / gt_labels.shape[0]).item() * 100, 2)}")
# correct = (gt_labels.repeat() == top_k_labels).sum()
# print(correct)

torch.Size([3960]) torch.Size([3960])
Top 1 score is 0.96
Top 5 score is 3.81


In [11]:
#compute the artist accuracy
artist_accuracy = {}
for i in range(len(artists)):
    artist_accuracy[artists[i]] = 0
for i in range(len(gt_labels)):
    if gt_labels[i] == top_one_labels[i]:
        artist_accuracy[artists[gt_labels[i]]] += 1
for i in range(len(artists)):
    artist_accuracy[artists[i]] /= len(df[df['artist'] == artists[i]])
    #multiply by 100 to get percentage
    artist_accuracy[artists[i]] *= 100
print(artist_accuracy)


{'WLOP': 0.0, 'Dao Trong Le': 0.0, 'Zeronis': 0.0, 'Chengwei Pan': 0.0, 'Wenjun Lin': 0.0, 'Grafit Studio': 0.0, 'Sylvain Sarrailh': 6.666666666666667, 'Greg Rutkowski': 0.0, 'Bayard Wu': 0.0, 'Bo Chen': 0.0, 'Tooth Wu': 0.0, '翼次方CG': 0.0, 'Anato Finnstark': 0.0, 'Qi Sheng Luo': 0.0, 'Raf Grassetti': 0.0, 'sparth': 0.0, 'Hicham Habchi': 0.0, 'Zhelong Xu': 0.0, 'Anthony Chong Jones': 0.0, 'Rudy Siswanto': 0.0, 'Nurzhan Bekkaliyev': 0.0, 'Suke ∷': 0.0, 'Jama Jurabaev': 6.666666666666667, 'Evan Lee': 0.0, 'Anatomy For Sculptors': 0.0, 'Christophe Young': 0.0, 'Eytan Zana': 0.0, 'Darek Zabrocki': 0.0, 'Jakub Rozalski': 0.0, 'Maria Panfilova': 0.0, 'Hou China': 0.0, 'Ismail Inceoglu': 6.666666666666667, 'Andreas Rocha': 6.666666666666667, 'Johnson Ting': 0.0, 'Ching Yeh': 0.0, 'Paul Chadeisson': 0.0, 'Igor Sid': 0.0, 'Jonas Ronnegard': 0.0, 'Nivanh Chanthara': 0.0, 'Thomas Chamberlain-Keen': 0.0, 'Steve Zheng': 0.0, 'Hue Teo': 0.0, 'Mauro Belfiore': 0.0, 'Raphael Lacoste': 6.666666666666667

In [12]:
#sort
sorted_accuracy = sorted(artist_accuracy.items(), key=lambda x: x[1], reverse=True)
print(sorted_accuracy)

[('Bastien Grivet', 13.333333333333334), ('Sylvain Sarrailh', 6.666666666666667), ('Jama Jurabaev', 6.666666666666667), ('Ismail Inceoglu', 6.666666666666667), ('Andreas Rocha', 6.666666666666667), ('Raphael Lacoste', 6.666666666666667), ('Anna Podedworna', 6.666666666666667), ('Victor Titov', 6.666666666666667), ('Marco Plouffe (Keos Masons)', 6.666666666666667), ('Anton Fadeev', 6.666666666666667), ('Finnian MacManus', 6.666666666666667), ('Krenz Cushart', 6.666666666666667), ('Hong SoonSang', 6.666666666666667), ('Wojtek Fus', 6.666666666666667), ('Ramón Nuñez', 6.666666666666667), ('Alejandro Burdisio', 6.666666666666667), ('Cedric Peyravernay', 6.666666666666667), ('Alex Konstad', 6.666666666666667), ('Nikolai Lockertsen', 6.666666666666667), ('Michal Lisowski', 6.666666666666667), ('Kittew', 6.666666666666667), ('jungmin jin /dospi', 6.666666666666667), ('WLOP', 0.0), ('Dao Trong Le', 0.0), ('Zeronis', 0.0), ('Chengwei Pan', 0.0), ('Wenjun Lin', 0.0), ('Grafit Studio', 0.0), ('Gr

In [13]:
#higher 10
print("Top 10")
for name, value in sorted_accuracy[:10]:
    print(name, round(value, 2))

Top 10
Bastien Grivet 13.33
Sylvain Sarrailh 6.67
Jama Jurabaev 6.67
Ismail Inceoglu 6.67
Andreas Rocha 6.67
Raphael Lacoste 6.67
Anna Podedworna 6.67
Victor Titov 6.67
Marco Plouffe (Keos Masons) 6.67
Anton Fadeev 6.67


In [14]:
#higher 10
print("Top 10")
for name, value in sorted_accuracy[:10]:
    print('|'+name+'|' + str(round(value, 2))+'%|')

Top 10
|Bastien Grivet|13.33%|
|Sylvain Sarrailh|6.67%|
|Jama Jurabaev|6.67%|
|Ismail Inceoglu|6.67%|
|Andreas Rocha|6.67%|
|Raphael Lacoste|6.67%|
|Anna Podedworna|6.67%|
|Victor Titov|6.67%|
|Marco Plouffe (Keos Masons)|6.67%|
|Anton Fadeev|6.67%|


In [15]:
#compute the prompts accuracy
prompts = df['prompt'].unique()
prompt_accuracy = {}
for i in range(len(prompts)):
    prompt_accuracy[prompts[i]] = 0
for i in range(len(gt_labels)):
    if gt_labels[i] == top_one_labels[i]:
        prompt_accuracy[df['prompt'].to_list()[i]] += 1
for i in range(len(prompts)):
    prompt_accuracy[prompts[i]] /= len(df[df['prompt'] == prompts[i]])
    #multiply by 100 to get percentage
    prompt_accuracy[prompts[i]] *= 100

In [16]:
#sort
sorted_accuracy = sorted(prompt_accuracy.items(), key=lambda x: x[1], reverse=True)
print(sorted_accuracy)

[('An exciting and adrenaline-fueled scene of a hot air balloon race over a vast canyon.', 1.1363636363636365), ('A spooky and eerie abandoned carnival scene with empty rides, broken lights, and a creepy clown lurking in the shadows.', 1.1363636363636365), ('A mystical and spiritual scene of a meditating monk in a temple surrounded by cherry blossoms.', 1.1363636363636365), ('A vibrant and colorful scene of a street market in Marrakech, Morocco with spices, textiles, and street performers.', 0.7575757575757576), ('A romantic and picturesque scene of a couple stargazing in a field at night, with shooting stars and the Milky Way overhead.', 0.7575757575757576), ('A mystical forest with a majestic unicorn standing in a clearing, surrounded by glowing flowers.', 0.7575757575757576), ('A surreal cityscape with floating buildings, a rainbow bridge, and a giant clocktower in the center.', 0.7575757575757576), ('A dynamic and exciting skateboard park with jumps, ramps, and half-pipes, surround

In [17]:
#get prompts with their mean accuracy
print("Top 5")
for name, value in sorted_accuracy[:5]:
    print(name, round(value, 2))
print("Bottom 5")
for name, value in sorted_accuracy[-5:]:
    print(name, round(value, 2))

Top 5
An exciting and adrenaline-fueled scene of a hot air balloon race over a vast canyon. 1.14
A spooky and eerie abandoned carnival scene with empty rides, broken lights, and a creepy clown lurking in the shadows. 1.14
A mystical and spiritual scene of a meditating monk in a temple surrounded by cherry blossoms. 1.14
A vibrant and colorful scene of a street market in Marrakech, Morocco with spices, textiles, and street performers. 0.76
A romantic and picturesque scene of a couple stargazing in a field at night, with shooting stars and the Milky Way overhead. 0.76
Bottom 5
A peaceful and romantic scene of a couple sitting in a gondola, being serenaded on a moonlit canal in Venice, Italy. 0.38
A cozy and festive holiday scene with a fireplace, stockings, and a Christmas tree decorated with lights and ornaments. 0.38
A romantic and dreamy scene of a couple in a hot air balloon, floating over a picturesque countryside. 0.38
A dreamy and surreal cloudscape scene with fluffy white clouds,

In [18]:
#get prompts with their mean accuracy
for name, value in sorted_accuracy[:5]:
    print('|'+name+'|' + str(round(value, 2))+'%|')
print("|...|...|")
for name, value in sorted_accuracy[-5:]:
    print('|'+name+'|' + str(round(value, 2))+'%|')

|An exciting and adrenaline-fueled scene of a hot air balloon race over a vast canyon.|1.14%|
|A spooky and eerie abandoned carnival scene with empty rides, broken lights, and a creepy clown lurking in the shadows.|1.14%|
|A mystical and spiritual scene of a meditating monk in a temple surrounded by cherry blossoms.|1.14%|
|A vibrant and colorful scene of a street market in Marrakech, Morocco with spices, textiles, and street performers.|0.76%|
|A romantic and picturesque scene of a couple stargazing in a field at night, with shooting stars and the Milky Way overhead.|0.76%|
|...|...|
|A peaceful and romantic scene of a couple sitting in a gondola, being serenaded on a moonlit canal in Venice, Italy.|0.38%|
|A cozy and festive holiday scene with a fireplace, stockings, and a Christmas tree decorated with lights and ornaments.|0.38%|
|A romantic and dreamy scene of a couple in a hot air balloon, floating over a picturesque countryside.|0.38%|
|A dreamy and surreal cloudscape scene with 

In [19]:
df = pd.DataFrame.from_dict(artist_accuracy, orient='index', columns=['accuracy'])
df.to_csv('../error_analysis/artists_artstation_error_analysis_sd_'+sd_version+'_'+pretrained_model+'_ViT-B-32.csv')

In [20]:
df = pd.DataFrame.from_dict(prompt_accuracy, orient='index', columns=['accuracy'])
df = df.sort_values(by=['accuracy'], ascending=False)
df.to_csv('../error_analysis/prompt_artstation_error_analysis_sd_'+sd_version+'_'+pretrained_model+'_ViT-B-32.csv')