In [9]:
import os
import numpy as np
import pandas as pd
import csv

In [10]:
model = 'mmit'

In [11]:
prediction_folder_path = 'model/' + model + '/predictions/'

In [12]:
folder_path = 'data'
datasets = [name for name in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, name))]

In [13]:
# Hinged Square Loss
def squared_hinge_loss(predicted, y, margin=1):
    low, high = y[:, 0], y[:, 1]
    loss_low = np.maximum(low - predicted + margin, 0)
    loss_high = np.maximum(predicted - high + margin, 0)
    loss = loss_low + loss_high
    return np.mean(np.square(loss))

In [14]:
# add row to csv
def add_row_to_csv(path, head, row):
    file_exists = os.path.exists(path)              # Check if the file exists
    with open(path, 'a', newline='') as csvfile:    # Open the CSV file in append mode
        writer = csv.writer(csvfile)
        if not file_exists:                         # If the file doesn't exist, write the header
            writer.writerow(head)
        writer.writerow(row)                        # Write the row

In [15]:
def get_loss(dataset, test_fold):
    # Get target
    fold_df = pd.read_csv('data/' + dataset + '/folds.csv')
    target_df = pd.read_csv('data/' + dataset + '/targets.csv')
    target_fold_df = pd.concat([target_df, fold_df], axis=1)
    true_value = target_fold_df[target_fold_df['fold'] == test_fold][['min.log.penalty', 'max.log.penalty']].to_numpy()

    # Get prediction
    pred_df = pd.read_csv(prediction_folder_path + dataset + '.' + str(test_fold) + '.csv')
    pred = pred_df.to_numpy()

    # Compute loss
    return squared_hinge_loss(pred, true_value)

In [16]:
for dataset in datasets:
    # get number of folds
    fold_df = pd.read_csv('data/' + dataset + '/folds.csv')
    n_folds = np.unique(fold_df['fold']).__len__()

    for test_fold in range(1, n_folds + 1):
        loss = get_loss(dataset, test_fold)
        add_row_to_csv('loss_csvs/' + dataset + '.csv', 
                        ['method', 'fold', 'loss'],
                        [model, test_fold, loss])