In [1]:
import os
import sys
import math
import random

from PIL import Image
import blobfile as bf
from mpi4py import MPI
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torch
import pandas as pd
import torch.distributed as dist

In [2]:
file_dir = os.path.abspath('')

data_dir = os.path.join(file_dir, '..','data_preprocessing','generated_data')

guided_diff_path = os.path.join(file_dir, '..', '..','re_design', 'guided-diffusion')
if guided_diff_path not in sys.path:
    sys.path.append(guided_diff_path)

from guided_diffusion import dist_util, logger
from guided_diffusion.resample import create_named_schedule_sampler

train_utils_path = os.path.join(file_dir,'..','train_utils')
if train_utils_path not in sys.path:
    sys.path.append(train_utils_path)
from utils_data import load_TF_data, SequenceDataset
from guided_tools import (
    get_data_generator,
    TrainLoop,
    create_model,
    create_gaussian_diffusion,
)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class DiffusionTrainingconfig:
    data_dir=f'{data_dir}/tcre_seq_motif_cluster.csv'
    schedule_sampler="uniform"
    lr=1e-4
    weight_decay=0.0
    lr_anneal_steps=0
    batch_size=16
    microbatch=-1  # -1 disables microbatches
    ema_rate="0.9999"  # comma-separated list of EMA values
    log_interval=10
    sample_interval=10000
    save_interval=10000
    resume_checkpoint=""
    use_fp16=False
    fp16_scale_growth=1e-3

    subset = None
    num_workers = 1
    num_sampling_to_compare_cells = 1000
    kmer_length = 5
    num_classes = 3
    

    # model
    seq_length=200
    image_size=seq_length
    
    num_channels=128
    num_res_blocks=2
    num_heads=4
    num_heads_upsample=-1
    num_head_channels=-1
    attention_resolutions="16,8"
    channel_mult=""
    dropout=0.0
    class_cond=True
    use_checkpoint=False
    use_scale_shift_norm=True
    resblock_updown=True
    use_fp16=False
    use_new_attention_order=False

    # Diffusion
    learn_sigma=False
    diffusion_steps=1000
    noise_schedule="linear"
    timestep_respacing=""
    use_kl=False
    predict_xstart=False
    rescale_timesteps=False
    rescale_learned_sigmas=False

    # Sampling
    use_ddim=False
    clip_denoised=True
    sample_bs = 100


config = DiffusionTrainingconfig()

dist_util.setup_dist()
logger.configure()


Logging to /tmp/openai-2024-07-10-20-59-08-351589


In [4]:
logger.log("creating data loader...")
data = load_TF_data(
    data_path=config.data_dir,
    seqlen=config.seq_length,
    limit_total_sequences=config.subset,
    num_sampling_to_compare_cells=config.num_sampling_to_compare_cells,
    to_save_file_name="cre_encode_data_motif_cluster",
    saved_file_name="cre_encode_data_motif_cluster.pkl",
    load_saved_data=True,
    start_label_number = 0,
)



creating data loader...


In [10]:
logger.log("creating model and diffusion...")
model = create_model(
        config.image_size,
        config.num_channels,
        config.num_res_blocks,
        channel_mult=config.channel_mult,
        learn_sigma=config.learn_sigma,
        class_cond=config.class_cond,
        use_checkpoint=config.use_checkpoint,
        attention_resolutions=config.attention_resolutions,
        num_heads=config.num_heads,
        num_head_channels=config.num_head_channels,
        num_heads_upsample=config.num_heads_upsample,
        use_scale_shift_norm=config.use_scale_shift_norm,
        dropout=config.dropout,
        resblock_updown=config.resblock_updown,
        use_fp16=config.use_fp16,
        use_new_attention_order=config.use_new_attention_order,
    )
diffusion = create_gaussian_diffusion(
        steps=config.diffusion_steps,
        learn_sigma=config.learn_sigma,
        noise_schedule=config.noise_schedule,
        use_kl=config.use_kl,
        predict_xstart=config.predict_xstart,
        rescale_timesteps=config.rescale_timesteps,
        rescale_learned_sigmas=config.rescale_learned_sigmas,
        timestep_respacing=config.timestep_respacing,
    )
model.to(dist_util.dev())

schedule_sampler = create_named_schedule_sampler(config.schedule_sampler, diffusion)

creating model and diffusion...


In [None]:
logger.log("training...")
TrainLoop(
    model=model,
    diffusion=diffusion,
    data=data,
    wandb_logging=False,
    batch_size=config.batch_size,
    microbatch=config.microbatch,
    lr=config.lr,
    ema_rate=config.ema_rate,
    log_interval=config.log_interval,
    sample_interval = config.sample_interval,
    save_interval=config.save_interval,
    resume_checkpoint=config.resume_checkpoint,
    use_fp16=config.use_fp16,
    fp16_scale_growth=config.fp16_scale_growth,
    schedule_sampler=schedule_sampler,
    weight_decay=config.weight_decay,
    lr_anneal_steps=config.lr_anneal_steps,
).run_loop()

