In [1]:
import os
import numpy as np
from typing import List

In [2]:
PARENT_DIR = "/projects/leelab/clin25"
CLASSIFICATION_DATASET_LIST = [
    "sun397",
    "caltech-101",
    "oxford_flowers",
    "food-101",
]
FROMATTED_METHOD_DICT = {
    "cocoop": "CoCoOp",
    "cpl": "CPL",
}
SEED_LIST = [1, 2, 3]
BATCH_TIME_MULTIPLIER = 20

In [3]:
def retrieve_runtime(
    method: str, dataset: str, seed: int, train_mode: str = "base"
) -> List[float]:
    output_dir = os.path.join(
        PARENT_DIR, f"{method}-output", dataset, train_mode, f"{seed}"
    )
    log_file = os.path.join(output_dir, "log.txt")
    with open(log_file) as handle:
        log_text = handle.readlines()
    batch_log_list = [line for line in log_text if line.startswith("epoch")]
    
    batch_epoch_list = [
        int(batch_log.split("[")[1].split("/")[0])
        for batch_log in batch_log_list
    ]
    batch_time_list = [
        float(batch_log.split("time ")[1].split(" ")[0])
        for batch_log in batch_log_list
    ]
    batch_epoch_list = np.array(batch_epoch_list)
    batch_time_list = np.array(batch_time_list)
    
    epoch_time_list = []
    for i in range(batch_epoch_list.max()):
        epoch = i + 1
        epoch_time = batch_time_list[batch_epoch_list == epoch]
        epoch_time = epoch_time.sum()
        epoch_time_list.append(epoch_time)
        
    return epoch_time_list

In [4]:
def print_runtime_summary(
    method: str, dataset_list: List[str], train_mode: str = "base"
) -> None:
    title = f"{FROMATTED_METHOD_DICT[method]} training GPU runtime"
    if "shot" in train_mode:
        title += f" with {train_mode}"
    title += "."
    print(title)
    print("-" * 45)

    for dataset in dataset_list:
        print(dataset)
        epoch_runtime_list = []
        total_runtime_list = []

        for seed in SEED_LIST:
            epoch_time_list = retrieve_runtime(method, dataset, seed, train_mode)
            epoch_runtime_list.append(np.mean(epoch_time_list) * BATCH_TIME_MULTIPLIER / 60)
            total_runtime_list.append(np.sum(epoch_time_list) * BATCH_TIME_MULTIPLIER / 60)

        epoch_runtime_mean = np.mean(epoch_runtime_list)
        epoch_runtime_std = np.std(epoch_runtime_list)
        total_runtime_mean = np.mean(total_runtime_list)
        total_runtime_std = np.std(total_runtime_list)

        print(f"\tEpoch mean (sd): {epoch_runtime_mean:.3f} ({epoch_runtime_std:.3f})")
        print(f"\tTotal mean (sd): {total_runtime_mean:.3f} ({total_runtime_std:.3f})")

        print("-" * 45)

## CoCoOp runtime (in minutes)

In [5]:
print_runtime_summary("cocoop", CLASSIFICATION_DATASET_LIST, "base")
print("")

print_runtime_summary("cocoop", ["flickr30k"], "145-shots")
print("")

print_runtime_summary("cocoop", ["flickr30k"], "290-shots")
print("")

CoCoOp training GPU runtime.
---------------------------------------------
sun397
	Epoch mean (sd): 8.796 (0.162)
	Total mean (sd): 87.957 (1.619)
---------------------------------------------
caltech-101
	Epoch mean (sd): 0.703 (0.013)
	Total mean (sd): 7.032 (0.135)
---------------------------------------------
oxford_flowers
	Epoch mean (sd): 0.705 (0.004)
	Total mean (sd): 7.053 (0.039)
---------------------------------------------
food-101
	Epoch mean (sd): 2.283 (0.079)
	Total mean (sd): 22.834 (0.792)
---------------------------------------------

CoCoOp training GPU runtime with 145-shots.
---------------------------------------------
flickr30k
	Epoch mean (sd): 0.771 (0.001)
	Total mean (sd): 7.709 (0.015)
---------------------------------------------

CoCoOp training GPU runtime with 290-shots.
---------------------------------------------
flickr30k
	Epoch mean (sd): 1.023 (0.005)
	Total mean (sd): 10.227 (0.045)
---------------------------------------------



## CPL runtime (in minutes)

In [6]:
print_runtime_summary("cpl", CLASSIFICATION_DATASET_LIST, "base")
print("")

print_runtime_summary("cpl", ["flickr30k"], "145-shots")
print("")

print_runtime_summary("cpl", ["flickr30k"], "290-shots")
print("")

CPL training GPU runtime.
---------------------------------------------
sun397
	Epoch mean (sd): 142.338 (3.518)
	Total mean (sd): 1423.381 (35.182)
---------------------------------------------
caltech-101
	Epoch mean (sd): 35.782 (1.086)
	Total mean (sd): 357.818 (10.858)
---------------------------------------------
oxford_flowers
	Epoch mean (sd): 38.803 (2.138)
	Total mean (sd): 388.032 (21.380)
---------------------------------------------
food-101
	Epoch mean (sd): 41.302 (1.085)
	Total mean (sd): 413.021 (10.850)
---------------------------------------------

CPL training GPU runtime with 145-shots.
---------------------------------------------
flickr30k
	Epoch mean (sd): 5.683 (0.232)
	Total mean (sd): 56.831 (2.323)
---------------------------------------------

CPL training GPU runtime with 290-shots.
---------------------------------------------
flickr30k
	Epoch mean (sd): 11.465 (0.656)
	Total mean (sd): 114.651 (6.562)
---------------------------------------------

