In [1]:
import pandas as pd
import numpy as np

from my_modules.transform import get_transform
from my_modules.model import MyModel
from my_modules.dataset import ClusterTestDataset

import torch
from torch.utils.data import DataLoader

from sklearn.metrics import f1_score, accuracy_score
from tqdm import tqdm
from PIL import Image

In [2]:
import pickle
with open('pkl/result.pkl', 'rb') as f:
    result = pickle.load(f)

In [3]:
device='cuda'
model = MyModel('efficientnet-b0').to(device)
model.load_state_dict(torch.load('/opt/ml/code/save/cluster/log9_3.pt'))
model.eval()
print('model ready')

Loaded pretrained weights for efficientnet-b0
model ready


In [4]:
# get transform

transform = get_transform(augment=False, crop=350, resize=224, cutout=None)

In [5]:
# get datafraame

df_key = pd.read_csv('df/df_labled_valid.csv')
#df_query = pd.read_csv('/opt/ml/code/df/df_age_valid_20.csv')

In [6]:
# get dataset

ds_key = ClusterTestDataset(df_key, transform)
#ds_query = ClusterTestDataset(df_query, transform)

In [7]:
# get dataloader

dataloader_key = DataLoader(ds_key, batch_size=64, shuffle=False, num_workers=3)
#dataloader_query = DataLoader(ds_query, batch_size=64, shuffle=False, num_workers=3)

In [8]:
# create keys

model.create_keys(dataloader_key, device)

100%|██████████| 60/60 [00:07<00:00,  7.50it/s]


In [10]:
# ensemble prediction

df = df_key
n_neighbors = 5
y_pred = result['y_pred']
y_true = result['y_true']
ensembled_y_pred = []

for i in tqdm(range(len(df))):
    
    #y_true.append(df.iloc[i]['age'])
    
    a = [0, 0, 0]
    
    img = Image.open(df.iloc[i]['path'])
    img = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        queries = model.eff(img).cpu().numpy()

    groups_idx = model.query(queries, n_neighbors)[0]

    for idx in groups_idx:
        a[y_pred[idx]] += 1
    
    ensembled_y_pred.append(np.argmax(a))

f1 = f1_score(y_true, ensembled_y_pred, average='macro')
acc = accuracy_score(y_true, ensembled_y_pred)

print(f'acc:{acc:.4f}, f1:{f1:.4f}')

100%|██████████| 3780/3780 [01:19<00:00, 47.34it/s]

acc:0.9156, f1:0.8112





In [None]:
with open('pkl/gender_label.pkl', 'wb') as f:
    pickle.dump(ensembled_y_pred, f)

In [None]:
# fig, axes = plt.subplots(1, n + 1, figsize=(15, 5))
# axes[0].imshow(img_pil)
# axes[0].axis('off')
# axes[0].set_title('query')

# for ax, i in zip(axes[1:], groups_idx):
#     img = Image.open(df_key.iloc[i]['path'])
#     ax.imshow(img)
#     ax.axis('off')
#     ax.set_title(df_key.iloc[i]['gender'])

# plt.show()