# Data Sampling Utility

이 노트북은 전체 데이터셋에서 샘플 데이터를 생성합니다.

**목적:**
- 빠른 프로토타이핑 및 테스트
- DMF, MLP, CFNet 모두 동일한 샘플 데이터 사용
- 공정한 성능 비교

**생성 파일:**
- `{dataset}-sample{N}.train.rating`
- `{dataset}-sample{N}.test.rating`
- `{dataset}-sample{N}.test.negative`

**샘플링 전략:**
1. 전체 유저 중 랜덤하게 N명 선택
2. 선택된 유저의 모든 상호작용 포함
3. Test set도 동일 유저만 유지
4. User/Item ID 재매핑 (0부터 연속적으로)
5. Negative samples는 원본 데이터 형식 유지 (99개/테스트 케이스)

## 1. 설정 (Configuration)

In [None]:
import os
import random
from time import time

# ============================================================
# 데이터셋 설정
# ============================================================
DATA_PATH = 'datasets/'
DATASET = 'ml-1m'  # 원본 데이터셋 이름

# ============================================================
# 샘플링 설정
# ============================================================
SAMPLE_SIZES = [100, 500, 1000]  # 생성할 샘플 크기 리스트
SEED = 42  # 재현성을 위한 랜덤 시드

# ============================================================
# 기타 설정
# ============================================================
OVERWRITE = False  # True: 기존 파일 덮어쓰기, False: 존재하면 스킵

# 랜덤 시드 설정
random.seed(SEED)

print(f"Dataset: {DATASET}")
print(f"Sample sizes: {SAMPLE_SIZES}")
print(f"Random seed: {SEED}")
print(f"Overwrite existing: {OVERWRITE}")

## 2. 샘플링 함수 정의

In [None]:
def create_sample(dataset, sample_users, seed, data_path='datasets/', overwrite=False):
    """
    Create sample dataset from original data.
    
    Args:
        dataset (str): Dataset name (e.g., 'ml-1m')
        sample_users (int): Number of users to sample
        seed (int): Random seed
        data_path (str): Path to datasets directory
        overwrite (bool): If True, overwrite existing files
    
    Returns:
        dict: Statistics (num_users, num_items, num_train, num_test)
    """
    # Set random seed
    random.seed(seed)
    
    # Sample file paths
    sample_train = f"{data_path}{dataset}-sample{sample_users}.train.rating"
    sample_test = f"{data_path}{dataset}-sample{sample_users}.test.rating"
    sample_neg = f"{data_path}{dataset}-sample{sample_users}.test.negative"
    
    # Check if files already exist
    if not overwrite and os.path.exists(sample_train) and os.path.exists(sample_test) and os.path.exists(sample_neg):
        print(f"  ✓ Sample files already exist (sample{sample_users}). Skipping.")
        return None
    
    print(f"\nCreating sample data with {sample_users} users...")
    t_start = time()
    
    # Original file paths
    orig_train = f"{data_path}{dataset}.train.rating"
    orig_test = f"{data_path}{dataset}.test.rating"
    orig_neg = f"{data_path}{dataset}.test.negative"
    
    # 1. Read train file and extract all users
    print("  [1/6] Reading train file...")
    all_users = set()
    train_data = []
    with open(orig_train, 'r') as f:
        for line in f:
            arr = line.strip().split('\t')
            user, item, rating, timestamp = int(arr[0]), int(arr[1]), arr[2], arr[3]
            all_users.add(user)
            train_data.append((user, item, rating, timestamp))
    
    print(f"        Found {len(all_users)} total users, {len(train_data)} interactions")
    
    # 2. Sample SAMPLE_USERS randomly
    print(f"  [2/6] Sampling {sample_users} users...")
    sampled_users = set(random.sample(sorted(all_users), min(sample_users, len(all_users))))
    
    # 3. Filter train data and remap IDs
    print("  [3/6] Filtering and remapping IDs...")
    sampled_train = [(u, i, r, t) for u, i, r, t in train_data if u in sampled_users]
    
    # Extract unique users/items
    unique_users = sorted(set(u for u, _, _, _ in sampled_train))
    unique_items = sorted(set(i for _, i, _, _ in sampled_train))
    
    # Create ID mappings
    user_map = {old_id: new_id for new_id, old_id in enumerate(unique_users)}
    item_map = {old_id: new_id for new_id, old_id in enumerate(unique_items)}
    
    print(f"        Sampled: {len(unique_users)} users, {len(unique_items)} items, {len(sampled_train)} interactions")
    
    # 4. Write train file
    print("  [4/6] Writing sample train file...")
    with open(sample_train, 'w') as f:
        for u, i, r, t in sampled_train:
            f.write(f"{user_map[u]}\t{item_map[i]}\t{r}\t{t}\n")
    
    # 5. Filter and write test file
    print("  [5/6] Writing sample test files...")
    test_data = []
    with open(orig_test, 'r') as f:
        for line in f:
            arr = line.strip().split('\t')
            user, item = int(arr[0]), int(arr[1])
            if user in sampled_users and item in item_map:
                test_data.append((user, item, arr[2], arr[3]))
    
    with open(sample_test, 'w') as f:
        for u, i, r, t in test_data:
            f.write(f"{user_map[u]}\t{item_map[i]}\t{r}\t{t}\n")
    
    # 6. Write test negative file
    print("  [6/6] Writing sample test negative file...")
    
    # Read original negative file as dictionary
    orig_neg_dict = {}
    with open(orig_neg, 'r') as f:
        for line in f:
            arr = line.strip().split('\t')
            user_item_pair = arr[0]
            user = int(user_item_pair.split(',')[0][1:])
            item = int(user_item_pair.split(',')[1][:-1])
            orig_neg_dict[(user, item)] = arr[1:]
    
    # Write negative file in test_data order
    with open(sample_neg, 'w') as f_out:
        for u_orig, i_orig, _, _ in test_data:
            if (u_orig, i_orig) in orig_neg_dict:
                orig_neg_items = orig_neg_dict[(u_orig, i_orig)]
                
                # Filter negatives (only items in item_map)
                neg_items = [int(x) for x in orig_neg_items if int(x) in item_map]
                neg_items_remapped = [item_map[i] for i in neg_items]
                
                # Ensure 99 negatives
                if len(neg_items_remapped) >= 99:
                    neg_items_remapped = neg_items_remapped[:99]
                else:
                    # Fill with random items if not enough
                    available_items = list(item_map.values())
                    while len(neg_items_remapped) < 99:
                        rand_item = random.choice(available_items)
                        if rand_item not in neg_items_remapped:
                            neg_items_remapped.append(rand_item)
                
                # Write remapped user/item IDs
                u_new = user_map[u_orig]
                i_new = item_map[i_orig]
                f_out.write(f"({u_new},{i_new})\t" + '\t'.join(map(str, neg_items_remapped)) + '\n')
            else:
                # Generate random negatives if not found
                print(f"        Warning: No negatives found for user {u_orig}, item {i_orig}. Generating random.")
                available_items = list(item_map.values())
                neg_items_remapped = random.sample(available_items, min(99, len(available_items)))
                u_new = user_map[u_orig]
                i_new = item_map[i_orig]
                f_out.write(f"({u_new},{i_new})\t" + '\t'.join(map(str, neg_items_remapped)) + '\n')
    
    t_end = time()
    print(f"\n  ✓ Sample files created in {t_end - t_start:.1f}s")
    print(f"    - {sample_train}")
    print(f"    - {sample_test}")
    print(f"    - {sample_neg}")
    
    return {
        'num_users': len(unique_users),
        'num_items': len(unique_items),
        'num_train': len(sampled_train),
        'num_test': len(test_data)
    }


