In [16]:
import os
import sys
import math
import random

import blobfile as bf
from mpi4py import MPI
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn.functional as F
import pandas as pd
import torch.distributed as dist

In [17]:
file_dir = os.path.abspath('')

data_dir = os.path.join(file_dir, '..','data_preprocessing','generated_data')

train_utils_path = os.path.join(file_dir,'..','train_utils')
if train_utils_path not in sys.path:
    sys.path.append(train_utils_path)
from utils_data import load_TF_data, SequenceDataset

universal_guide_path = os.path.join(file_dir, '..', '..','re_design', 
                                    'Universal-Guided-Diffusion', 'Guided_Diffusion_Imagenet')
if universal_guide_path not in sys.path:
    sys.path.append(universal_guide_path)

from guided_diffusion import dist_util, logger
from guided_diffusion.resample import create_named_schedule_sampler
from guided_diffusion.respace import SpacedDiffusion, space_timesteps
from guided_diffusion import gaussian_diffusion as gd

classifier_guided_path = os.path.join(file_dir, '..', 'classifier_diffusion')
if classifier_guided_path not in sys.path:
    sys.path.append(classifier_guided_path)

from guided_tools import create_model, create_gaussian_diffusion, create_classifier

In [18]:
from dragonnfruit.models import DragoNNFruit, CellStateController

In [19]:
class DiffusionTrainingconfig:
    TESTING_MODE = False
    data_dir=f'{data_dir}/tcre_seq_motif_cluster.csv'
    classifier_checkpoint_path = f'{classifier_guided_path}/classifier_checkpoints/model001000.pt'
    schedule_sampler="uniform"
    lr=1e-4
    weight_decay=0.0
    lr_anneal_steps=0
    batch_size=512
    microbatch=-1  # -1 disables microbatches
    ema_rate="0.9999"  # comma-separated list of EMA values
    log_interval=100
    sample_interval=100
    save_interval=5000
    resume_checkpoint=""
    use_fp16=False
    fp16_scale_growth=1e-3

    subset = None
    num_workers = 8
    kmer_length = 5
    num_classes = 3
    

    # model
    seq_length=200
    image_size=seq_length
    
    num_channels=128
    num_res_blocks=2
    num_heads=4
    num_heads_upsample=-1
    num_head_channels=-1
    attention_resolutions="100,50,25"
    channel_mult=""
    dropout=0.0
    class_cond=True
    use_checkpoint=False
    use_scale_shift_norm=True
    resblock_updown=True
    use_fp16=False
    use_new_attention_order=False

    # Diffusion
    learn_sigma=False
    diffusion_steps=100
    noise_schedule="linear"
    timestep_respacing=""
    sigma_small=False
    use_kl=True
    predict_xstart=False
    rescale_timesteps=False
    rescale_learned_sigmas=False

    # Sampling
    use_ddim=False
    clip_denoised=True
    num_sampling_to_compare_cells = 10
    sample_bs = 1
    use_classifier = True
    classifier_scale=10 # between 1 and 10 trade off between diversity and fidelity
    run_name = "" # for gimme scan erros when doing several runs simultaneously 

    # Classifier
    classifier_use_fp16=False
    classifier_width=256
    classifier_depth=3
    classifier_attention_resolutions="100,50,25"  # 16
    classifier_use_scale_shift_norm=True  # False
    classifier_resblock_updown=True  # False
    classifier_pool="spatial"


config = DiffusionTrainingconfig()

dist_util.setup_dist()
logger.configure()

Logging to /tmp/openai-2024-09-15-21-53-44-836788


In [20]:
logger.log("creating data loader...")
data = load_TF_data(
    data_path=config.data_dir,
    seqlen=config.seq_length,
    limit_total_sequences=config.subset,
    num_sampling_to_compare_cells=config.num_sampling_to_compare_cells,
    to_save_file_name="cre_encode_data_motif_cluster",
    saved_file_name="cre_encode_data_motif_cluster.pkl",
    load_saved_data=True,
    start_label_number = 0,
)

creating data loader...


KeyboardInterrupt: 

In [7]:
from universal_tools import (
    create_model_and_diffusion,
    OperationArgs, 
    SamplingArgs
)

In [8]:
logger.log("creating model and diffusion...")
model, diffusion = create_model_and_diffusion(config)

model.to(dist_util.dev())

schedule_sampler = create_named_schedule_sampler(config.schedule_sampler, diffusion)

creating model and diffusion...
creating model and diffusion...


