In [1]:
import shutil
from pathlib import Path

import os
import glob
import json
import torch
import collections


def SWA_process(file_name, only_show=False):
    
    model_dir = Path('/data4/huangweigang/gh/csiro-biomass/ÂéÜÂè≤Ê®°Âûã/') / file_name
    scripts_name = os.path.basename(model_dir)
    out_dir = model_dir / (scripts_name + '_swa_models')


    out_dir.mkdir(parents=True, exist_ok=True)


    for file in model_dir.iterdir():
        if file.suffix in ['.py', '.xlsx', '.json', '.csv']:
            shutil.copy(file, out_dir)


    search_pattern = os.path.join(out_dir, f"{scripts_name}.py.log")
    matching_files = glob.glob(search_pattern)

    if matching_files:
        # print(f"‚úÖ Â≠òÂú® .log Êñá‰ª∂")
        pass
    else:
        print(f"‚ùå Ê≤°Êúâ .log Êñá‰ª∂")
        alternative_dir = '/data4/huangweigang/gh'
        alternative_search_pattern = os.path.join(alternative_dir, f"**/{scripts_name}.py.log")
        alternative_matching_files = glob.glob(alternative_search_pattern, recursive=True)
        if alternative_matching_files:
            for file in alternative_matching_files:
                print(f"‚úÖ ÁßªÂä® {file}")
                shutil.copy(file, out_dir)  # Â§çÂà∂Êñá‰ª∂Âà∞out_dir
        else:
            print(f"‚ùå Ê≤°ÊúâÊâæÂà∞ .log Êñá‰ª∂ ‚ùå‚ùå")
            


    folds = [1, 2, 3, 4, 5]
    folds_best_CV = [-1] * len(folds)
    
    for fold in folds:
        json_path = os.path.join(model_dir, f'model_top3_fold{fold}.json')
        if not os.path.exists(json_path):
            print(f"‚ùå JSON file not found for fold {fold}: {json_path}")
            continue
            
        with open(json_path, 'r') as f:
            model_info = json.load(f)

        model_paths = set()
        epoch_set = set()
        
        best_cv_value = -1
        for model_data in model_info.get('cv_top3', []):
            epoch = model_data['epoch']
            path = model_dir / os.path.basename(model_data['path'])
            best_cv_value = max(best_cv_value, model_data['cv'])

            if epoch not in epoch_set:
                epoch_set.add(epoch)
                model_paths.add(path)
        folds_best_CV[fold - 1] = best_cv_value

        for model_data in model_info.get('loss_top3', []):
            epoch = model_data['epoch']
            path = model_dir / os.path.basename(model_data['path'])
            if epoch not in epoch_set:
                epoch_set.add(epoch)
                model_paths.add(path)
        

        if only_show: 
            continue
        model_paths = list(model_paths)
        
        print(f"üîÑ Fold {fold} Found {len(model_paths)} unique models  {sorted(epoch_set)}")
        # for path in model_paths:
        #     print(f" - {path}")

        if len(model_paths) < 2:
            print(f"‚ö†Ô∏è  Fold {fold} Only {len(model_paths)} unique models, skipping SWA...")
            continue

        models = []
        for module_path in model_paths:
            if os.path.exists(module_path):
                model = torch.load(module_path, map_location='cpu', weights_only=True)
                models.append(model)
            else:
                print(f"‚ùå Model file not found: {module_path}")

        if len(models) < 2:
            print(f"‚ö†Ô∏è  Fold {fold} Only {len(models)} models loaded, skipping SWA...")
            continue

        worker_state_dicts = [m for m in models]
        weight_keys = list(worker_state_dicts[0].keys())
        # print(f"Example weight keys: {list(weight_keys)[:5]}")

        fed_state_dict = collections.OrderedDict()
        for key in weight_keys:
            key_sum = 0
            for i in range(len(models)):
                key_sum += worker_state_dicts[i][key]
            fed_state_dict[key] = key_sum / len(models)



        
        output_path = os.path.join(out_dir, f'fold{fold}_swa.pt')
        torch.save(fed_state_dict, output_path)
        print(f"‚úÖ Fold {fold} averaging complete. Saved to: {os.path.basename(output_path)}")

    avg_score = sum(folds_best_CV) / len(folds_best_CV)
    str_ = f"‚úÖ {scripts_name:<30} avg: {avg_score:0.4f} per: {', '.join([f'{x:0.4f}' for x in folds_best_CV])}"
    print(str_)

    # ‰øùÂ≠ò‰ø°ÊÅØÂà∞txtÊñá‰ª∂
    output_file = os.path.join(out_dir, "folds_best_CV_results.txt")
    with open(output_file, 'w') as file:
        file.write(str_)



In [2]:
SWA_process('0119/single_0119')
SWA_process('0119/single_0119_head_ratio')
# SWA_process('single_1209/single_1209_Huge_2')
# SWA_process('single_1209/single_1209_Huge_2_Freeze1')

‚ùå Ê≤°Êúâ .log Êñá‰ª∂
‚úÖ ÁßªÂä® /data4/huangweigang/gh/method/single_0119.py.log
üîÑ Fold 1 Found 6 unique models  [54, 58, 61, 87, 102, 128]
‚úÖ Fold 1 averaging complete. Saved to: fold1_swa.pt
üîÑ Fold 2 Found 5 unique models  [85, 90, 104, 115, 117]
‚úÖ Fold 2 averaging complete. Saved to: fold2_swa.pt
üîÑ Fold 3 Found 5 unique models  [124, 129, 130, 141, 160]
‚úÖ Fold 3 averaging complete. Saved to: fold3_swa.pt
üîÑ Fold 4 Found 6 unique models  [34, 61, 103, 104, 133, 147]
‚úÖ Fold 4 averaging complete. Saved to: fold4_swa.pt
üîÑ Fold 5 Found 4 unique models  [42, 60, 64, 89]
‚úÖ Fold 5 averaging complete. Saved to: fold5_swa.pt
‚úÖ single_0119                    avg: 0.8096 per: 0.8088, 0.8605, 0.7756, 0.8049, 0.7979
‚ùå Ê≤°Êúâ .log Êñá‰ª∂
‚úÖ ÁßªÂä® /data4/huangweigang/gh/method/single_0119_head_ratio.py.log
üîÑ Fold 1 Found 5 unique models  [41, 42, 44, 56, 72]
‚úÖ Fold 1 averaging complete. Saved to: fold1_swa.pt
üîÑ Fold 2 Found 6 unique models  [63, 84, 117, 121, 

In [4]:
SWA_process('0119/single_0119', only_show=True)
SWA_process('0119/single_0119_head_ratio', only_show=True)

‚úÖ single_0119                    avg: 0.8096 per: 0.8088, 0.8605, 0.7756, 0.8049, 0.7979
‚úÖ single_0119_head_ratio         avg: 0.8361 per: 0.8355, 0.8972, 0.8030, 0.8230, 0.8220


single_1209_Huge_1
vit_huge_plus_patch16_dinov3.lvd1689m   fr = 0.8 

single_1209_Huge_2
vit_huge_plus_patch16_dinov3_qkvb.lvd1689m  fr = 0.8 

single_1209_Huge_2_Freeze1
vit_huge_plus_patch16_dinov3_qkvb.lvd1689m   fr = 1.0 => fr = 0.8 (epoch = 20)