In [None]:
import os
import json
from typing import List, Dict
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from contextlib import nullcontext

# ====== 설정 ======
INPUT_JSON = "dataset_fl.json"
OUTPUT_JSONL = "dataset_fl_emd.jsonl"

MODEL_NAME = "bert-base-uncased"
BATCH_SIZE = 64
MAX_LENGTH = 256
USE_FP16 = True
NUM_WORKERS = 0
PIN_MEMORY = False

# ====== 0) 입력 파일 로더 (형식 자동 판별) ======
def load_records_any(path: str) -> List[Dict]:

    with open(path, "r", encoding="utf-8") as f:
        head = f.read(2048)
    head_stripped = head.lstrip()

    # 케이스 1: 리스트 JSON
    if head_stripped.startswith('['):
        with open(path, "r", encoding="utf-8") as f:
            return json.load(f)

    # 케이스 2: JSONL 시도
    recs = []
    failed_jsonl = False
    with open(path, "r", encoding="utf-8") as f:
        for i, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue
            try:
                recs.append(json.loads(line))
            except json.JSONDecodeError:
                failed_jsonl = True
                break
    if not failed_jsonl and recs:
        return recs

    # 케이스 3: Concatenated JSON 보정
    with open(path, "r", encoding="utf-8") as f:
        raw = f.read().strip()
    fixed = "[" + raw.replace("}\n{", "},{").replace("}{", "},{") + "]"
    try:
        return json.loads(fixed)
    except json.JSONDecodeError as e:
        preview = raw[:800]
        raise RuntimeError(
            "입력 파일 파싱 실패: JSON, JSONL, concat JSON 모두 불가. "
            f"원인: {e}\n파일 앞부분 미리보기:\n{preview}"
        )

# ====== 1) Dataset / Collate ======
class ReviewDataset(Dataset):
    def __init__(self, records: List[Dict]):
        self.records = records
    def __len__(self):
        return len(self.records)
    def __getitem__(self, idx):
        r = self.records[idx]
        return {
            "review_id": r["review_id"],
            "user_id": r.get("user_id"),
            "business_id": r.get("business_id"),
            "rating": r.get("review_stars"), 
            "review_text": r.get("review_text", "")
        }

def make_collate_fn(tokenizer):
    def collate(batch: List[Dict]):
        texts = [b.get("review_text") or "" for b in batch]
        enc = tokenizer(
            texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=MAX_LENGTH
        )
        enc = {k: v.cpu() for k, v in enc.items()}
        meta = [{k: v for k, v in b.items() if k != "review_text"} for b in batch]
        return {"meta": meta, "inputs": enc}
    return collate

# ====== 2) 데이터 로드 ======
data = load_records_any(INPUT_JSON)
print(f"로드 완료: {len(data)}개 레코드")

# ====== 3) 모델 준비 ======
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME).to(device).eval()

ds = ReviewDataset(data)
dl = DataLoader(
    ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    collate_fn=make_collate_fn(tokenizer)
)

# ====== 4) 추론 & 저장(JSONL) ======
if os.path.exists(OUTPUT_JSONL):
    os.remove(OUTPUT_JSONL)

if device.type == "cuda" and USE_FP16:
    autocast_ctx = torch.cuda.amp.autocast
else:
    autocast_ctx = nullcontext

with torch.no_grad():
    with open(OUTPUT_JSONL, "a", encoding="utf-8") as wf:
        pbar = tqdm(dl, desc="Encoding (CLS)", unit="batch", dynamic_ncols=True)
        for batch in pbar:
            inputs = {k: v.to(device) for k, v in batch["inputs"].items()}

            with autocast_ctx():
                out = model(**inputs)
                cls = out.last_hidden_state[:, 0, :]

            cls = cls.detach().cpu().tolist()
            for meta, vec in zip(batch["meta"], cls):
                rec = {
                    "user_id": meta.get("user_id"),
                    "business_id": meta.get("business_id"),
                    "stars": meta.get("rating"),
                    "bert_embedding": vec
                }
                wf.write(json.dumps(rec, ensure_ascii=False) + "\n")

print(f"✅ 완료: {OUTPUT_JSONL}")

In [None]:
import json

N = 5 # 확인할 줄 수

record_count = 0
with open("dataset_fl.json", "r", encoding="utf-8") as f:
    for line in f:
        record_count += 1

    f.seek(0)

    print(f"✅ 전체 리뷰 개수: {record_count}개")
    print("-" * 30)
    print(f"📄 처음 {N}개 레코드 미리보기:")

    # 3. 처음 N개 레코드 출력
    for i, line in enumerate(f):
        if i >= N:
            break
        try:
            record = json.loads(line)
            print(json.dumps(record, indent=2, ensure_ascii=False))
        except json.JSONDecodeError:
            print(f"오류: {i+1}번째 줄은 올바른 JSON 형식이 아닙니다.")
            print(line.strip())