In [10]:
class RegularizedDynamicBPNetCounts(torch.nn.Module):

	def __init__(self, controller, n_filters=128, n_layers=8, trimming=None, 
		conv_bias=False, n_outputs=1, dropout_rate=0.2):
		super(RegularizedDynamicBPNetCounts, self).__init__()

		self.trimming = trimming if trimming is not None else 2 ** n_layers + 37
		self.n_filters = n_filters
		self.n_layers = n_layers
		self.n_outputs = n_outputs
		self.dropout_rate = dropout_rate
		self.iconv = torch.nn.Conv1d(4, n_filters, kernel_size=21, padding=10)
		self.irelu = torch.nn.ReLU()
		self.idropout = torch.nn.Dropout(p=dropout_rate)
  		
		self.deconv_kernel_size = 75
		self.fconv = torch.nn.Conv1d(self.n_filters, self.n_outputs, kernel_size=self.deconv_kernel_size,
			bias=conv_bias)

		self.biases = torch.nn.ModuleList([
			torch.nn.Linear(controller.n_outputs, n_filters) for i in range(
				n_layers)
		])

		self.convs = torch.nn.ModuleList([
			torch.nn.Conv1d(n_filters, n_filters, kernel_size=3, stride=1, 
				dilation=2**i, padding=2**i, bias=conv_bias) for i in range(1, 
					n_layers+1)
		])

		self.relus = torch.nn.ModuleList([
			torch.nn.ReLU() for i in range(n_layers)
		])

		self.linear = torch.nn.Linear(n_filters, 1)

		self.dropouts = torch.nn.ModuleList([
			torch.nn.Dropout(p=dropout_rate) for _ in range(n_layers)
		])
	   
		self.controller = controller

	def forward(self, X, cell_states):

		start, end = self.trimming, X.shape[2] - self.trimming
		cell_states = self.controller(cell_states)
		X = self.irelu(self.iconv(X))
		X = self.idropout(X)
		for i in range(self.n_layers):
			X_conv = self.convs[i](X)
			X_bias = self.biases[i](cell_states).unsqueeze(-1)			
			X = X + self.relus[i](X_conv + X_bias)
			X = self.dropouts[i](X)

		X = X[:, :, start - self.deconv_kernel_size//2 : end + self.deconv_kernel_size//2]
		y_profile = self.fconv(X)
		X = torch.mean(X, axis=2)
		y_counts = self.linear(X)

		return y_profile, y_counts

In [11]:
class scBPnetCounts(DragoNNFruit):
	def __init__(self, accessibility, name, alpha=1, scale_log_rd=False, n_outputs=2):
		torch.nn.Module.__init__(self)
		self.accessibility = accessibility
		self.name = name
		self.alpha = alpha
		self.n_outputs = n_outputs
		self.scale_log_rd = scale_log_rd
		# self.logger = Logger(["Epoch", "Iteration", "Training Time",
		# 	"Validation Time", "Training MNLL", "Validation MNLL",
		# 	"Validation Profile Correlation", "Validation Count Correlation", 
		# 	"Saved?"], verbose=True)

	def forward(self, X, cell_states):

		return self.accessibility(X, cell_states)

	def log_softmax(self, y_profile):
		y_profile = y_profile.reshape(y_profile.shape[0], -1)
		y_profile = torch.nn.LogSoftmax(dim=-1)(y_profile)
		y_profile = y_profile.reshape(y_profile.shape[0], self.n_outputs, -1)
		return y_profile
	
	def predict(self, X, cell_states, batch_size=64, logits = False):
		with torch.no_grad():
			starts = np.arange(0, X.shape[0], batch_size)

			ends = starts + batch_size

			y_profiles, y_counts = [], []
			for start, end in zip(starts, ends):
				X_batch = X[start:end]
				cell_states_batch = cell_states[start:end]
				y_profiles_, y_counts_ = self(X_batch,cell_states_batch)
				if not logits:  # apply softmax
					y_profiles_ = self.log_softmax(y_profiles_)
				y_profiles.append(y_profiles_.cpu().detach().numpy())
				y_counts.append(y_counts_.cpu().detach().numpy())

			y_profiles = np.concatenate(y_profiles)
			y_counts = np.concatenate(y_counts)
			return y_profiles, y_counts

In [21]:
guide_config = {
    'cell_state_dim':50,
    'ctrl_nodes':256,
    'ctrl_layers':1, #0
    'ctrl_outputs':64,
    'bp_n_filters':128,
    'conv_layers':8,
    'trimming':None,
    'dropout_rate':0.2,
    
}

In [22]:
controller = CellStateController(
		n_inputs=guide_config['cell_state_dim'], 
		n_nodes=guide_config['ctrl_nodes'], 
		n_layers=guide_config['ctrl_layers'], 
		n_outputs=guide_config['ctrl_outputs'],
 	)

In [23]:
accessibility_model = RegularizedDynamicBPNetCounts(
			controller=controller,
			n_filters=guide_config['bp_n_filters'], 
			n_layers=guide_config['conv_layers'], 
			trimming=guide_config['trimming'], 
			dropout_rate=guide_config['dropout_rate'], 
			n_outputs=2,
		)

In [24]:
guide_function = scBPnetCounts(accessibility_model, name=None)

In [7]:
def model_fn(x, t, y=None, args=None, model=None):
    # assert y is not None
    return model(x, t, y if args.class_cond else None)

def cond_fn(x, t, y=None, args=None, guide_function=None):
    assert y is not None
    with torch.enable_grad():
        x_in = x.detach().requires_grad_(True)
        logits = guide_function(x_in, t)
        log_probs = F.log_softmax(logits, dim=-1)
        selected = log_probs[range(len(logits)), y.view(-1)]
        return torch.autograd.grad(selected.sum(), x_in)[0]

In [8]:
from functools import partial

In [10]:
operation_config = OperationArgs()
sampling_config = SamplingArgs()

In [12]:
cell_num_list = data['cell_types']
cell_list = list(data["numeric_to_tag"].values())
nucleotides = ["A", "C", "G", "T"]
sample_bs = config.sample_bs
num_samples = round(config.num_sampling_to_compare_cells/sample_bs)
generated_celltype_motif = {}
seqlen = config.seq_length
model.eval()
# iterate over cell types
for cell_num in cell_num_list:
    cell_type = data['numeric_to_tag'][cell_num]
    print(f"Generating {config.num_sampling_to_compare_cells} samples for cell_type {cell_type}")
    final_sequences = []
    for n_a in range(num_samples):
        model_kwargs = {}
        sampled_cell_types = np.array([cell_num] * sample_bs)
        classes = torch.from_numpy(sampled_cell_types).to(dist_util.dev())
        model_kwargs["y"] = classes
        sample = diffusion.ddim_sample_loop_operation(
            partial(model_fn, model=model, args=config),
            (config.sample_bs, 1, 4, config.image_size),
            operated_image=None,
            operation=operation_config,
            clip_denoised=config.clip_denoised,
            model_kwargs=model_kwargs,
            cond_fn=partial(cond_fn, guide_function=guide_function),
            device=torch.device('cuda'),
            progress=sampling_config.progressive
        ).squeeze(1)
        for n_b, x in enumerate(sample):
            sequence = "".join([nucleotides[s] for s in np.argmax(x.detach().cpu(), axis=0)])
            seq_final = f">seq_test_{n_a}_{n_b}\n" + sequence
            final_sequences.append(seq_final)

Generating 10 samples for cell_type ct2


100%|██████████| 100/100 [00:02<00:00, 44.73it/s]
100%|██████████| 100/100 [00:02<00:00, 45.00it/s]
100%|██████████| 100/100 [00:02<00:00, 46.51it/s]
100%|██████████| 100/100 [00:02<00:00, 47.82it/s]
100%|██████████| 100/100 [00:02<00:00, 49.01it/s]
100%|██████████| 100/100 [00:02<00:00, 47.71it/s]
100%|██████████| 100/100 [00:02<00:00, 46.99it/s]
100%|██████████| 100/100 [00:01<00:00, 50.29it/s]
100%|██████████| 100/100 [00:02<00:00, 46.41it/s]
100%|██████████| 100/100 [00:02<00:00, 43.75it/s]


NameError: name 'extract_motifs' is not defined

In [13]:
final_sequences


['>seq_test_0_0\nAAGAAAAAAAAACAGACCAAATAAAAACGGCCAAGAAATAAACACTACCAATCGAGGCAGCTAAGATACACCAGCCATATAACAACAAAGCGACATCAAAAGACCAAAAACACAAAACAACCAAGGAAATGAAAAACGACAAACCACCAGGAATCCACCAAATCCACCACATCAAGCCACCATCCGAACACAACTACGCA',
 '>seq_test_1_0\nACAACATACCTGCCACAAAAACAAACCAGCACAACAGAATGACAAACAAAAAACCTAAACACACACGAAGCAATGCCCAAAAACTAAGCAGCAACCCCAAAGGAGAAAAAAAGCAAACGCACAACAAACCGAATCCAGAAAACCATCACTGCAGCAGCGACAAAAGGGAACCTGACAATCCCCAAGAACGAAAGACAGGA',
 '>seq_test_2_0\nAATGAGAATGCCAACAGAAAGACGAACCACACAATAACGCGCAGCATGAATGAACAGACAACAGAAAAAACACAAACAGAACAGCAAACCTCAGAAAAACAGAACAAAAAAACAACACAACAAAAAGACTAAAAAAGACAAAGAGAACTGGAAAAACCAAAAAAAACGATCCCCCAGACAAAAGTCAACAAAACAAAAAA',
 '>seq_test_3_0\nAAATCAAGTAGCAGGACGAAGATAAAGCCAATAAAAACAAAAGCCAAGAAAAAAAAATGCCCTAAACTCGCGCACAAGGAGACAACAAATTCCGAAAAGAAACAAAGAGACAAAATCAAAGAAATCGTGACACACAACAAAGTACAGAAATACAAAACAAATCCAAACAAAACCAGAAAGAATGGAAAACTACGACTGAT',
 '>seq_test_4_0\nCCGGCCCACAGAAAGCACGAACACCCATATCACTAACGAAACAGGCATACACCACAAGCTCACACTTGAGGCAAAGAAACCAAAAACAACCAGAACCACAACA