In [1]:
import os
import sys

from accelerate import DataLoaderConfiguration
from accelerate import Accelerator, DistributedDataParallelKwargs
import wandb
import torch
import datetime

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
file_dir = os.path.abspath('')

save_dir =  os.path.join(file_dir,'train_output')
data_dir = os.path.join(file_dir, '..','data_preprocessing','generated_data')

train_utils_path = os.path.join(file_dir,'..','train_utils')
if train_utils_path not in sys.path:
    sys.path.append(train_utils_path)
from utils import (
    plot_training_loss,
    plot_training_validation,
    compare_motif_list
)
from utils_data import load_TF_data, load_TF_data_bidir, call_motif_scan

universal_guide_path = os.path.join(file_dir, '..', '..','re_design', 
                                    'Universal-Guided-Diffusion', 'Guided_Diffusion_Imagenet')
if universal_guide_path not in sys.path:
    sys.path.append(universal_guide_path)

from guided_diffusion import logger
from guided_diffusion.resample import create_named_schedule_sampler

from universal_models import (
    create_model_and_diffusion,
    scBPnetGuide,
)

from universal_tools import (
    OperationArgs,
    TrainLoop
)

In [20]:
class DiffusionTrainingConfig:
    datafile=f'{data_dir}/sampled_df_columns_renamed_bidir.csv'
    device = 'cuda'
    pikle_filename = "cre_expr_bidir_512seqlength_classcond_3clusters"
    run_name = "512seqlength_bidir_activecre" # for gimme scan erros when doing several runs simultaneously 
    guide_checkpoint = "guide_checkpoints/model_e9_b9901.pth"
    load_data = True
    subset = None
    num_workers = 4
    seq_length=512
    
    batch_size=256
    train_cond = True # whether to train conditionally
    lr=1e-4
    weight_decay=0.1
    lr_anneal_steps=2000
    
    log_interval=20
    sample_interval=20
    save_interval=1000000
    resume_checkpoint=""
    # use_fp16=False
    # fp16_scale_growth=1e-3

    # Sampling
    # use_ddim=False
    clip_denoised=True
    
    sampling_subset_random = 50 # number of cell types to subset for faster sampling
    num_cre_counts_per_cell_type = 200
    parallel_generating_bs = 256 # used for parallel sampling, adjust based on GPU memory capabilities
    get_seq_metrics = False
    get_kmer_metrics_bulk = False
    get_kmer_metrics_labelwise = False # not implemented, takes too long
    kmer_length = 5   

    # model
    class_cond=False
    use_checkpoint=False # not implemented
    image_size=seq_length
    ema_rate=0.995
    num_channels=256
    num_res_blocks=2
    num_heads=2
    num_heads_upsample=-1
    num_head_channels=64
    attention_resolutions=""
    # channel_mult=""
    dropout=0.1
    
    use_scale_shift_norm=True
    resblock_updown=True
    use_new_attention_order=False

    # Diffusion
    schedule_sampler="uniform" # or "loss-second-moment"
    learn_sigma=True
    diffusion_steps=100
    noise_schedule="linear" # or "cosine"
    timestep_respacing=""
    sigma_small=False
    use_kl=True
    predict_xstart=False
    rescale_timesteps=True
    rescale_learned_sigmas=False

    # to send config to wandb
    @classmethod
    def to_dict(cls):
        return {k: v for k, v in cls.__dict__.items() if not k.startswith('__') and not callable(v) and k != 'to_dict'}


config = DiffusionTrainingConfig()

In [21]:
data = load_TF_data_bidir(
    data_path=config.datafile,
    seqlen=config.seq_length,
    limit_total_sequences=config.subset,
    to_save_file_name=config.pikle_filename,
    saved_file_name=config.pikle_filename + ".pkl",
    load_saved_data=config.load_data,
    train_cond = config.train_cond,
    run_name = config.run_name,
)

In [36]:
import numpy as np
_, arr = np.unique(data['x_train_cell_type'], return_inverse=True)

In [37]:
arr

array([1, 0, 1, ..., 1, 0, 2])

In [38]:
embed = torch.nn.Embedding(3,10)
embed(torch.tensor(arr))

tensor([[ 2.3654,  0.4050,  1.9861,  ..., -0.3083,  0.1713,  1.1925],
        [-0.3819, -0.6687, -0.4859,  ...,  2.8117,  1.1389,  1.2058],
        [ 2.3654,  0.4050,  1.9861,  ..., -0.3083,  0.1713,  1.1925],
        ...,
        [ 2.3654,  0.4050,  1.9861,  ..., -0.3083,  0.1713,  1.1925],
        [-0.3819, -0.6687, -0.4859,  ...,  2.8117,  1.1389,  1.2058],
        [-0.8090, -0.2480,  0.6853,  ...,  0.5333, -0.1603, -0.7974]],
       grad_fn=<EmbeddingBackward0>)

In [5]:
compare_motif_list(data['train_motifs'],data['shuffle_motifs'])

0.26935696149822763

run_config:
{'cell_state_dim': 50,
 'ctrl_nodes': 256,
 'ctrl_layers': 1,
 'ctrl_outputs': 64,
 'bp_n_filters': 512,
 'conv_layers': 10,
 'trimming': 512,
 'dropout_rate': 0.0}

In [None]:
# import pickle
# example_output_profile = "matt_code/example_y.pkl"
# example_cell_states = "matt_code/example_c.pkl"
# example_input_sequences = "matt_code/example_X.pkl"

