# 检索任务Performance可视化结果展示

## 函数定义

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt

def read_csv(csv_file):
    try:
        # 使用 pandas 直接读取 csv 文件
        data = pd.read_csv(csv_file)
        return data
    except FileNotFoundError:
        print(f"File not found: {csv_file}")
        return None
    except pd.errors.EmptyDataError:
        print(f"File is empty: {csv_file}")
        return None
    except Exception as e:
        print(f"An error occurred while reading the file: {e}")
        return None
        
# 定义绘图函数，带最高点数值显示
def plot_test_loss(data):
    plt.figure()
    plt.plot(data['epoch'], data['test_loss'], label='Test Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Test Loss')
    plt.title('Test Loss Over Epochs')
    plt.legend()
    plt.show() 

def plot_test_accuracy(data):
    plt.figure()
    plt.plot(data['epoch'], data['test_accuracy'], label='Test Accuracy')
    max_idx = data['test_accuracy'].idxmax()
    max_epoch = data['epoch'][max_idx]
    max_value = data['test_accuracy'][max_idx]
    plt.scatter(max_epoch, max_value, color='red', label='Max Point')
    plt.text(max_epoch, max_value, f'{max_value:.2f}', color='red', fontsize=10, ha='right', va='bottom')
    plt.xlabel('Epoch')
    plt.ylabel('Test Accuracy')
    plt.title('Test Accuracy Over Epochs')
    plt.legend()
    plt.show()

def plot_v2_acc(data):
    plt.figure()
    plt.plot(data['epoch'], data['v2_acc'], label='V2 Accuracy')
    max_idx = data['v2_acc'].idxmax()
    max_epoch = data['epoch'][max_idx]
    max_value = data['v2_acc'][max_idx]
    plt.scatter(max_epoch, max_value, color='red', label='Max Point')
    plt.text(max_epoch, max_value, f'{max_value:.2f}', color='red', fontsize=10, ha='right', va='bottom')
    plt.xlabel('Epoch')
    plt.ylabel('V2 Accuracy')
    plt.title('V2 Accuracy Over Epochs')
    plt.legend()
    plt.show()

def plot_v4_acc(data):
    plt.figure()
    plt.plot(data['epoch'], data['v4_acc'], label='V4 Accuracy')
    max_idx = data['v4_acc'].idxmax()
    max_epoch = data['epoch'][max_idx]
    max_value = data['v4_acc'][max_idx]
    plt.scatter(max_epoch, max_value, color='red', label='Max Point')
    plt.text(max_epoch, max_value, f'{max_value:.2f}', color='red', fontsize=10, ha='right', va='bottom')
    plt.xlabel('Epoch')
    plt.ylabel('V4 Accuracy')
    plt.title('V4 Accuracy Over Epochs')
    plt.legend()
    plt.show()

def plot_v10_acc(data):
    plt.figure()
    plt.plot(data['epoch'], data['v10_acc'], label='V10 Accuracy')
    max_idx = data['v10_acc'].idxmax()
    max_epoch = data['epoch'][max_idx]
    max_value = data['v10_acc'][max_idx]
    plt.scatter(max_epoch, max_value, color='red', label='Max Point')
    plt.text(max_epoch, max_value, f'{max_value:.2f}', color='red', fontsize=10, ha='right', va='bottom')
    plt.xlabel('Epoch')
    plt.ylabel('V10 Accuracy')
    plt.title('V10 Accuracy Over Epochs')
    plt.legend()
    plt.show()

def plot_top5_acc(data):
    plt.figure()
    plt.plot(data['epoch'], data['top5_acc'], label='Top-5 Accuracy')
    max_idx = data['top5_acc'].idxmax()
    max_epoch = data['epoch'][max_idx]
    max_value = data['top5_acc'][max_idx]
    plt.scatter(max_epoch, max_value, color='red', label='Max Point')
    plt.text(max_epoch, max_value, f'{max_value:.2f}', color='red', fontsize=10, ha='right', va='bottom')
    plt.xlabel('Epoch')
    plt.ylabel('Top-5 Accuracy')
    plt.title('Top-5 Accuracy Over Epochs')
    plt.legend()
    plt.show()

def plot_v50_acc(data):
    plt.figure()
    plt.plot(data['epoch'], data['v50_acc'], label='V50 Accuracy')
    max_idx = data['v50_acc'].idxmax()
    max_epoch = data['epoch'][max_idx]
    max_value = data['v50_acc'][max_idx]
    plt.scatter(max_epoch, max_value, color='red', label='Max Point')
    plt.text(max_epoch, max_value, f'{max_value:.2f}', color='red', fontsize=10, ha='right', va='bottom')
    plt.xlabel('Epoch')
    plt.ylabel('V50 Accuracy')
    plt.title('V50 Accuracy Over Epochs')
    plt.legend()
    plt.show()

