In [None]:
import argparse
import numpy as np
import torch
# import wandb
from SyntheticTreeDataset import *
from OdorDataset import OdorMonoDataset
from utils.helpers import *
from methods import *
from optimizers import *
from torch.utils.data import DataLoader
from utils.visualization import *
import uuid
from utils.helpers import set_seeds
from sklearn.neighbors import kneighbors_graph
import scipy
from torch.optim import Adam

from distances import (
    distance_matrix,
    euclidean_distance,
    poincare_distance,
    knn_geodesic_distance_matrix,
    knn_graph_weighted_adjacency_matrix,
    # hamming_distance_matrix
)

from sklearn.manifold import TSNE

from sklearn.decomposition import PCA


### If using Jupyter Notebook:###
import sys
if 'ipykernel_launcher' in sys.argv[0]:
    sys.argv = sys.argv[:1]
###




if __name__ == "__main__":

    parser = argparse.ArgumentParser('Hyperbolic Smell')
    parser.add_argument('--data_type', type=str, default='labels' , choices={"representation","labels"}) #label or batch
    parser.add_argument('--representation_name', type=str, default='pom', choices={"molformer","pom"})
    parser.add_argument('--batch_size', type=int, default=195) #195
    parser.add_argument('--num_epochs', type=int, default=5001) #100
    # parser.add_argument('--min_dist', type=float, default=1.)
    parser.add_argument('--latent_dim', type=int, default=2)
    parser.add_argument('--lr', type=float, default=0.1)
    # parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--base_dir', type=str,
                        default='./data/')
    parser.add_argument('--dataset_name', type=str, default='gslf' , choices={"gslf","ravia","keller","sagar","sagarfmri"})  # tree for synthetic, gslf for real
    parser.add_argument('--normalize', type=bool, default=False) #* # only for Hyperbolic embeddings
    parser.add_argument('--optimizer', type=str, default='standard', choices=['standard', 'poincare', 'Adam']) #*
    parser.add_argument('--model_name', type=str, default='mds', choices=['isomap', 'mds', 'contrastive'])
    parser.add_argument('--latent_dist_fun', type=str, default='euclidean', choices=['euclidean', 'poincare']) #*
    parser.add_argument('--distr', type=str, default='gaussian', choices=['gaussian', 'hypergaussian']) #*
    parser.add_argument('--distance_method', type=str, default='euclidean',
                        choices=['geo', 'graph', 'hamming', 'euclidean','similarity']) #'euclidean' for sagar/keller, 'similarity' for ravia
    parser.add_argument('--n_samples', type=int, default=4000)
    parser.add_argument('--dim', type=int, default=768)
    parser.add_argument('--depth', type=int, default=5)  # Changed from bool to int
    parser.add_argument('--temperature', type=float, default=0.1)  # 0.1 #100
    parser.add_argument('--n_neighbors', type=int, default=20) # 20 #10
    parser.add_argument('--epsilon', type=float, default=10.0) #
    parser.add_argument('--roi', type=str, default=None,choices=["OFC", "PirF","PirT","AMY",None]) #
    parser.add_argument('--subject', type=float, default=None,choices=[1,2,3,None]) #
    parser.add_argument('--filter_dragon', type=bool, default=False) #for chemical data
    # args = argparse.Namespace()
    args = parser.parse_args()

    if torch.cuda.is_available():
        args.device = torch.device('cuda')
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        print('Using GPU')
    else:
        args.device = torch.device('cpu')
        args.gpu_index = -1
        print('Using CPU')

    args.random_string = uuid.uuid4().hex
    data_type = args.data_type
    dataset_name = args.dataset_name
    representation_name = args.representation_name
    num_epochs = args.num_epochs
    normalize = args.normalize
    latent_dim = args.latent_dim
    lr = args.lr
    seed = args.seed
    base_dir = args.base_dir
    optimizer = args.optimizer
    model_name = args.model_name
    latent_dist_fun = args.latent_dist_fun
    distr = args.distr
    distance_method = args.distance_method
    n_neighbors = args.n_neighbors
    epsilon = args.epsilon
    temperature = args.temperature
    subject = args.subject
    roi = args.roi
    ### Overwrite the batchsize ###
    depth = args.depth
    # args.batch_size = 2 ** args.depth - 1  # to get full batch
    batch_size = args.batch_size
    filter_dragon = args.filter_dragon
#    set_seeds(seed)

    if distance_method == 'similarity' and dataset_name not in ['ravia']:
        raise ValueError('Similarity distance method can only be used with Ravia dataset')

    if dataset_name == 'tree':
        embeddings, labels = get_tree_data(depth)
        labels = torch.tensor(labels)
        embeddings = torch.tensor(embeddings)
        ## binary_tree is a dataset of binary sequences.
        ## The root of the tree is the node 0: binary_tree[0]
        ## groundtruth distance from node i to the root of the tree (i.e. shortest path distance from node i to the root): hamming_distance(binary_tree[0], binary_tree[i])
        ## For visualizations, one can color a node by its groundtruth distance to the tree.
    elif dataset_name == 'random':
        #todo do we need this?
        embeddings = torch.randn(n_samples, dim)
    elif dataset_name in ['gslf', 'keller' , 'sagar']: ### If multiple subjects, to average among them put grand_avg=True. If individual subjects then put grand_avg=False and below use select_subjects function
        input_embeddings = f'embeddings/{representation_name}/{dataset_name}_{representation_name}_embeddings_13_Apr17.csv'
        embeddings, labels,subjects,CIDs = read_embeddings(base_dir, select_descriptors(dataset_name), input_embeddings,
                                             grand_avg=True if (dataset_name == 'keller' or (dataset_name == 'saagar' and subject==None)) else False)
        # embeddings, labels,subjects,CIDs = read_embeddings(base_dir, select_descriptors(dataset_name), input_embeddings,
        #                                      grand_avg=True if dataset_name == 'keller' or dataset_name=='sagar' else False) #grand_avg averages among subjects so put false for analyzing each subject individually
        # embeddings, labels,subjects,CIDs = read_embeddings(base_dir, select_descriptors(dataset_name), input_embeddings,
        #                                      grand_avg=True if dataset_name=='sagar' else False) 
 
        if filter_dragon:
            embeddings, labels, subjects, CIDs, embeddings_chemical=read_dragon_features(embeddings, labels, subjects, CIDs)
            args.representation_name = 'chemical'
            embeddings =  embeddings_chemical
        #embeddings = 100000 * torch.randn(4983, 20)
        
        #To perform PCA or t-SNE on MolFormer or POM enbeddings:
        # X_embedded = TSNE(n_components=2, learning_rate='auto',
        #          init='random', perplexity=300).fit_transform(embeddings)
        #X_embedded = PCA(n_components=2).fit_transform(embeddings)

        #Embed labels:
        # X_embedded = TSNE(n_components=2, learning_rate='auto',
        #          init='random', perplexity=1000).fit_transform(labels)
        #X_embedded = PCA(n_components=2).fit_transform(labels)


        # X_embedded = PCA(n_components=20).fit_transform(embeddings)
        # embeddings = torch.tensor(X_embedded, dtype=torch.float32)  # Convert to a PyTorch tensor

        # print('embeddings after PCA', embeddings.shape)