In [29]:
import os
import numpy as np
from typing import List

In [54]:
PARENT_DIR = "/projects/leelab/clin25"
DATASET_LIST = [
    "sun397",
    "caltech-101",
    "oxford_flowers",
    "food-101",
]
SEED_LIST = [1, 2, 3]
BATCH_TIME_MULTIPLIER = 20

In [38]:
def retrieve_runtime(method: str, dataset: str, seed: int) -> List[float]:
    output_dir = os.path.join(
        PARENT_DIR, f"{method}-output", dataset, "base", f"{seed}"
    )
    log_file = os.path.join(output_dir, "log.txt")
    with open(log_file) as handle:
        log_text = handle.readlines()
    batch_log_list = [line for line in log_text if line.startswith("epoch")]
    
    batch_epoch_list = [
        int(batch_log.split("[")[1].split("/")[0])
        for batch_log in batch_log_list
    ]
    batch_time_list = [
        float(batch_log.split("time ")[1].split(" ")[0])
        for batch_log in batch_log_list
    ]
    batch_epoch_list = np.array(batch_epoch_list)
    batch_time_list = np.array(batch_time_list)
    
    epoch_time_list = []
    for i in range(batch_epoch_list.max()):
        epoch = i + 1
        epoch_time = batch_time_list[batch_epoch_list == epoch]
        epoch_time = epoch_time.sum()
        epoch_time_list.append(epoch_time)
        
    return epoch_time_list

## CoCoOp runtime

In [57]:
print("CoCoOp training GPU runtime.")
print("-" * 45)

for dataset in DATASET_LIST:
    print(dataset)
    epoch_runtime_list = []
    total_runtime_list = []
    
    for seed in SEED_LIST:
        epoch_time_list = retrieve_runtime("cocoop", dataset, seed)
        epoch_runtime_list.append(np.mean(epoch_time_list) * BATCH_TIME_MULTIPLIER / 60)
        total_runtime_list.append(np.sum(epoch_time_list) * BATCH_TIME_MULTIPLIER / 60)
    
    epoch_runtime_mean = np.mean(epoch_runtime_list)
    epoch_runtime_std = np.std(epoch_runtime_list)
    total_runtime_mean = np.mean(total_runtime_list)
    total_runtime_std = np.std(total_runtime_list)
    
    print(f"\tEpoch mean (sd): {epoch_runtime_mean:.3f} ({epoch_runtime_std:.3f})")
    print(f"\tTotal mean (sd): {total_runtime_mean:.3f} ({total_runtime_std:.3f})")
    
    print("-" * 45)

CoCoOp training GPU runtime.
---------------------------------------------
sun397
	Epoch mean (sd): 8.796 (0.162)
	Total mean (sd): 87.957 (1.619)
---------------------------------------------
caltech-101
	Epoch mean (sd): 0.703 (0.013)
	Total mean (sd): 7.032 (0.135)
---------------------------------------------
oxford_flowers
	Epoch mean (sd): 0.705 (0.004)
	Total mean (sd): 7.053 (0.039)
---------------------------------------------
food-101
	Epoch mean (sd): 2.283 (0.079)
	Total mean (sd): 22.834 (0.792)
---------------------------------------------


## CPL runtime

In [58]:
print("CPL training GPU runtime.")
print("-" * 45)

for dataset in DATASET_LIST:
    print(dataset)
    epoch_runtime_list = []
    total_runtime_list = []
    
    for seed in SEED_LIST:
        epoch_time_list = retrieve_runtime("cpl", dataset, seed)
        epoch_runtime_list.append(np.mean(epoch_time_list) * BATCH_TIME_MULTIPLIER / 60)
        total_runtime_list.append(np.sum(epoch_time_list) * BATCH_TIME_MULTIPLIER / 60)
    
    epoch_runtime_mean = np.mean(epoch_runtime_list)
    epoch_runtime_std = np.std(epoch_runtime_list)
    total_runtime_mean = np.mean(total_runtime_list)
    total_runtime_std = np.std(total_runtime_list)
    
    print(f"\tEpoch mean (sd): {epoch_runtime_mean:.3f} ({epoch_runtime_std:.3f})")
    print(f"\tTotal mean (sd): {total_runtime_mean:.3f} ({total_runtime_std:.3f})")
    
    print("-" * 45)

CPL training GPU runtime.
---------------------------------------------
sun397
	Epoch mean (sd): 142.338 (3.518)
	Total mean (sd): 1423.381 (35.182)
---------------------------------------------
caltech-101
	Epoch mean (sd): 35.782 (1.086)
	Total mean (sd): 357.818 (10.858)
---------------------------------------------
oxford_flowers
	Epoch mean (sd): 38.803 (2.138)
	Total mean (sd): 388.032 (21.380)
---------------------------------------------
food-101
	Epoch mean (sd): 41.302 (1.085)
	Total mean (sd): 413.021 (10.850)
---------------------------------------------