print("✓ Sampling function defined.")

## 3. 샘플 데이터 생성

In [None]:
print("=" * 80)
print("SAMPLE DATA GENERATION")
print("=" * 80)

results = {}

for sample_size in SAMPLE_SIZES:
    print(f"\n{'='*80}")
    print(f"Generating sample: {DATASET}-sample{sample_size}")
    print(f"{'='*80}")
    
    stats = create_sample(
        dataset=DATASET,
        sample_users=sample_size,
        seed=SEED,
        data_path=DATA_PATH,
        overwrite=OVERWRITE
    )
    
    if stats:
        results[sample_size] = stats

print(f"\n{'='*80}")
print("GENERATION COMPLETE")
print(f"{'='*80}")

## 4. 검증 및 요약

In [None]:
import os

print("\nSample Data Summary:")
print("=" * 100)
print(f"{'Sample':<20} {'Users':<10} {'Items':<10} {'Train':<15} {'Test':<10} {'File Sizes':<30}")
print("=" * 100)

for sample_size in SAMPLE_SIZES:
    sample_name = f"{DATASET}-sample{sample_size}"
    train_file = f"{DATA_PATH}{sample_name}.train.rating"
    test_file = f"{DATA_PATH}{sample_name}.test.rating"
    neg_file = f"{DATA_PATH}{sample_name}.test.negative"
    
    if os.path.exists(train_file):
        # Get file sizes
        train_size = os.path.getsize(train_file) / 1024  # KB
        test_size = os.path.getsize(test_file) / 1024
        neg_size = os.path.getsize(neg_file) / 1024
        
        # Count lines
        with open(train_file) as f:
            train_lines = sum(1 for _ in f)
        with open(test_file) as f:
            test_lines = sum(1 for _ in f)
        
        # Get stats if available
        if sample_size in results:
            stats = results[sample_size]
            num_users = stats['num_users']
            num_items = stats['num_items']
        else:
            num_users = '?'
            num_items = '?'
        
        file_sizes = f"{train_size:.1f}K / {test_size:.1f}K / {neg_size:.1f}K"
        print(f"{sample_name:<20} {str(num_users):<10} {str(num_items):<10} {train_lines:<15} {test_lines:<10} {file_sizes:<30}")
    else:
        print(f"{sample_name:<20} {'NOT FOUND':<60}")

print("=" * 100)
print("\n✓ All sample datasets are ready for use.")
print("\nUsage in notebooks:")
print("  DATASET = 'ml-1m-sample100'  # or 'ml-1m-sample500', 'ml-1m-sample1000'")