In [1]:
import os

from utils.path_utils import project_root
import torch

from models.adatime.da.models import get_backbone_class
from models.adatime.da.algorithms import get_algorithm_class

from models.adatime.configs.get_configs import Config

import pandas as pd


In [2]:
# setA = pd.read_pickle(os.path.join(project_root(), 'data', 'tl_datasets', 'final_dataset_pretrain_A.pickle'))
# setA_sepsis = pd.read_csv(os.path.join(project_root(), 'data', 'tl_datasets', 'is_sepsis_pretrain_A.txt'),
#                           header=None)
# print(f"Found {len(setA)}")


In [3]:
# type(setA_sepsis)

In [4]:
# with open(os.path.join(project_root(), 'data', 'tl_datasets', 'is_sepsis_pretrain_A.txt')) as f:
#     is_sepsis = [int(is_sep) for is_sep in f.read().splitlines()]
# 
# print(f"Found {len(is_sepsis)}")

In [5]:
# import numpy as np
# 
# sepsis = pd.Series(is_sepsis)
# 
# positive_sepsis_idxs = sepsis[sepsis == 1].index
# negative_sepsis_idxs = sepsis[sepsis == 0].sample(frac=0.50).index
# all_samples = list(positive_sepsis_idxs) + list(negative_sepsis_idxs)
# np.random.shuffle(all_samples)


In [6]:

import numpy as np
import pandas as pd
import os

import tqdm

from utils.path_utils import project_root

def get_subset_of_setA():
    with open(os.path.join(project_root(), 'data', 'tl_datasets', 'is_sepsis_pretrain_A.txt')) as f:
        is_sepsis = [int(is_sep) for is_sep in f.read().splitlines()]

    print(f"Found {len(is_sepsis)}")
    
    sepsis = pd.Series(is_sepsis)
    positive_sepsis_idxs = sepsis[sepsis == 1].index
    negative_sepsis_idxs = sepsis[sepsis == 0].sample(frac=0.50, random_state=2024).index
    all_samples = list(positive_sepsis_idxs) + list(negative_sepsis_idxs)
    np.random.shuffle(all_samples)
    
    # Filtering files (Making data balanced)
    setA = pd.read_pickle(os.path.join(project_root(), 'data', 
                                       'tl_datasets', 'final_dataset_pretrain_A.pickle'))
    # subsetA = [setA[idx] for idx in all_samples]
    subsetA = []
    for idx in tqdm.tqdm(all_samples, desc='Subset', total=len(all_samples)):
        subsetA.append(setA[idx].drop(['PatientID', 'SepsisLabel'], axis=1))
    subsetA_sepsis = [sepsis[idx] for idx in all_samples]
        
    print(f"Total number of samples for pre-training: {len(subsetA)}")
    
    # Converting files to pt
    
    return subsetA, subsetA_sepsis

subsetA, subsetA_sepsis = get_subset_of_setA()


In [7]:

import tqdm


def csv_to_pt(patient_files, is_sepsis, desc):
    
    all_patients = {'samples': [], 'labels': []}
    
    max_time_step = 336
    for idx, (file, sepsis) in tqdm.tqdm(enumerate(zip(patient_files, is_sepsis)), 
                                                      desc=f"{desc}", 
                                                      total=len(patient_files)):
        
        pad_width = ((0, max_time_step - len(file)), (0, 0))
        file = np.pad(file, pad_width=pad_width, mode='constant').astype(np.float32)
        
        all_patients['samples'].append(torch.from_numpy(file).unsqueeze(0))
        all_patients['labels'].append(torch.tensor(sepsis, dtype=torch.float32).unsqueeze(0))
        
    all_patients['samples'] = torch.cat(all_patients['samples'], dim=0)
    all_patients['labels'] = torch.cat(all_patients['labels'], dim=0)
    
    return {'samples': all_patients['samples'], 'labels': all_patients['labels']}, is_sepsis

all_patients, is_sepsis = csv_to_pt(subsetA, subsetA_sepsis, desc='SubsetA')


In [8]:
torch.save(all_patients, os.path.join(project_root(), 'data', 'tl_datasets', 'pretrain', 'pretrain_subset.pt'))

In [9]:
-

In [None]:
da_name = 'DIRT'
model_path = os.path.join(project_root(), 'results', 'adatime', f'{da_name}',
                              'pretrain_finetune', f'{da_name}_{da_name}_gtn', '0_to_1_run_0',
                              'checkpoint.pt')
pretrained_dict = torch.load(model_path)
pretrained_model = pretrained_dict['best']


In [None]:
pretrained_model

In [None]:
from models.adatime.da.models import get_backbone_class, GTN, classifier, codats_classifier
import torch.nn as nn
from models.adatime.configs.get_configs import Config

config = Config()
original_feature_extactor = GTN(config)
original_classifier = classifier(config)
model = nn.Sequential(original_feature_extactor, original_classifier)
model.load_state_dict(pretrained_model)  # Loading weights


In [None]:
model.state_dict()


In [None]:
da_name = 'Deep_Coral'
model_path = os.path.join(project_root(), 'results', 'adatime', f'{da_name}',
                              'pretrain_finetune', f'{da_name}_{da_name}_gtn', '0_to_1_run_0',
                              'checkpoint.pt')
checkpoint = torch.load(model_path)


In [None]:

def initialize_algorithm(da_method, backbone, configs, device='cuda'):

    # get algorithm class
    algorithm_class = get_algorithm_class(da_method)
    backbone_fe = get_backbone_class(backbone)

    # Initilaize the algorithm
    algorithm = algorithm_class(backbone_fe, configs, device)
    algorithm.to(device)

    return algorithm


In [None]:
# original_model = initialize_algorithm('Deep_Coral', 'GTN', Config())

In [None]:
from models.adatime.da.models import GTN
from models.adatime.da.models import classifier

import torch.nn as nn

feature_extactor = GTN(Config())
classifier = classifier(Config())

original_model = nn.Sequential(feature_extactor, classifier)


In [None]:
original_model.load_state_dict(checkpoint['best'])

In [None]:
original_model