In [1]:
import os
import json
import argparse
import datetime
import pickle

import numpy as np
import torch
from torch.optim import Adam, SGD, RAdam
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from tensorboardX import SummaryWriter
from torchvision import transforms
from sklearn.model_selection import train_test_split
from warmup_scheduler import GradualWarmupScheduler

# from model import *
from betaVAE import *
from read_data import *
from utils import *

In [15]:
# Simulating argparse functionality in a Jupyter Notebook
class Args:
    def __init__(self):
        self.config = '../configs/betavae_tissues.json'  # Example: replace with your default config file
        self.checkpoint = "../checkpoitns"       # Example: replace with your default checkpoint if any
        self.seed = 99
        self.log = 0
        self.parallel = None

# Instantiate the simulated args
args = Args()

# Access the arguments like this:
print(f"Config file: {args.config}")
print(f"Checkpoint: {args.checkpoint}")
print(f"Seed: {args.seed}")
print(f"Log: {args.log}")
print(f"Parallel: {args.parallel}")


Config file: ../configs/betavae_tissues.json
Checkpoint: ../checkpoitns
Seed: 99
Log: 0
Parallel: None


In [16]:
with open(args.config) as f:
    config = json.load(f)

In [17]:
config

{'path_csv': ['../../RNA/data_for_beta_VAE/GTex_Lung_data_SSL_proteincoding.csv',
  '../../RNA/data_for_beta_VAE/GTex_BrainCortex_proteincoding.csv',
  '../../RNA/data_for_beta_VAE/GTex_Liver_proteincoding.csv',
  '../../RNA/data_for_beta_VAE/GTex_Stomach_proteincoding.csv',
  '../../RNA/data_for_beta_VAE/GTex_Pancreas_proteincoding.csv'],
 'patch_data_path': '../../Histology/Lung_Patches256x256/',
 'img_size': 256,
 'max_patch_per_wsi': 100,
 'rna_features': 19198,
 'weights_decay': 0,
 'lr': 5e-05,
 'num_epochs': 500,
 'n_workers': 4,
 'device': 0,
 'flag': 'betavae_proteincoding_tissues',
 'save_dir': '../checkpoints/betavae_training_tissues/',
 'summary_path': '../summaries_betavae_tissues/',
 'log_interval': 20,
 'bag_size': 40,
 'batch_size': 128,
 'beta': 0.0005,
 'quick': 0,
 'optimizer': 'Adam'}

In [18]:
print(10*'-')
print('Config for this experiment \n')
print(config)
print(10*'-')

if 'flag' in config:
    args.flag = config['flag']
else:
    args.flag = 'train_{date:%Y-%m-%d %H:%M:%S}'.format(date=datetime.datetime.now())

if not os.path.exists(config['save_dir']):
    os.mkdir(config['save_dir'])

----------
Config for this experiment 

{'path_csv': ['../../RNA/data_for_beta_VAE/GTex_Lung_data_SSL_proteincoding.csv', '../../RNA/data_for_beta_VAE/GTex_BrainCortex_proteincoding.csv', '../../RNA/data_for_beta_VAE/GTex_Liver_proteincoding.csv', '../../RNA/data_for_beta_VAE/GTex_Stomach_proteincoding.csv', '../../RNA/data_for_beta_VAE/GTex_Pancreas_proteincoding.csv'], 'patch_data_path': '../../Histology/Lung_Patches256x256/', 'img_size': 256, 'max_patch_per_wsi': 100, 'rna_features': 19198, 'weights_decay': 0, 'lr': 5e-05, 'num_epochs': 500, 'n_workers': 4, 'device': 0, 'flag': 'betavae_proteincoding_tissues', 'save_dir': '../checkpoints/betavae_training_tissues/', 'summary_path': '../summaries_betavae_tissues/', 'log_interval': 20, 'bag_size': 40, 'batch_size': 128, 'beta': 0.0005, 'quick': 0, 'optimizer': 'Adam'}
----------


In [76]:
path_csv = config['path_csv']
rna_features = config['rna_features']
batch_size = config.get('batch_size', 64)
encoder_checkpoint = config.get('encoder_checkpoint', None)
beta = config.get('beta', 2)
quick = config.get('quick', 0)
opt = config.get('optimizer', 'Adam')

print('Loading dataset...')

datasets = {
    'train': [],
    'test': [],
    'val': []
}

