# 1. set up

## 1.1. libraries

In [None]:
import sys
print("print version")
print(sys.version)

import os
import time

import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt

import math # for plotting network
import matplotlib.gridspec as gridspec # for plotting figure 3
import matplotlib.patches as mpatches # for plotting network
from matplotlib.lines import Line2D # for plotting
from matplotlib.patches import Circle # for plotting network
from matplotlib.patches import Patch # for plotting network
from matplotlib.patches import Rectangle # for plotting figure 5
import matplotlib.patches as patches # for plotting figure

import pickle

import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn import preprocessing
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler

import cvxpy as cp

from helper_plot import *

## 1.2. load up variables

In [None]:
path_project = '../'
path_mat_PC_train = f'{path_project}data/mat_PC/train/'
path_mat_PC_test = f'{path_project}data/mat_PC/test/'
path_label = f'{path_project}data/label/'
path_population = f'{path_project}data/population/'
path_plot = f'{path_project}figure/'

df_meta = pd.read_csv(path_population + 'meta_merged.csv')
df_desc = pd.read_csv(path_population + 'desc_subpopulation.tsv', delimiter = '\t')
df_desc = df_desc[['Population code', 'Population description']]
df_desc.set_index('Population code', inplace = True)

dict_desc = df_desc['Population description'].to_dict()
n_PC = 600

n_train = 2560
n_test = 641
n_subpop = 26

mat_GT_train = np.loadtxt(f'{path_mat_PC_train}mat_PC_train.tsv', delimiter = '\t')[:, 0:n_PC]

vec_label_train_index = np.loadtxt(f'{path_label}index_train.csv', delimiter = ',')
vec_label_train_index = vec_label_train_index.astype(int) - 4
vec_label_test_index = np.loadtxt(f'{path_label}index_test.csv', delimiter = ',')
vec_label_test_index = vec_label_test_index.astype(int) - 4

vec_class = pd.read_csv(f'{path_population}population.tsv', delimiter = '\t', header = None)
vec_class = np.array(vec_class.values).flatten()

df_label = pd.read_csv(f'{path_population}meta_merged.csv', delimiter = ',')

vec_label_test = [df_label['SUP'][i] for i in vec_label_test_index]
vec_label_train = [df_label['SUP'][i] for i in vec_label_train_index]

dict_mean_GT_subpop = {}
mat_mean_GT_subpop = np.zeros([n_PC, 26])
for i, cls in enumerate(vec_class):
    vec_index_subpop = [i for i in range(n_train) if vec_label_train[i] == cls]
    mat_mean_GT_subpop[:, i] = sum(mat_GT_train[vec_index_subpop, :]) / len(vec_index_subpop)
    dict_mean_GT_subpop[cls] = mat_mean_GT_subpop[:, i]

## 1.3. functions

In [None]:
def pred_convex(vec_test, label_test, mat_mean_GT_subpop):
    # Define the optimization variable
    X = cp.Variable(n_subpop)

    # Define the objective function (minimize the least squares error)
    objective = cp.Minimize(cp.norm(mat_mean_GT_subpop @ X - vec_test, 'fro'))

    # Define the constraints
    constraints = [X >= 0, cp.sum(X) == 1]

    # Define and solve the problem
    problem = cp.Problem(objective, constraints)
    problem.solve()

    result_pred = X.value
    result_pred_subpop = vec_class[np.argmax(result_pred)]
    
    # Use result_pred to construct a n_PC-dimensional vector made from linear combinations of the mean vectors
    result_pred_PC = mat_mean_GT_subpop @ result_pred

    mean_PC_subpop = dict_mean_GT_subpop[label_test]

    # Compute the cosine similarity between result_pred_PC and vec_test
    cos_sim = np.dot(result_pred_PC, vec_test) / (np.linalg.norm(result_pred_PC) * np.linalg.norm(vec_test))
    
    # Compute the cosine similarity between mean_PC_subpop and vec_test
    cos_sim_subpop = np.dot(mean_PC_subpop, vec_test) / (np.linalg.norm(mean_PC_subpop) * np.linalg.norm(vec_test))

    bool_correct = (result_pred_subpop == label_test)

    return result_pred, result_pred_subpop, cos_sim, cos_sim_subpop, bool_correct

# 2. confusion matrix

## 2.1. train-test split

In [None]:
#vec_name_supp_fig1 = ['No', '90.0%', '99.0%', '99.9%'] # DELETE
# takes about 6.5 minutes
vec_name_supp_fig1 = ['full', '500k', '50k', '5k']

dict_list_pred = {}
for test_i in vec_name_supp_fig1:
    print(f'starting with iteration {test_i}')
    df_result = np.zeros([n_test, 5])
    df_result = pd.DataFrame(df_result)

    mat_GT_test = np.loadtxt(f'{path_mat_PC_test}mat_PC_test_{test_i}.tsv', delimiter = '\t')[0:n_PC, :]
    mat_GT_test = np.transpose(mat_GT_test)

    for index_test in range(n_test):
        
        result_pred, result_pred_subpop, cos_sim, cos_sim_subpop, bool_correct = pred_convex(mat_GT_test[index_test, :], vec_label_test[index_test], mat_mean_GT_subpop)
        df_result.iloc[index_test, :] = [vec_label_test[index_test], result_pred_subpop, bool_correct, cos_sim, cos_sim_subpop]
    
    dict_list_pred[test_i] = df_result[1]