# Sampling

In [27]:
NUM_CLASSES = 3
model.to(dist_util.dev())
if config.use_fp16:
    model.convert_to_fp16()
model.eval()

logger.log("sampling...")
all_images = []
all_labels = []
while len(all_images) * config.sample_bs < config.num_sampling_to_compare_cells:
    model_kwargs = {}
    if config.class_cond:
        classes = torch.randint(
            low=0, high=NUM_CLASSES, size=(config.sample_bs,), device=dist_util.dev()
        )
        model_kwargs["y"] = classes
    sample_fn = (
        diffusion.p_sample_loop if not config.use_ddim else diffusion.ddim_sample_loop
    )
    sample = sample_fn(
        model,
        (config.sample_bs, 4, config.image_size),
        clip_denoised=config.clip_denoised,
        model_kwargs=model_kwargs,
    )
    sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
    sample = sample.permute(0, 2, 3, 1)
    sample = sample.contiguous()

    gathered_samples = [torch.zeros_like(sample) for _ in range(dist.get_world_size())]
    dist.all_gather(gathered_samples, sample)  # gather not supported with NCCL
    all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
    if config.class_cond:
        gathered_labels = [
            torch.zeros_like(classes) for _ in range(dist.get_world_size())
        ]
        dist.all_gather(gathered_labels, classes)
        all_labels.extend([labels.cpu().numpy() for labels in gathered_labels])
    logger.log(f"created {len(all_images) * config.sample_bs} samples")

arr = np.concatenate(all_images, axis=0)
arr = arr[: config.num_sampling_to_compare_cells]
if config.class_cond:
    label_arr = np.concatenate(all_labels, axis=0)
    label_arr = label_arr[: config.num_sampling_to_compare_cells]
if dist.get_rank() == 0:
    shape_str = "x".join([str(x) for x in arr.shape])
    out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz")
    logger.log(f"saving to {out_path}")
    if config.class_cond:
        np.savez(out_path, arr, label_arr)
    else:
        np.savez(out_path, arr)

dist.barrier()
logger.log("sampling complete")

sampling...


RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 3 is not equal to len(dims) = 4

In [30]:
model_kwargs = {}
if config.class_cond:
    classes = torch.randint(
        low=0, high=NUM_CLASSES, size=(config.sample_bs,), device=dist_util.dev()
    )
    model_kwargs["y"] = classes
sample_fn = (
    diffusion.p_sample_loop if not config.use_ddim else diffusion.ddim_sample_loop
)
sample = sample_fn(
    model,
    (config.sample_bs, 4, config.image_size),
    clip_denoised=config.clip_denoised,
    model_kwargs=model_kwargs,
)

In [None]:
from collections import Counter

cell_types=data["cell_types"]
# count cell types in train
cell_dict_temp = Counter(data['x_train_cell_type'].tolist())
# reoder by cell_types list
cell_dict = {k:cell_dict_temp[k] for k in data['cell_types']}
# take only counts
cell_type_counts = list(cell_dict.values())
cell_type_probabilities = [x / sum(cell_type_counts) for x in cell_type_counts]

In [None]:
nucleotides = ["A", "C", "G", "T"]
final_sequences = []
plain_generated_sequences = []
for n_a in range(int(config.num_sampling_to_compare_cells/ config.sample_bs)):
    model_kwargs = {}
    if config.class_cond:
        classes = torch.randint(
            low=0, high=NUM_CLASSES, size=(config.sample_bs,), device=dist_util.dev()
        )
        model_kwargs["y"] = classes
    sample_fn = (
        diffusion.p_sample_loop if not config.use_ddim else diffusion.ddim_sample_loop
    )
    sampled_images = sample_fn(
        model,
        (config.sample_bs, 4, config.image_size),
        clip_denoised=config.clip_denoised,
        model_kwargs=model_kwargs,
    )
    
    for n_b, x in enumerate(sampled_images):
        sequence = "".join([nucleotides[s] for s in np.argmax(x.detach().cpu(), axis=0)])
        plain_generated_sequences.append(sequence)
        seq_final = f">seq_test_{n_a}_{n_b}\n" + sequence
        final_sequences.append(seq_final)


In [47]:
nucleotides = ["A", "C", "G", "T"]
final_sequences = []
plain_generated_sequences = []
for n_b, x in enumerate(sample):
    sequence = "".join([nucleotides[s] for s in np.argmax(x.detach().cpu(), axis=0)])
    plain_generated_sequences.append(sequence)
    seq_final = f">seq_test_{n_b}\n" + sequence
    final_sequences.append(seq_final)