test_labels = []

Loading dataset...


In [122]:
def normalize_dfs(train_df, val_df, test_df, labels=False, norm_type='standard'):
    def _get_log(x):
        # trick to take into account zeros
        x = np.log(x.replace(0, np.nan))
        return x.replace(np.nan, 0)
    # get list of columns to scale
    rna_columns = [x for x in train_df.columns if 'rna_' in x]
    
    
    # log transform
    train_df[rna_columns] = train_df[rna_columns].apply(_get_log)
    val_df[rna_columns] = val_df[rna_columns].apply(_get_log)
    test_df[rna_columns] = test_df[rna_columns].apply(_get_log)
    
    
    train_df = train_df[rna_columns+['wsi_file_name']]
    val_df = val_df[rna_columns+['wsi_file_name']]
    test_df = test_df[rna_columns+['wsi_file_name']]
    
    rna_values = train_df[rna_columns].values

    if norm_type == 'standard':
        scaler = StandardScaler()
    elif norm_type == 'minmax':
        scaler = MinMaxScaler(feature_range=(0,1))
    rna_values = scaler.fit_transform(rna_values)

    train_df[rna_columns] = rna_values
    test_df[rna_columns] = scaler.transform(test_df[rna_columns].values)
    val_df[rna_columns] = scaler.transform(val_df[rna_columns].values)

    return train_df, val_df, test_df, scaler

In [123]:
for id, dataset in enumerate(path_csv):
    print(dataset)
    df = pd.read_csv(dataset)

    df_transposed = df.set_index('gene_id').transpose().reset_index()
    df_transposed.rename(columns={'index': 'wsi_file_name'}, inplace=True)

    train_df, test_df = train_test_split(df_transposed, test_size=0.2)

    train_df, val_df = train_test_split(train_df, test_size=0.2)

    train_df, val_df, test_df, scaler = normalize_dfs(train_df, val_df, test_df, norm_type='minmax')

    datasets['train'].append(train_df)
    datasets['test'].append(test_df)
    datasets['val'].append(val_df)
    
    test_labels = test_labels + ([id] * test_df.shape[0])

../../RNA/data_for_beta_VAE/GTex_Lung_data_SSL_proteincoding.csv
../../RNA/data_for_beta_VAE/GTex_BrainCortex_proteincoding.csv
../../RNA/data_for_beta_VAE/GTex_Liver_proteincoding.csv
../../RNA/data_for_beta_VAE/GTex_Stomach_proteincoding.csv
../../RNA/data_for_beta_VAE/GTex_Pancreas_proteincoding.csv


In [124]:
if(len(datasets['train']) >=2):
    train_df = pd.concat([datasets['train'][0], datasets['train'][1]])
    val_df = pd.concat([datasets['val'][0], datasets['val'][1]])
    test_df = pd.concat([datasets['test'][0], datasets['test'][1]])
    for i in range(2, len(datasets['train'])):
        train_df = pd.concat([train_df, datasets['train'][i]])
        val_df = pd.concat([val_df, datasets['val'][i]])
        test_df = pd.concat([test_df, datasets['test'][i]])
else:
    train_df = datasets['train'][0]
    val_df = datasets['val'][0]
    test_df = datasets['test'][0]

print('Train shape {}'.format(train_df.shape))
print('Val shape {}'.format(val_df.shape))
print('Test shape {}'.format(test_df.shape))
train_df, val_df, test_df, scaler = normalize_dfs(train_df, val_df, test_df, norm_type='standard')

train_dataset = RNADataset([train_df], quick=quick)
val_dataset = RNADataset([val_df])
test_dataset = RNADataset([test_df])

train_dataloader = DataLoader(train_dataset,batch_size=batch_size, 
               num_workers=4, shuffle=True)
val_dataloader = DataLoader(val_dataset,batch_size=batch_size, 
               num_workers=4, 
               shuffle=False)
test_dataloader = DataLoader(test_dataset,batch_size=1, 
               num_workers=4, 
               shuffle=False)

Train shape (62514, 20899)
Val shape (15634, 20899)
Test shape (19541, 20899)


  result = getattr(ufunc, method)(*inputs, **kwargs)
  result = getattr(ufunc, method)(*inputs, **kwargs)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  train_df[rna_columns] = rna_values
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test_df[rna_columns] = scaler.transform(test_df[rna_columns].values)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-co