
# Average multiple LoRA checkpoints (simple FedAvg)

이 노트북은 디렉토리에 모아둔 **LoRA 전용 ckpt(state_dict)**들을 불러와 **단순 평균**(equal weight)한 뒤 `local_only_ensemble.ckpt`와 같은 하나의 파일로 저장합니다.

## 사용 방법
1. 아래 **설정(경로/패턴)** 셀에서 경로를 본인 환경에 맞게 수정합니다.  
2. **실행** 셀을 실행하면, 평균 결과가 `out_path`에 저장됩니다.


In [1]:

import os, glob, math
from pathlib import Path
from typing import List, Dict, Tuple
import torch

def average_lora_ckpts(
    in_dir: str,
    out_path: str,
    pattern: str = "local_only_tldr_choice_qwen_client_*.ckpt",
    strict_keys: bool = False,
    save_dtype = torch.float32,
) -> Dict[str, torch.Tensor]:
    """
    디렉토리(in_dir)에서 pattern에 매칭되는 ckpt들을 모두 불러와
    공통 키의 float 텐서만 simple average하여 out_path로 저장합니다.

    - strict_keys=False: 교집합 키만 평균 (누락/shape 불일치 키는 첫 ckpt 값 유지)
    - strict_keys=True : 모든 파일이 동일한 키/shape를 가져야 함 (아니면 에러)

    반환: 평균된 state_dict (메모리 상의 dict)
    """
    in_dir = str(in_dir)
    out_path = str(out_path)
    paths = sorted(glob.glob(os.path.join(in_dir, pattern)))
    if not paths:
        raise FileNotFoundError(f"No ckpt matches: {os.path.join(in_dir, pattern)}")

    # Load all state_dicts on CPU
    state_dicts: List[Dict[str, torch.Tensor]] = []
    for p in paths:
        sd = torch.load(p, map_location="cpu")
        if not isinstance(sd, dict):
            raise RuntimeError(f"Checkpoint is not a state_dict: {p}")
        state_dicts.append(sd)

    # Decide key set
    if strict_keys:
        # Intersection of keys across all checkpoints
        key_set = set(state_dicts[0].keys())
        for sd in state_dicts[1:]:
            key_set &= set(sd.keys())
        # Validate shapes for strict mode
        for k in key_set:
            shape0 = tuple(state_dicts[0][k].shape) if torch.is_tensor(state_dicts[0][k]) else None
            for sd in state_dicts[1:]:
                v = sd[k]
                if not torch.is_tensor(v):
                    raise RuntimeError(f"Non-tensor value for key={k} in strict mode.")
                if tuple(v.shape) != shape0:
                    raise RuntimeError(f"Shape mismatch for key={k}: {shape0} vs {tuple(v.shape)}")
    else:
        # Start from the first dict's keys; allow missing keys later
        key_set = set(state_dicts[0].keys())

    # Prepare output dict: start from first state_dict as base
    out_sd: Dict[str, torch.Tensor] = {}
    base = state_dicts[0]

    # Build per-key accumulator
    num_ckpts = len(state_dicts)
    for k in key_set:
        v0 = base.get(k, None)
        if not torch.is_tensor(v0) or not v0.dtype.is_floating_point:
            # Keep as-is (buffers / non-float) from base
            if v0 is not None:
                out_sd[k] = v0.clone()
            continue

        # Accumulate tensors that exist with same shape
        acc = None
        shape0 = tuple(v0.shape)
        count = 0
        for sd in state_dicts:
            v = sd.get(k, None)
            if v is None or (not torch.is_tensor(v)) or (tuple(v.shape) != shape0) or (not v.dtype.is_floating_point):
                if strict_keys:
                    raise RuntimeError(f"Key {k} missing or shape/dtype mismatch in strict mode.")
                # skip in non-strict
                continue
            if acc is None:
                acc = v.to(dtype=save_dtype, copy=True)
            else:
                acc.add_(v.to(dtype=save_dtype))
            count += 1

        if acc is None or count == 0:
            # No valid contributors → keep base value
            out_sd[k] = v0.to(dtype=save_dtype, copy=True)
        else:
            acc.div_(float(count))
            out_sd[k] = acc

    # If not strict, also copy over extra (non-intersection) keys from base that we didn't visit
    if not strict_keys:
        for k, v in base.items():
            if k in out_sd:
                continue
            if torch.is_tensor(v) and v.dtype.is_floating_point:
                out_sd[k] = v.to(dtype=save_dtype, copy=True)
            else:
                out_sd[k] = v

    # Save
    Path(out_path).parent.mkdir(parents=True, exist_ok=True)
    torch.save(out_sd, out_path)
    print(f"[OK] Averaged {num_ckpts} ckpts → {out_path}")
    return out_sd


In [2]:

# === 설정(경로/패턴) ===
in_dir   = "./checkpoints_1.0_local_only"   # LoRA ckpt 저장 폴더
pattern  = "local_only_tldr_choice_qwen_client_*.ckpt"  # 파일 패턴
out_path = "./checkpoints_1.0_local_only/local_only_ensemble.ckpt"     # 결과 저장 경로

# strict_keys=True 로 두면 키/shape가 전부 동일해야만 동작 (안전하지만 빡셈)
strict_keys = False
save_dtype  = None  # 예: torch.float32 로 강제 변환하고 싶다면 torch.float32 지정


In [3]:

# === 실행 ===
import torch

sd = average_lora_ckpts(
    in_dir=in_dir,
    out_path=out_path,
    pattern=pattern,
    strict_keys=strict_keys,
    save_dtype=(save_dtype if save_dtype is not None else torch.float32),
)

# 간단한 요약 출력
num_params = sum(p.numel() for p in sd.values() if torch.is_tensor(p))
num_tensors = sum(1 for p in sd.values() if torch.is_tensor(p))
print(f"Averaged state_dict: tensors={num_tensors}, total_params={num_params}")
print(f"Saved to: {out_path}")


[OK] Averaged 53 ckpts → ./checkpoints_1.0_local_only/local_only_ensemble.ckpt
Averaged state_dict: tensors=336, total_params=4399104
Saved to: ./checkpoints_1.0_local_only/local_only_ensemble.ckpt
