In [7]:
import torch
from torch.utils import data
# from ..src.data import StarmenDataset
from torch.autograd import Variable
import matplotlib.colors as mcolors
import matplotlib
import matplotlib.pyplot as plt
import random
import numpy as np
import logging
import time
import sys
import os
import pandas as pd
from torch.utils.data import Dataset


In [None]:
class StarmenDataset(Dataset):
    def __init__(self, csv_path, nb_subject = None, test_split = 0.1, val_split = 0.1):
        self.p = 0.2
        self.csv_path = csv_path
        self.test_split = test_split
        self.val_split = val_split
        if nb_subject:
            self.nb_subject = nb_subject
        else: 
            self.nb_subject = 1_000
        self.get_infos()


    def __len__(self):
        return len(self.ids)
    

    def get_infos(self):
        self.datas = pd.read_csv(self.csv_path)
        self.datas = self.datas.rename(columns={'t': 'age', 'tau': 'baseline_age', 'path': 'img_path'})

        # Change relative path
        relative_path = self.csv_path.split("df.csv")[0]
        self.datas['img_path'] = self.datas['img_path'].apply(lambda x: os.path.join(relative_path, "images", x.split("/images/")[-1]))

        # Get subject id
        def extract_subject_id(str):
            return str.split("__")[3].split("subject_s")[-1]
        self.datas["id"] = self.datas["id"].apply(extract_subject_id)
        
        self.sanity_check_infos()
        
        # Get list of subject_ids
        self.ids = self.datas["id"].unique()
        
    
    def sanity_check_infos(self):
        """
        Check completeness of the data infos. 
        """
        self.sanity_check = dict()
        check_null = self.datas.isnull().sum()

        check_duplicates = self.datas.duplicated().sum()
        if check_duplicates: 
            self.sanity_check['duplicates'] = self.datas[self.datas.duplicated()]  
        else: self.sanity_check['duplicates'] = None

        # Check img path exists

        self.datas['path_exists'] = self.datas['img_path'].apply(os.path.exists)
        missing_paths = self.datas[~self.datas['path_exists']]
        if missing_paths.empty:
            self.sanity_check["missing_paths"] = None
        else: 
            self.sanity_check['missing_paths'] = missing_paths['img_path'].tolist()
            self.datas = self.datas[self.datas['path_exists']].reset_index(drop=True)
        self.datas.drop(columns=['path_exists'], inplace=True)


        # self.datas = self.datas.loc[:, self.datas.columns != 'id']
        # tau_list = self.datas.iloc[:, 0]
        # alpha_list = self.datas.iloc[:, 1]
        # age_list = self.datas.iloc[:, 2]

        # npy_path_list = self.datas.iloc[:, 3]
        # # change relative path


        # first_age_list = pd.DataFrame(data=[age_list[int(i / 10) * 10] for i in range(10 * subject)], columns=['first_age'])
        # subject_list = pd.DataFrame(data=[int(i / 10) for i in range(10 * subject)], columns=['subject'])
        # timepoint_list = pd.DataFrame(data=[i % 10 for i in range(10 * subject)], columns=['timepoint'])

        # self.datas = pd.concat(
        #     [npy_path_list, subject_list, tau_list, age_list, timepoint_list, first_age_list, alpha_list], axis=1)
        # self.datas = self.datas.rename(columns={'t': 'age', 'tau': 'baseline_age'})



    

    def __getitem__(self, index):
        """Generates one sample of data"""
        # Select sample
        x = self.image_path[index]
        y = self.subject[index]
        z = self.baseline_age[index]
        u = self.age[index]
        v = self.timepoint[index]
        w = self.first_age[index]
        a = self.alpha[index]
        return x, y, z, u, v, w, a



In [51]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Operate on device: ", device)

starmen_info_path = "../data/starmen/output_random_noacc/df.csv"
starmen_ds = StarmenDataset(starmen_info_path)

starmen_ds.datas


Operate on device:  cpu


Unnamed: 0,baseline_age,alpha,age,img_path,id
0,-1.635638,1.000129,-4.29,../data/starmen/output_random_noacc/images/Sim...,0
1,-1.635638,1.000129,-3.24,../data/starmen/output_random_noacc/images/Sim...,0
2,-1.635638,1.000129,-2.20,../data/starmen/output_random_noacc/images/Sim...,0
3,-1.635638,1.000129,-1.15,../data/starmen/output_random_noacc/images/Sim...,0
4,-1.635638,1.000129,-0.10,../data/starmen/output_random_noacc/images/Sim...,0
...,...,...,...,...,...
9995,1.541414,0.999906,-1.64,../data/starmen/output_random_noacc/images/Sim...,999
9996,1.541414,0.999906,-0.68,../data/starmen/output_random_noacc/images/Sim...,999
9997,1.541414,0.999906,0.27,../data/starmen/output_random_noacc/images/Sim...,999
9998,1.541414,0.999906,1.22,../data/starmen/output_random_noacc/images/Sim...,999


In [None]:
Dataset = Dataset_starmen
test = Dataset(test_data['path'], test_data['subject'], test_data['baseline_age'], test_data['age'],
                test_data['timepoint'], test_data['first_age'], test_data['alpha'])

test_loader = torch.utils.data.DataLoader(test, batch_size=10, shuffle=False,
                                            num_workers=0, drop_last=False, pin_memory=True)

In [None]:
# calculate psi
psi = test_data['alpha'] * (test_data['age'] - test_data['baseline_age'])
psi_array = np.linspace(min(psi), max(psi), num=5)
index = [np.nonzero(np.abs(np.array(psi) - p) < 0.05)[0][:2] for p in psi_array]
index = [j for i in index for j in i]


# individual trajectory
subject = [i // 10 for i in index]
subject_img = []
for s in subject:
    subject_img += list(np.arange(s * 10, (s + 1) * 10))
path = test_data.iloc[subject_img, 0]
image = torch.tensor([[np.load(p)] for p in path], device=device).float()
# # indiv_tra, _, _, _ = autoencoder.forward(image)

In [None]:
index_test = [1]

fig, axes = plt.subplots(2 * len(index), 10, figsize=(20, 2 * 2 * len(index)))
plt.subplots_adjust(wspace=0, hspace=0)
for i in range(len(index)):
    for j in range(10):
        axes[2 * i][j].matshow(image[10 * i + j][0].cpu().detach().numpy())
        # axes[2 * i + 1][j].matshow(255 * indiv_tra[10 * i + j][0].cpu().detach().numpy())
for axe in axes:
    for ax in axe:
        ax.set_xticks([])
        ax.set_yticks([])
for axe in axes:
    for ax in axe:
        ax.set_xticks([])
        ax.set_yticks([])