In [1]:
import os
import numpy as np

In [2]:
PARENT_DIR = "/projects/leelab/clin25"
FROMATTED_METHOD_DICT = {
    "cocoop": "CoCoOp",
    "cpl": "CPL",
}
DATASET_LIST = ["flickr30k"]
SEED_LIST = [1, 2, 3]

PAPER_RECALL_145SHOTS_DICT = {
    "cocoop": {"flickr30k": 82.40},
    "cpl": {"flickr30k": 85.64},
}

PAPER_RECALL_290SHOTS_DICT = {
    "cocoop": {"flickr30k": 84.80},
    "cpl": {"flickr30k": 86.91},
}

PAPER_RECALL_DICT = {
    145: PAPER_RECALL_145SHOTS_DICT,
    290: PAPER_RECALL_290SHOTS_DICT,
}

In [3]:
def retrieve_recall_at_1(method: str, dataset: str, seed: int, num_shots: int) -> float:
    output_dir = os.path.join(
        PARENT_DIR,
        f"{method}-output",
        dataset,
        f"{num_shots}-shots",
        f"{seed}",
        "eval-unseen",
    )
    with open(os.path.join(output_dir, "log.txt")) as handle:
        log_text = handle.readlines()
        
    accuracy_line = log_text[-3].strip().replace("* ", "").replace(":", "")
    assert accuracy_line.split(" ")[0] == "accuracy"
    return float(accuracy_line.split(" ")[1].replace("%", ""))

In [4]:
def print_result_summary(method: str, num_shots: int) -> None:
    print(f"{FROMATTED_METHOD_DICT[method]} recall@1 with {num_shots} shots.")
    print("-" * 45)
    for dataset in DATASET_LIST:
        print(dataset)
        recall_list = [
            retrieve_recall_at_1(method, dataset, seed, num_shots)
            for seed in SEED_LIST
        ]
        paper_recall = PAPER_RECALL_DICT[num_shots][method][dataset]

        for i in range(len(recall_list)):
            if i == 0:
                relative_diff = (recall_list[i] - paper_recall) / paper_recall * 100
                relative_diff = f" ({relative_diff:.2f})"
            else:
                relative_diff = ""
            print(f"\tRecall@1 with seed {SEED_LIST[i]}: {recall_list[i]}{relative_diff}")

        recall_mean = np.mean(recall_list)
        recall_se = np.var(recall_list) / np.sqrt(len(recall_list))
        recall_ci = 1.96 * recall_se

        mean_relative_diff = (recall_mean - paper_recall) / paper_recall * 100
        ci_contains_paper_recall = (
            (paper_recall <= recall_mean + recall_ci)
            & (paper_recall >= recall_mean - recall_ci)
        )

        print(
            f"\tRecall@1 mean (95% CI): {recall_mean:.2f} ({recall_ci:.2f})"
        )
        print("")
        print(f"\tMean relative difference with paper recall@1: {mean_relative_diff:.2f}")
        print(f"\t95% CI contains paper recall@1: {ci_contains_paper_recall}")
        print("-" * 45)

## Performance 145 shots

In [5]:
print_result_summary("cocoop", 145)

CoCoOp recall@1 with 145 shots.
---------------------------------------------
flickr30k
	Recall@1 with seed 1: 82.4 (0.00)
	Recall@1 with seed 2: 82.5
	Recall@1 with seed 3: 83.4
	Recall@1 mean (95% CI): 82.77 (0.23)

	Mean relative difference with paper recall@1: 0.44
	95% CI contains paper recall@1: False
---------------------------------------------


In [6]:
print_result_summary("cpl", 145)

CPL recall@1 with 145 shots.
---------------------------------------------
flickr30k
	Recall@1 with seed 1: 83.2 (-2.85)
	Recall@1 with seed 2: 83.2
	Recall@1 with seed 3: 83.0
	Recall@1 mean (95% CI): 83.13 (0.01)

	Mean relative difference with paper recall@1: -2.93
	95% CI contains paper recall@1: False
---------------------------------------------


## Performance with 290 shots

In [7]:
print_result_summary("cocoop", 290)

CoCoOp recall@1 with 290 shots.
---------------------------------------------
flickr30k
	Recall@1 with seed 1: 85.2 (0.47)
	Recall@1 with seed 2: 83.0
	Recall@1 with seed 3: 84.7
	Recall@1 mean (95% CI): 84.30 (1.00)

	Mean relative difference with paper recall@1: -0.59
	95% CI contains paper recall@1: True
---------------------------------------------


In [8]:
print_result_summary("cpl", 290)

CPL recall@1 with 290 shots.
---------------------------------------------
flickr30k
	Recall@1 with seed 1: 85.1 (-2.08)
	Recall@1 with seed 2: 83.0
	Recall@1 with seed 3: 84.4
	Recall@1 mean (95% CI): 84.17 (0.86)

	Mean relative difference with paper recall@1: -3.16
	95% CI contains paper recall@1: False
---------------------------------------------
