In [1]:
# 파이썬 ≥3.5 필수
import sys
assert sys.version_info >= (3, 5)

# 사이킷런 ≥0.20 필수
import sklearn
assert sklearn.__version__ >= "0.20"

try:
    # %tensorflow_version은 코랩 명령입니다.
    %tensorflow_version 2.x
    !pip install -q -U tfx
    print("패키지 호환 에러는 무시해도 괜찮습니다.")
except Exception:
    pass

# 텐서플로 ≥2.0 필수
import tensorflow as tf
from tensorflow import keras
assert tf.__version__ >= "2.0"

# 공통 모듈 임포트
import numpy as np
import os

# 노트북 실행 결과를 동일하게 유지하기 위해
np.random.seed(42)

# 깔끔한 그래프 출력을 위해
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)

# 그림을 저장할 위치
PROJECT_ROOT_DIR = "."
CHAPTER_ID = "data"
IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, "images", CHAPTER_ID)
os.makedirs(IMAGES_PATH, exist_ok=True)

def save_fig(fig_id, tight_layout=True, fig_extension="png", resolution=300):
    path = os.path.join(IMAGES_PATH, fig_id + "." + fig_extension)
    print("그림 저장:", fig_id)
    if tight_layout:
        plt.tight_layout()
    plt.savefig(path, format=fig_extension, dpi=resolution)

In [4]:
vocab = ['<1H OCEAN', 'INLAND', 'ISLAND', 'NEAR BAY', 'NEAR OCEAN']
indices=tf.range(len(vocab), dtype=tf.int64)
table_init=tf.lookup.KeyValueTensorInitializer(vocab, indices)
num_oov_buckets=2
table = tf.lookup.StaticVocabularyTable(table_init, num_oov_buckets)
table

<tensorflow.python.ops.lookup_ops.StaticVocabularyTable at 0x210342c2cd0>

In [8]:
categories=tf.constant(["NEAR BAY","DESERT","INLAND", "INLAND"])
cat_indices=table.lookup(categories)
table, categories, cat_indices

(<tensorflow.python.ops.lookup_ops.StaticVocabularyTable at 0x210342c2cd0>,
 <tf.Tensor: shape=(4,), dtype=string, numpy=array([b'NEAR BAY', b'DESERT', b'INLAND', b'INLAND'], dtype=object)>,
 <tf.Tensor: shape=(4,), dtype=int64, numpy=array([3, 5, 1, 1], dtype=int64)>)

In [7]:
cat_one_hot = tf.one_hot(cat_indices, depth=len(vocab)+num_oov_buckets)
cat_one_hot, len(vocab)+num_oov_buckets

(<tf.Tensor: shape=(4, 7), dtype=float32, numpy=
 array([[0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.]], dtype=float32)>,
 7)