In [1]:
import os
import json
from PIL import Image

import torch
from transformers import CLIPProcessor, CLIPModel

import numpy as np
import pandas as pd
from tqdm.auto import tqdm
tqdm.pandas()

RANDOM_STATE = 0

# Data

In [2]:
with open("../data/href2details.json", encoding="utf-8") as f:
    href2details = json.load(f)
with open("../data/num2details.json", encoding="utf-8") as f:
    num2details = json.load(f)

In [16]:
data = pd.DataFrame()

In [17]:
root = "../data/posters"
EXT = ".jpg"

data["rel_path"] = [f"{root}/{fname}" for fname in os.listdir(root) if fname.endswith(EXT)]

data["filename"] = data["rel_path"].apply(os.path.basename)
data["href"] = data.filename.apply(
    lambda x: x.split("_")[0]
)
data["Country"] = data["href"].apply(
    lambda x: href2details[f"/{x}"]["Country:"][len("Country:"):].strip()
)

In [35]:
data["Country"].value_counts()

South Korea    5825
Japan          5590
China          4061
Thailand       1744
Taiwan          682
Hong Kong       678
Philippines     460
Name: Country, dtype: int64

In [75]:
sample = data.groupby("Country", group_keys=False).apply(
    lambda x: x.sample(
        data["Country"].value_counts().to_list()[-1]
        , random_state=RANDOM_STATE
    )
).reset_index(drop=True)

sample.head(1)

Unnamed: 0,rel_path,filename,href,Country
0,../data/posters/27370-transmission_JxZqXc.jpg,27370-transmission_JxZqXc.jpg,27370-transmission,China


In [76]:
sample.shape, sample["Country"].value_counts()

((3220, 4),
 China          460
 Hong Kong      460
 Japan          460
 Philippines    460
 South Korea    460
 Taiwan         460
 Thailand       460
 Name: Country, dtype: int64)

# Zero-shot

In [94]:
# labels = ['China', 'Hong Kong', 'Japan', 'Philippines', 'South Korea', 'Taiwan', 'Thailand']
labels = sorted(data["Country"].unique())
labels

['China',
 'Hong Kong',
 'Japan',
 'Philippines',
 'South Korea',
 'Taiwan',
 'Thailand']

In [85]:
model_id = "openai/clip-vit-base-patch32"
cache_dir = "../models/"

model = CLIPModel.from_pretrained(model_id, cache_dir=cache_dir)
processor = CLIPProcessor.from_pretrained(model_id, cache_dir=cache_dir)

In [108]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    model.to(device)
model.device

device(type='cuda', index=0)

In [110]:
def predict(path2image, labels, processor, model, device):
    image = Image.open(path2image)
    with torch.no_grad():
        
        inputs = processor(
            text=labels, images=image, return_tensors="pt", padding=True
        ).to(device)
        
        outputs = model(**inputs)
        logits_per_image = outputs.logits_per_image # this is the image-text similarity score
        probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
        return labels[probs.argmax()]
    
def embed()

In [127]:
inputs["pixel_values"].squeeze(0).mean(axis=0).

torch.Size([50176])

In [119]:
from sklearn.model_selection import StratifiedKFold

X = sample["rel_path"]
y = sample["Country"]

skf = StratifiedKFold(n_splits=3, shuffle=True, random_state=RANDOM_STATE)
# skf.get_n_splits(X, y)

for i, (train_index, test_index) in enumerate(skf.split(X, y)):
    
    fold_test = sample.loc[test_index]
    
    fold_test["pred"] = fold_test["rel_path"].progress_apply(
        lambda x: predict(x, labels, processor, model, device)
    )
    
    print(classification_report(fold_test["Country"], fold_test["pred"]))

  0%|          | 0/1074 [00:00<?, ?it/s]

              precision    recall  f1-score   support

       China       0.35      0.47      0.40       154
   Hong Kong       0.35      0.05      0.09       153
       Japan       0.63      0.75      0.68       153
 Philippines       0.87      0.75      0.80       154
 South Korea       0.59      0.91      0.72       153
      Taiwan       0.43      0.58      0.49       153
    Thailand       0.92      0.52      0.66       154

    accuracy                           0.57      1074
   macro avg       0.59      0.57      0.55      1074