def plot_v100_acc(data):
    plt.figure()
    plt.plot(data['epoch'], data['v100_acc'], label='V100 Accuracy')
    max_idx = data['v100_acc'].idxmax()
    max_epoch = data['epoch'][max_idx]
    max_value = data['v100_acc'][max_idx]
    plt.scatter(max_epoch, max_value, color='red', label='Max Point')
    plt.text(max_epoch, max_value, f'{max_value:.2f}', color='red', fontsize=10, ha='right', va='bottom')
    plt.xlabel('Epoch')
    plt.ylabel('V100 Accuracy')
    plt.title('V100 Accuracy Over Epochs')
    plt.legend()
    plt.show()

def plot_v50_top5_acc(data):
    plt.figure()
    plt.plot(data['epoch'], data['v50_top5_acc'], label='V50 Top-5 Accuracy')
    max_idx = data['v50_top5_acc'].idxmax()
    max_epoch = data['epoch'][max_idx]
    max_value = data['v50_top5_acc'][max_idx]
    plt.scatter(max_epoch, max_value, color='red', label='Max Point')
    plt.text(max_epoch, max_value, f'{max_value:.2f}', color='red', fontsize=10, ha='right', va='bottom')
    plt.xlabel('Epoch')
    plt.ylabel('V50 Top-5 Accuracy')
    plt.title('V50 Top-5 Accuracy Over Epochs')
    plt.legend()
    plt.show()

def plot_v100_top5_acc(data):
    plt.figure()
    plt.plot(data['epoch'], data['v100_top5_acc'], label='V100 Top-5 Accuracy')
    max_idx = data['v100_top5_acc'].idxmax()
    max_epoch = data['epoch'][max_idx]
    max_value = data['v100_top5_acc'][max_idx]
    plt.scatter(max_epoch, max_value, color='red', label='Max Point')
    plt.text(max_epoch, max_value, f'{max_value:.2f}', color='red', fontsize=10, ha='right', va='bottom')
    plt.xlabel('Epoch')
    plt.ylabel('V100 Top-5 Accuracy')
    plt.title('V100 Top-5 Accuracy Over Epochs')
    plt.legend()
    plt.show()

def plot_average_test_accuracy(file_paths):
        # Read all CSV files and extract the test_accuracy column
    all_data = []
    for file_path in file_paths:
        data = read_csv(file_path)
        if data is not None and 'test_accuracy' in data.columns:
            all_data.append(data['test_accuracy'])
        else:
            print(f"Skipping file {file_path} due to missing 'test_accuracy' column.")
    
    if len(all_data) == 0:
        print("No valid data to process.")
        return
    
    # Combine all test_accuracy columns into a dataframe
    combined_data = pd.DataFrame(all_data).transpose()
    
    # Calculate mean, min, and max
    combined_data['mean_accuracy'] = combined_data.mean(axis=1)
    combined_data['min_accuracy'] = combined_data.min(axis=1)
    combined_data['max_accuracy'] = combined_data.max(axis=1)
    
    # Add an epoch column
    combined_data['epoch'] = range(1, len(combined_data) + 1)
    
    # Plot the average test accuracy
    plt.figure(figsize=(8, 5))
    plt.plot(combined_data['epoch'], combined_data['mean_accuracy'], color='blue', label='Average Test Accuracy')
    
    # Add shaded region for min-max range
    plt.fill_between(
        combined_data['epoch'], 
        combined_data['min_accuracy'], 
        combined_data['max_accuracy'], 
        color='blue', alpha=0.2, label='Min-Max Range'
    )
    
    # Chance level line
    chance_level = 0.005  # Replace with your actual chance level
    plt.axhline(y=chance_level, color='gray', linestyle='--', label='Chance Level')

    # Add labels, title, and legend
    plt.xlabel('Epoch')
    plt.ylabel('Test Accuracy')
    plt.title('Test Accuracy with Min-Max Range Over Epochs')
    plt.legend()
    plt.show()

In [None]:
sub = "sub-08"

base_dir = os.path.join("outputs", "contrast", "ATMS")

date = "12-01_15-53" 

csv_path = os.path.join(base_dir, sub, date, f"ATMS_{sub}.csv")

csv_paths = []
for i in range(10):
    cur_sub = "sub-"
    if i != 9:
        cur_sub = cur_sub + "0" + str(i+1)
    else:
        cur_sub += "10"
    single_csv_path = os.path.join(base_dir, cur_sub, date, f"ATMS_{cur_sub}.csv")
    csv_paths.append(single_csv_path) 

In [None]:
sub, base_dir, csv_path, csv_paths

In [None]:
data = read_csv(csv_path)

In [None]:
plot_average_test_accuracy(csv_paths)

In [None]:

plot_test_loss(data)

In [None]:
# 200 way
plot_test_accuracy(data)

In [None]:
plot_v2_acc(data)

In [None]:
plot_v4_acc(data)

In [None]:
plot_v10_acc(data)

In [None]:
plot_top5_acc(data)

In [None]:
plot_v50_acc(data)

In [None]:
plot_v100_acc(data)

In [None]:
plot_v50_top5_acc(data)

In [None]:
plot_v100_top5_acc(data)