# SSRFNet SVMixer Check
로컬 체크포인트가 정상 로드되는지 확인하는 간단한 예제입니다.

In [13]:
import torch
from pathlib import Path
from experiments.eval_only.test_code.models.svmixer import SVMixer

def _fix_state_dict(state_dict):
    """모델의 state_dict 키를 정리합니다."""
    state_dict = state_dict.get('state_dict', state_dict)
    cleaned_state_dict = {}
    for key, value in state_dict.items():
        if key.startswith("student_model."):
            new_key = key[len("student_model."):]
        elif key.startswith("classifier."):
            new_key = key[len("classifier."):]
        else:
            new_key = key
        if 'total_ops' in new_key or 'total_params' in new_key:
            continue
        cleaned_state_dict[new_key] = value
    return cleaned_state_dict

model = SVMixer(12, 149, 1024)
repo_root = Path.cwd()
state_dict_path = repo_root / 'assets' / 'trained_model' / 'ssrfnet_eer0.60_sv_mixer_state_dict.pt'

print('Using checkpoint:', state_dict_path)
state_dict = torch.load(state_dict_path, map_location='cpu')
state_dict = _fix_state_dict(state_dict)
model.load_state_dict(state_dict)
print('Keys:', len(state_dict.keys()))
first_key = next(iter(state_dict))
print('Sample key:', first_key)
print('Tensor shape:', state_dict[first_key].shape)


Using checkpoint: /home/koo/code/code/SSLbackend/20250906/SSRFNet_content/assets/trained_model/ssrfnet_eer0.60_sv_mixer_state_dict.pt
Keys: 303
Sample key: conv_layers.0.conv.weight
Tensor shape: torch.Size([512, 1, 10])