weighted avg       0.59      0.57      0.55      1074



  0%|          | 0/1073 [00:00<?, ?it/s]

              precision    recall  f1-score   support

       China       0.32      0.46      0.38       153
   Hong Kong       0.34      0.07      0.12       154
       Japan       0.65      0.74      0.69       153
 Philippines       0.88      0.58      0.70       153
 South Korea       0.53      0.91      0.67       154
      Taiwan       0.40      0.49      0.44       153
    Thailand       0.89      0.54      0.67       153

    accuracy                           0.54      1073
   macro avg       0.57      0.54      0.52      1073
weighted avg       0.57      0.54      0.52      1073



  0%|          | 0/1073 [00:00<?, ?it/s]

              precision    recall  f1-score   support

       China       0.32      0.46      0.37       153
   Hong Kong       0.21      0.05      0.08       153
       Japan       0.59      0.70      0.64       154
 Philippines       0.89      0.73      0.80       153
 South Korea       0.57      0.84      0.68       153
      Taiwan       0.33      0.40      0.36       154
    Thailand       0.87      0.52      0.65       153

    accuracy                           0.53      1073
   macro avg       0.54      0.53      0.51      1073
weighted avg       0.54      0.53      0.51      1073



In [128]:
# posterpath = "../data/posters/not_released_yet/27319-the-golden-hairpin_k644m_4c.jpg"
posterpath = "../data/posters/714369-pepero-was-taken-away-on-pepero-day_e4J7K_4c.jpg"
posterpath = "../data/posters/711027-enhypen-en-log_BLYz5_4c.jpg"
posterpath = "../data/posters/693057-siwon-s-fortune-cookie_XND8q_4c.jpg"
posterpath = "../data/posters/77567-fight-for-love_k36Nw_4c.jpg" # China
posterpath = "../data/posters/680631-ka-sunscreen_Rg3oo_4c.jpg" # Thailand
posterpath = "../data/posters/79321-the-serpents-song_BNXKb_4c.jpg" # Thailand
image = Image.open(posterpath)
# display(image)

In [None]:
with torch.no_grad():
    
    inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
    outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image # this is the image-text similarity score
    probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities

In [88]:
sorted(zip(probs.flatten(), labels), reverse=True)

[(tensor(0.4053), 'South Korea'),
 (tensor(0.2192), 'Japan'),
 (tensor(0.1207), 'China'),
 (tensor(0.1014), 'Hong Kong'),
 (tensor(0.0665), 'Thailand'),
 (tensor(0.0563), 'Philippines'),
 (tensor(0.0305), 'Taiwan')]

# Not released yet

In [4]:
for k, v in num2details.items():
    href2details[f"/{k}"] = v

In [5]:
len(href2details)

31454

In [3]:
data = pd.DataFrame()

In [7]:
root = "../data/posters/not_released_yet"
EXT = ".jpg"

data["rel_path"] = [f"{root}/{fname}" for fname in os.listdir(root) if fname.endswith(EXT)]

data["filename"] = data["rel_path"].apply(os.path.basename)
data["href"] = data.filename.apply(
    lambda x: x.split("_")[0]
)



data["Country"] = data["href"].apply(
    lambda x: href2details[f"/{x}"]["Country:"][len("Country:"):].strip()
)

In [8]:
data["Country"].value_counts()

South Korea    250
China           36
Thailand        34
Japan           24
Taiwan           4
Hong Kong        2
Name: Country, dtype: int64

In [11]:
data["Country"].value_counts()

South Korea    280
China           41
Thailand        35
Japan           31
Hong Kong        4
Taiwan           4
Name: Country, dtype: int64

In [14]:
data = data[data["Country"].apply(lambda x: x not in {'Hong Kong', 'Taiwan'})]

In [15]:
data["Country"].value_counts()

South Korea    280
China           41
Thailand        35
Japan           31
Name: Country, dtype: int64

In [16]:
sample = data.groupby("Country", group_keys=False).apply(
    lambda x: x.sample(
        data["Country"].value_counts().to_list()[-1]
        , random_state=RANDOM_STATE
    )
).reset_index(drop=True)

sample.head(1)

Unnamed: 0,rel_path,filename,href,Country
0,../data/posters/not_released_yet/737225-jie-zi...,737225-jie-zi-gui-cheng_r4qqZ_4c.jpg,737225-jie-zi-gui-cheng,China


In [17]:
sample.shape, sample["Country"].value_counts()

((124, 4),
 China          31
 Japan          31
 South Korea    31
 Thailand       31
 Name: Country, dtype: int64)