In [None]:
import os
import csv
import torch
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from collections import defaultdict

# --- 1. 설정 (Train과 동일하게) ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME = "bert-base-uncased" # 학습할 때 쓴 모델명과 동일해야 함
MAX_LEN = 256
BATCH_SIZE = 32
NUM_CLASSES = 531
INPUT_MODEL_PATH = "saved_model/best_model.pt"
OUTPUT_CSV = "submission.csv"

# 경로 (사용자 환경에 맞게 수정)
BASE_DIR = "Amazon_products"
TEST_CORPUS_PATH = os.path.join(BASE_DIR, "test/test_corpus.txt") # 파일명 확인
HIERARCHY_PATH = os.path.join(BASE_DIR, "class_hierarchy.txt")

# --- 2. 헬퍼 함수들 ---
def load_test_corpus(path):
    pids, texts = [], []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split("\t", 1)
            if len(parts) == 2:
                pids.append(parts[0])
                texts.append(parts[1])
    return pids, texts

def load_hierarchy(path):
    parents = defaultdict(list)
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 2:
                p, c = int(parts[0]), int(parts[1])
                parents[c].append(p)
    return parents

class TestDataset(Dataset):
    def __init__(self, texts, tokenizer, max_len):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, item):
        text = str(self.texts[item])
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten()
        }

# --- 3. 추론 실행 ---
def run_inference():
    print("Loading Data & Model...")
    test_pids, test_texts = load_test_corpus(TEST_CORPUS_PATH)
    parents_map = load_hierarchy(HIERARCHY_PATH)
    
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    dataset = TestDataset(test_texts, tokenizer, MAX_LEN)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    # 모델 로드
    model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=NUM_CLASSES)
    model.load_state_dict(torch.load(INPUT_MODEL_PATH, map_location=DEVICE))
    model.to(DEVICE)
    model.eval()

    all_preds = []
    
    print("Predicting...")
    with torch.no_grad():
        for batch in tqdm(loader):
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            
            # 확률로 변환 (Sigmoid)
            probs = torch.sigmoid(logits)
            
            # 상위 K개 추출 (여기서는 Threshold 대신 Top-K 전략 사용 + 계층구조)
            # 과제에서 라벨이 2~3개라고 했으므로, 넉넉히 Top-3를 뽑고 부모를 추가하는 전략
            top_k = 3
            top_probs, top_indices = torch.topk(probs, k=top_k, dim=1)
            
            top_indices = top_indices.cpu().numpy()
            
            for idx_list in top_indices:
                final_labels = set(idx_list)
                
                # 계층 구조 보정 (자식이 있으면 부모도 반드시 포함)
                # 2~3번 반복해서 최상위 부모까지 찾아감
                for _ in range(3): 
                    current_ids = list(final_labels)
                    for cid in current_ids:
                        if cid in parents_map:
                            for pid in parents_map[cid]:
                                final_labels.add(pid)
                
                all_preds.append(sorted(list(final_labels)))

    # --- 4. CSV 저장 ---
    print(f"Saving submission to {OUTPUT_CSV}...")
    with open(OUTPUT_CSV, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(["id", "labels"])
        for pid, labels in zip(test_pids, all_preds):
            label_str = ",".join(map(str, labels))
            writer.writerow([pid, label_str])
    
    print("Done!")

if __name__ == "__main__":
    run_inference()

  from .autonotebook import tqdm as notebook_tqdm


Loading Data & Model...


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Predicting...


  0%|          | 0/615 [00:00<?, ?it/s]