In [1]:
import catboost as cb
import pandas as pd
from sklearn.model_selection import train_test_split


# 데이터셋 로드
df_grouped = pd.read_csv('/opt/ml/recipe_project/model/train_data.csv')

# # X와 y로 데이터 분할
X = df_grouped.drop("recipeid", axis=1)
y = df_grouped["recipeid"]


# 훈련 세트와 테스트 세트로 데이터 분할
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 모델 로드
loaded_model = cb.CatBoostClassifier()
loaded_model.load_model('catboost_model.bin')

<catboost.core.CatBoostClassifier at 0x7fe78b80ab90>

In [2]:
import catboost as cb
import numpy as np
import ast

loaded_model = cb.CatBoostClassifier()
loaded_model.load_model('catboost_model.bin')
label_mapping = {idx: label for idx, label in enumerate(loaded_model.classes_)}

In [3]:
def recommend_category(input_list, top_k = 10):
    input_list = np.array(ast.literal_eval(input_list))
    input_list = input_list / np.max(input_list)
    input_list = np.round(input_list, decimals=2)
    probabilities = loaded_model.predict_proba(input_list)
    top_k_classes = probabilities.argsort()[-top_k:].tolist()
    result = [label_mapping[label] for label in top_k_classes]
    return str(result)

# test
input_li = "[2, 2, 0, 2, 0, 1, 0, 3, 0, 0, 0, 0, 0, 0, 6, 1, 0, 0, 0, 2, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 2, 1, 0, 2, 0, 0, 0, 0, 1, 0, 2, 2, 0, 0, 0, 4, 1, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0]"
recommend_category(input_li, 10)

'[6874103, 6871728, 6879215, 6902125, 6899335, 6881099, 6906655, 6868389, 6894096, 6880798]'

# Accuracy test

In [None]:
TOP_K = 5
# 테스트 세트에 대한 예측 확률 계산
probabilities = loaded_model.predict_proba(X_test)
# 상위 top_k개 클래스 레이블 반환
top_k_classes = probabilities.argsort(axis=1)[:, -TOP_K:].tolist()
result = {}
for index, i in enumerate(top_k_classes):
    result[index] = i

# top_10_classes 리스트의 각 요소를 label_mapping을 사용하여 변경
converted_top_k_classes = []
for classes in top_k_classes:
    converted_classes = [label_mapping[label] for label in classes]
    converted_top_k_classes.append(converted_classes)

y_test_li = y_test.to_list()

cnt = 0
for y_test_value, top_k_value_list in zip(y_test_li, converted_top_k_classes):
    if y_test_value in top_k_value_list:
        cnt += 1


combined_list = []
for sublist in converted_top_k_classes:
    combined_list.extend(sublist)


print(f'top_{TOP_K}')
print(f'output_unique: {len(set(combined_list))}, origin_unique: {len(loaded_model.classes_)}')
print(f'correct: {cnt}, incorrect: {len(y_test) - cnt}')
print(f'accuracy: {round(cnt / len(y_test) * 100, 2)}%')

top_5
output_unique: 2096, origin_unique: 2361
correct: 4687, incorrect: 13869
accuracy: 25.26%


```
top_1
output_unique: 1336, origin_unique: 2361
correct: 1919, incorrect: 16637
accuracy: 10.34%
```

```
top_2
output_unique: 1693, origin_unique: 2361
correct: 2922, incorrect: 15634
accuracy: 15.75%
```

```
top_3
output_unique: 1885, origin_unique: 2361
correct: 3662, incorrect: 14894
accuracy: 19.73%
```

```
top_4
output_unique: 2000, origin_unique: 2361
correct: 4207, incorrect: 14349
accuracy: 22.67%
```

```
top_5
output_unique: 2096, origin_unique: 2361
correct: 4687, incorrect: 13869
accuracy: 25.26%
```

```
top_10
output_unique: 2280, origin_unique: 2361
correct: 6318, incorrect: 12238
accuracy: 34.05%
```

```
top_15
output_unique: 2326, origin_unique: 2361
correct: 7389, incorrect: 11167
accuracy: 39.82%
```

```
top_20
output_unique: 2346, origin_unique: 2361
correct: 8229, incorrect: 10327
accuracy: 44.35%
```

```
top_25
output_unique: 2351, origin_unique: 2361
correct: 8896, incorrect: 9660
accuracy: 47.94%
```

```
top_30
output_unique: 2356, origin_unique: 2361
correct: 9441, incorrect: 9115
accuracy: 50.88%
```

```
top_35
output_unique: 2358, origin_unique: 2361
correct: 9862, incorrect: 8694
accuracy: 53.15%
```

```
top_40
output_unique: 2360, origin_unique: 2361
correct: 10245, incorrect: 8311
accuracy: 55.21%
```

```
top_45
output_unique: 2360, origin_unique: 2361
correct: 10588, incorrect: 7968
accuracy: 57.06%
```

```
top_50
output_unique: 2361, origin_unique: 2361
correct: 10919, incorrect: 7637
accuracy: 58.84%
```

```
top_100
output_unique: 2361, origin_unique: 2361
correct: 12988, incorrect: 5568
accuracy: 69.99%
```

```
top_200
output_unique: 2361, origin_unique: 2361
correct: 14805, incorrect: 3751
accuracy: 79.79% 
```