# with open(example_input_sequences, 'rb') as f:
#     X = pickle.load(f)
# with open(example_output_profile, 'rb') as f:
#     y = pickle.load(f)
# with open(example_cell_states, 'rb') as f:
#     c = pickle.load(f)


In [7]:
label_vector = torch.rand((128,52))

In [13]:
X = torch.rand((128,1,4,512))

In [16]:
guide_model = scBPnetGuide(config.guide_checkpoint)

In [4]:
model, diffusion = create_model_and_diffusion(config)
schedule_sampler = create_named_schedule_sampler(config.schedule_sampler, diffusion)
# model.to(dist_util.dev())


In [9]:
operation_config = OperationArgs()

In [13]:
dataloader_config = DataLoaderConfiguration(split_batches=False)
accelerator = Accelerator(
            # kwargs_handlers=[ddp_kwargs],
            dataloader_config = dataloader_config, 
            cpu= (config.device == "cpu"), 
            mixed_precision= None, 
            log_with=['wandb'])

trainloop = TrainLoop(
    config=config,
    operation_config=operation_config,
    model=model,
    diffusion=diffusion,
    guide_model=guide_model,
    accelerator=accelerator,
    data=data,
    batch_size=config.batch_size,
    lr=config.lr,
    log_interval=config.log_interval,
    sample_interval = config.sample_interval,
    save_interval=config.save_interval,
    resume_checkpoint=config.resume_checkpoint,
    schedule_sampler=schedule_sampler,
    lr_anneal_steps=config.lr_anneal_steps,
    run_name = config.run_name,
)
# os.environ['WANDB_SILENT']="true"
# os.environ["WANDB_MODE"] = "offline"
# wandb_config = {"learning_rate": config.lr, "num_sampling_to_compare_cells": config.num_sampling_to_compare_cells, "batch_size": config.batch_size}
# wandb.init(project="universal_diffusion", config=wandb_config)

# trainloop.run_loop()


In [14]:
trainloop.initialize_sampling()
a = trainloop.process_label_array

Process 0 handles 231 samples.


In [15]:
a

[(331, 0.0, 0.0),
 (16, 0.0, 0.0),
 (252, 0.0, 0.0),
 (252, 0.0, 0.0),
 (285, 0.0, 0.0),
 (285, 0.0, 0.0),
 (285, 0.0, 0.0),
 (285, 0.0, 0.0),
 (334, 0.0, 0.0),
 (334, 0.0, 0.0),
 (7, 0.0, 0.0),
 (7, 0.0, 0.0),
 (116, 0.0, 0.0),
 (273, 0.0, 0.0),
 (273, 0.0, 0.0),
 (289, 0.0, 0.0),
 (162, 0.0, 0.0),
 (162, 0.0, 0.0),
 (162, 0.0, 0.0),
 (375, 0.0, 0.0),
 (375, 0.0, 0.0),
 (90, 0.0, 0.0),
 (17, 0.0, 0.0),
 (304, 0.0, 0.0),
 (288, 0.0, 0.0),
 (288, 0.0, 0.0),
 (190, 0.0, 0.0),
 (190, 0.0, 0.0),
 (190, 0.0, 0.0),
 (190, 0.0, 0.0),
 (40, 0.0, 0.0),
 (144, 0.0, 0.0),
 (431, 0.0, 0.0),
 (431, 0.0, 0.0),
 (431, 0.0, 0.0),
 (431, 0.0, 0.0),
 (431, 0.0, 0.0),
 (210, 0.6856283387397418, 0.4357586612602581),
 (22, 0.0, 0.0),
 (247, 0.0, 0.0),
 (247, 0.0, 0.0),
 (79, 0.0, 0.0),
 (79, 0.0, 0.0),
 (79, 0.0, 0.0),
 (348, 0.0, 0.0),
 (348, 0.0, 0.0),
 (348, 0.0, 0.0),
 (207, 0.0, 0.0),
 (207, 0.0, 0.0),
 (207, 0.0, 0.0),
 (224, 0.2373365954795822, 1.8384822045204177),
 (446, 0.0, 0.0),
 (446, 0.8698639

In [11]:
data_loader = trainloop.data_loader
x,y = next(iter(data_loader))
x = x.to('cuda')
y =y.to('cuda')
guide_model(x,y)

tensor([5327486.0000, 5806657.5000, 8008878.0000, 4447922.0000, 4476182.0000,
        4148849.7500, 5517874.5000, 4093083.7500, 5503929.0000, 6678124.5000,
        6169298.0000, 6016802.0000, 3776196.7500, 4889560.5000, 9019005.0000,
        5473608.5000, 6333264.0000, 5981333.5000, 5266115.0000, 5558643.0000,
        3591988.7500, 4942362.0000, 7957057.5000, 5530094.0000, 4311635.5000,
        7956750.5000, 4305153.0000, 4951033.5000, 4091515.0000, 4544701.5000,
        9660567.0000, 4609254.5000], device='cuda:0')

In [12]:
trainloop.create_sample_labelwise()

generating:   0%|          | 0/328 [00:00<?, ?it/s]

 12%|█▏        | 120/1000 [23:20<2:51:10, 11.67s/it]
generating:   0%|          | 0/328 [23:20<?, ?it/s]


KeyboardInterrupt: 

In [14]:
model_output = torch.ones((30, 1, 4, 1024))


1

In [None]:
model_output,_ = torch.split(model_output, 1, dim=1)