## 2.2. report accuracies per scenario

In [None]:
print(sum(df_result[0] == dict_list_pred['full'])/n_test)
print(sum(df_result[0] == dict_list_pred['500k'])/n_test)
print(sum(df_result[0] == dict_list_pred['50k'])/n_test)
print(sum(df_result[0] == dict_list_pred['5k'])/n_test)

## 2.3. plot

In [None]:
name = 'train_test'
n_col = 2
n_plot = 2
size_font_small = 5
vec_name_fig5 = ['No', '90%', '99%', '99.9%']

# Convert mm to inches for figsize
width_in_inches = 170 / 25.4
height_in_inches = 220 / 25.4

# Create the figure with specified size
fig, axs = plt.subplots(ncols = n_col, nrows = 2,
                        figsize=(width_in_inches, height_in_inches), sharex=True, sharey=True)

for i, i_plot in enumerate(range(4)):
    
    ax_i = axs[0 if i_plot <= 1 else 1, 0 if i_plot % 2 == 0 else 1]
    # ax_i = axs[i]
    
    # Predict the labels for the test set
    y_pred = dict_list_pred[vec_name_supp_fig1[i]]

    # Compute the accuracy
    score_acc = round(accuracy_score(vec_label_test, y_pred)*100, 2)

    # confusion matrix
    cm = confusion_matrix(vec_label_test, y_pred, labels = list(vec_label_sorted))
    df_cm = pd.DataFrame(cm, columns = vec_label_sorted, index = vec_label_sorted)
    
    # annotation array to not annotate 0
    annot_array = np.array([['' if i == 0 else str(i) for i in inner] for inner in df_cm.values])
    
    ax_i = sns.heatmap(df_cm, 
                annot = annot_array, 
                annot_kws = {"size": size_font_small}, 
                fmt='', 
                cmap='Greys',  # Change to grayscale
                ax = ax_i, 
                cbar = True, 
                vmin = 0, 
                vmax = 45)
    
    # Set the font size for colorbar labels
    cbar = ax_i.collections[0].colorbar
    cbar.ax.tick_params(labelsize=size_font_small)
    
    ax_i.set_aspect('equal', 'box')
    ax_i.set_title("Test Data with " + vec_name_fig5[i] + " Missing (accuary: " + str(score_acc) + "%)", 
                   fontsize = size_font_small)
    
    # Set x-axis tick positions and labels
    ax_i.set_xticks([x - 0.5 for x in range(1, len(vec_label_sorted) + 1)])
    ax_i.set_xticklabels(vec_label_sorted, fontsize=size_font_small, ha = 'center')

    # Set y-axis tick positions and labels
    ax_i.set_yticks([y - 0.5 for y in range(1, len(vec_label_sorted) + 1)])
    ax_i.set_yticklabels(vec_label_sorted, fontsize=size_font_small)
    
    ax_i.tick_params(length = 0) # remove tick marks
    
    # Set black gridlines around the plot area
    for _, spine in ax_i.spines.items():
        spine.set_visible(True)
        spine.set_edgecolor('black')
        spine.set_linewidth(1)
        
    # Draw grey gridlines at specific x-axis locations
    # Assuming 'CLM' and 'CDX' are in your column labels
    line_positions = [ax_i.get_xticklabels().index(label) for label in ax_i.get_xticklabels() if label.get_text() in ['CLM', 'CDX', 'CEU', 'BEB']]
    for pos in line_positions:
        ax_i.axvline(pos, color='grey', linestyle='--', linewidth=0.5)
        ax_i.axhline(pos, color='grey', linestyle='--', linewidth=0.5)
    
    for label in (ax_i.get_xticklabels() + ax_i.get_yticklabels()):
        pop_temp = label.get_text()
        color_temp = dict_color_super[df_meta_unique[df_meta_unique['SUP'] == pop_temp]['POP'].values[0]]
        label.set_bbox(dict(facecolor = color_temp, 
                            edgecolor='None', 
                            alpha = 0.5, pad = 1))
        
    # rectangles to show super population
    index_rect = 0
    for sup_i in vec_sup:
        ax_i.add_patch(Rectangle((index_rect, index_rect), 
                                 dict_sup_total[sup_i], dict_sup_total[sup_i], 
                                 fill = False, edgecolor = dict_color_super[sup_i], lw = 2))
        index_rect += dict_sup_total[sup_i]
        
    # Create custom legend
    legend_elements = [Line2D([0], [0], linestyle='-', color=color, label=continent) 
                       for continent, color in dict_color_super.items()]
    
plt.tight_layout()

fig.text(0.5, -0.01, 'Group Classification (Model)', ha = 'center', va = 'center', fontsize = size_font_small)
fig.text(-0.01, 0.5, 'Group Classification (Truth)', ha = 'center', va = 'center', rotation = 'vertical', fontsize = size_font_small)
plt.subplots_adjust(top = 0.82)
fig.savefig(fn_ensure_slash(path_plot) + 'sfigure1/figure_supp_confusion_matrix_four.pdf', format = 'pdf', dpi = 1200)
plt.show()