In [1]:
import sys 
sys.path.append('/data/conghao001/diffusion_model/latent-diffusion')

In [2]:
from ldm.models.diffusion.ddpm import DDPM, DiffusionWrapper, disabled_train
from ldm.models.diffusion.ddim import DDIMSampler


In [3]:
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config, get_obj_from_str
from pytorch_lightning.utilities.distributed import rank_zero_only


In [4]:
import numpy as np 
import torch

In [5]:
__conditioning_keys__ = {'concat': 'c_concat',
                         'crossattn': 'c_crossattn',
                         'adm': 'y'}

In [6]:
# class DiffusionWrapper(pl.LightningModule):
#     def __init__(self, diff_model_config, conditioning_key):
#         super().__init__()
#         self.diffusion_model = instantiate_from_config(diff_model_config)
#         self.conditioning_key = conditioning_key
#         assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']

#     def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
#         if self.conditioning_key is None:
#             out = self.diffusion_model(x, t)
#         elif self.conditioning_key == 'concat':
#             xc = torch.cat([x] + c_concat, dim=1)
#             out = self.diffusion_model(xc, t)
#         elif self.conditioning_key == 'crossattn':
#             cc = torch.cat(c_crossattn, 1)
#             out = self.diffusion_model(x, t, context=cc)
#         elif self.conditioning_key == 'hybrid':
#             xc = torch.cat([x] + c_concat, dim=1)
#             cc = torch.cat(c_crossattn, 1)
#             out = self.diffusion_model(xc, t, context=cc)
#         elif self.conditioning_key == 'adm':
#             cc = c_crossattn[0]
#             out = self.diffusion_model(x, t, y=cc)
#         else:
#             raise NotImplementedError()

#         return out

# Rewrite the ldm model

~~1. get_input function can be removed: it's supposed to get the input from batch with regard to key (first stage model input or cond stage model input), since the original model can have different cond stage keys, but in our case the cond input is always gene expr~~


In [7]:
class LatentDiffusion(DDPM):
    def __init__(self,
                 first_stage_config,
                 cond_stage_config,
                 dataset, 
                 batch_size,
                 first_stage_params,
                 first_stage_ckpt,
                 num_timesteps_cond=None,
                 cond_stage_key="gene_expressions",    
                 cond_stage_trainable=False,    # we can either use pretrained MLP or train a new one
                 concat_mode=True,
                 cond_stage_forward=None,
                 conditioning_key=None,    # by default, concat mode is used
                 scale_factor=1.0,
                 scale_by_std=False,
                 *args, **kwargs):
        self.num_timesteps_cond = default(num_timesteps_cond, 1)
        self.scale_by_std = scale_by_std
        assert self.num_timesteps_cond <= kwargs['timesteps']
        
        # to init the first stage model
#         self.dataset = dataset
        self.batch_size = batch_size
#         self.first_stage_params = first_stage_params
#         self.first_stage_ckpt = first_stage_ckpt
        
        # for backwards compatibility after implementation of DiffusionWrapper
        if conditioning_key is None:
            conditioning_key = 'concat' if concat_mode else 'crossattn'
        if cond_stage_config == '__is_unconditional__':    # to train unconditional diff model
            conditioning_key = None
            
        ckpt_path = kwargs.pop("ckpt_path", None)    # this is for the diff model ckpt, not vaes
        ignore_keys = kwargs.pop("ignore_keys", [])
        super().__init__(conditioning_key=conditioning_key, *args, **kwargs)    # DiffusionWrapper is called here
        
        self.concat_mode = concat_mode
        self.cond_stage_trainable = cond_stage_trainable
        self.cond_stage_key = cond_stage_key
        try:
            self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
        except:
            self.num_downs = 0
        if not scale_by_std:
            self.scale_factor = scale_factor
        else:
            self.register_buffer('scale_factor', torch.tensor(scale_factor))
        self.instantiate_first_stage(first_stage_config, first_stage_params, dataset, first_stage_ckpt)    # first stage model is initiated here
#         self.instantiate_cond_stage(cond_stage_config)    # let's remove this and directly input the gene expr to unet
        self.cond_stage_forward = cond_stage_forward
        self.clip_denoised = False
        self.bbox_tokenizer = None  

        self.restarted_from_ckpt = False
        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys)
            self.restarted_from_ckpt = True
            
    def make_cond_schedule(self, ):
        self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
        ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
        self.cond_ids[:self.num_timesteps_cond] = ids
        
    @rank_zero_only
    @torch.no_grad()
    def on_train_batch_start(self, batch, batch_idx):
        # only for very first batch
        # this function should be called at the beginning of trainer.fit_loop, not sure if it will be called in trainer.fit
        if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
            assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
            # set rescale weight to 1./std of encodings
            print("### USING STD-RESCALING ###")
#             x = super().get_input(batch, self.first_stage_key)    # this is not necessary as the whole batch is always passed together
            x = batch
            x = x.to(self.device)
            encoder_posterior, partial_reprs, node_reprs = self.encode_first_stage(x)
            self.partial_reprs = partial_reprs
            self.node_reprs = node_reprs
            
            z = self.get_first_stage_encoding(encoder_posterior).detach()
            del self.scale_factor
            self.register_buffer('scale_factor', 1. / z.flatten().std())
            print(f"setting self.scale_factor to {self.scale_factor}")
            print("### USING STD-RESCALING ###")
            
    def register_schedule(self,
                          given_betas=None, beta_schedule="linear", timesteps=1000,
                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
        super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)

        self.shorten_cond_schedule = self.num_timesteps_cond > 1
        if self.shorten_cond_schedule:
            self.make_cond_schedule()
            
    def init_first_stage_model(self, config, ckpt, **kwargs):
        # vae kwargs: params, dataset, using_lincs, include_predict_gene_exp_mlp = False, num_train_batches=1, batch_size=1, use_clamp_log_var = False
        if not "target" in config:
            if config == '__is_first_stage__':
                return None
            elif config == "__is_unconditional__":
                return None
            raise KeyError("Expected key `target` to instantiate.")
        model = get_obj_from_str(config["target"])
        model = model.load_from_checkpoint(ckpt, **kwargs)

        return model
            
    def instantiate_first_stage(self, config, params, dataset, ckpt):
#         model = instantiate_from_config(config)    # TODO: rewrite with our model init function

        if config['model_type'] == 'vae':
            model = self.init_first_stage_model(config, ckpt, params=params, dataset=dataset, using_lincs=config['using_lincs'])
        elif config['model_type'] == 'aae':
            model = self.init_first_stage_model(
                config, 
                ckpt, 
                params=params,
                dataset=dataset,
                using_lincs=config['using_lincs'],
                using_wasserstein_loss=False,
                using_gp=False,
            )
        elif config['model_type'] == 'wae':
            model = self.init_first_stage_model(
                config, 
                ckpt, 
                params=params,
                dataset=dataset,
                using_lincs=config['using_lincs'],
                using_wasserstein_loss=True,
                using_gp=True,
            )
        else: 
            raise NotImplementedError('first stage model type is not supported')
        
        self.first_stage_model = model.eval()
        self.first_stage_model.train = disabled_train    # overwrite model's train function with the empty disabled_train
        for param in self.first_stage_model.parameters():
            param.requires_grad = False
            
    def instantiate_cond_stage(self, config):
        # TODO: adapt this function
        if not self.cond_stage_trainable:
            if config == "__is_first_stage__":
                print("Using first stage also as cond stage.")
                self.cond_stage_model = self.first_stage_model
            elif config == "__is_unconditional__":
                print(f"Training {self.__class__.__name__} as an unconditional model.")
                self.cond_stage_model = None
                # self.be_unconditional = True
            else:
                model = instantiate_from_config(config)
                self.cond_stage_model = model.eval()
                self.cond_stage_model.train = disabled_train
                for param in self.cond_stage_model.parameters():
                    param.requires_grad = False
        else:
            assert config != '__is_first_stage__'
            assert config != '__is_unconditional__'
            model = instantiate_from_config(config)
            self.cond_stage_model = model
            
#     def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
#         # we may not need the denoising row?
#         denoise_row = []
#         for zd in tqdm(samples, desc=desc):
#             denoise_row.append(self.decode_first_stage(zd.to(self.device),
#                                                             force_not_quantize=force_no_decoder_quantization))
#         n_imgs_per_row = len(denoise_row)
#         denoise_row = torch.stack(denoise_row)  # n_log_step, n_row, C, H, W
#         denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
#         denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
#         denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
#         return denoise_grid

    def get_first_stage_encoding(self, encoder_posterior):
        # we don't have the diagonal gaussian dist
#         if isinstance(encoder_posterior, DiagonalGaussianDistribution):
#             z = encoder_posterior.sample()
        if isinstance(encoder_posterior, torch.Tensor):
            z = encoder_posterior
        else:
            raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
        return self.scale_factor * z
    
    def get_learned_conditioning(self, c):
        if self.cond_stage_forward is None:
            if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
                c = self.cond_stage_model.encode(c)
                if isinstance(c, DiagonalGaussianDistribution):
                    c = c.mode()
            else:
                c = self.cond_stage_model(c)
        else:
            # TODO: if our cond model is a GenericMLP, set self.cond_stage_forward as the forward function of self.cond_stage_model
            assert hasattr(self.cond_stage_model, self.cond_stage_forward)
            c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
        return c
    
    @torch.no_grad()
    def get_input(self, batch, return_first_stage_outputs=False, force_c_encode=False,
                  cond_key=None, return_original_cond=False, bs=None):
        # TODO: rewrite this function to get the mol input and the gene expr input
        # TODO: this function is called in shared_step (used in training/val_step for lightening model). adapt them as well
        '''
        this function is to get inputs for the diff model. 
        It should only return the latent repr to get diffused and the cond vector to control the diffusion
        
        output of this function: 
            z: latent repr of the first stage model
            partial_repr?? no need to have this
            c: output of the cond stage model (or it can be the gene expr directly??)
        '''
        
#         x = super().get_input(batch, k)
        x = batch
        if bs is not None:
            x = x[:bs]
        x = x.to(self.device)
        encoder_posterior, partial_reprs, node_reprs = self.encode_first_stage(x)
        z = self.get_first_stage_encoding(encoder_posterior).detach()
        self.partial_reprs = partial_reprs
        self.node_reprs = node_reprs

        if self.model.conditioning_key is not None:    # should be concat or crossattn
            if cond_key is None:
                cond_key = self.cond_stage_key

            if cond_key == 'gene_expressions':
                xc = torch.cat((batch[cond_key], batch['dose'].unsqueeze(-1)), dim=-1)
            else:
                xc = None
                raise NotImplementedError('condition key is not supported')
                
            # to train the cond model (currently let's input gene expr directly)
            '''
            if not self.cond_stage_trainable or force_c_encode:
                if isinstance(xc, dict) or isinstance(xc, list):
                    # import pudb; pudb.set_trace()
                    c = self.get_learned_conditioning(xc)
                else:
                    c = self.get_learned_conditioning(xc.to(self.device))
            else:
                c = xc
            '''
            c = xc
                
            if bs is not None:
                c = c[:bs]

        else:
            c = None
            xc = None
            
        # reshape z and c
#         print('original z size:', z.size())
        z = z.view((self.batch_size, -1, z.size(-1)))
#         print('after resizing', z.size())
        
        out = [z, c]
        if return_first_stage_outputs:
            xrec = self.decode_first_stage(z, batch)    # this shouldn't happen, decoder will return a lot of things
            out.extend([x, xrec])
        if return_original_cond:
            out.append(xc)
        return out
    
    @torch.no_grad()
    def decode_first_stage(self, z, batch, predict_cids=False, force_not_quantize=False):
        z = 1. / self.scale_factor * z
        (
            first_node_type_logits,
            node_type_logits,
            edge_candidate_logits,
            edge_type_logits,
            attachment_point_selection_logits,
        ) = self.first_stage_model.decoder(
            input_molecule_representations=z,
            graph_representations=self.partial_reprs,
            graphs_requiring_node_choices=batch.correct_node_type_choices_batch.unique(),
            # edge selection
            node_representations=self.node_reprs,
            num_graphs_in_batch=len(batch.ptr) - 1,
            focus_node_idx_in_batch=batch.focus_node,
            node_to_graph_map=batch.batch,
            candidate_edge_targets=batch.valid_edge_choices[:, 1].long(),
            candidate_edge_features=batch.edge_features,
            # attachment selection
            candidate_attachment_points=batch.valid_attachment_point_choices.long(),
        )
        return [first_node_type_logits, node_type_logits, edge_candidate_logits, edge_type_logits, attachment_point_selection_logits]
    
    @torch.no_grad()
    def encode_first_stage(self, batch):
        # other repr from partial encoder should also be done and passed to self, so that decoder can access to them
        input_molecule_representations = self.first_stage_model.full_graph_encoder(
            original_graph_node_categorical_features=batch.original_graph_node_categorical_features,
            node_features=batch.original_graph_x.float(),
            edge_index=batch.original_graph_edge_index,
            edge_features=batch.original_graph_edge_features,  # can be edge_type or edge_attr
            batch_index=batch.original_graph_x_batch,
        )
        partial_graph_representations, node_representations = self.first_stage_model.partial_graph_encoder(
            partial_graph_node_categorical_features=batch.partial_node_categorical_features,
            node_features=batch.x,
            edge_index=batch.edge_index.long(),
            edge_features=batch.partial_graph_edge_features,
            graph_to_focus_node_map=batch.focus_node,
            candidate_attachment_points=batch.valid_attachment_point_choices,
            batch_index=batch.batch,
        )
        return input_molecule_representations, partial_graph_representations, node_representations
    
    def shared_step(self, batch, **kwargs):
        print(batch)
        # skip a weird batch
        if batch['dose'].size(0) != 1000:
            raise ValueError('channel number is not 1000!')
            
        # pass the batch data to decoder via self
        self.batch = batch

        # x here is actually latent repr z. it's written as x to be consistent with the DDPM theory.
        x, c = self.get_input(batch)
        loss = self(x, c)
        return loss
    
    def forward(self, x, c, *args, **kwargs):
        t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
        if self.model.conditioning_key is not None:
            assert c is not None
            if self.cond_stage_trainable:
                c = self.get_learned_conditioning(c)
            if self.shorten_cond_schedule:  # TODO: drop this option
                tc = self.cond_ids[t].to(self.device)
                c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
        return self.p_losses(x, c, t, *args, **kwargs)
    
    def apply_model(self, x_noisy, t, cond, return_ids=False):

        if isinstance(cond, dict):
            # hybrid case, cond is exptected to be a dict
            pass
        else:
            if not isinstance(cond, list):
                cond = [cond]
            key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
            cond = {key: cond}

        # pass the cond (gene expr) as c_concat in DiffusionWrapper model forward function
        # then x_noisy and cond will be concatenated and passed to U-Net model forward
        
        x_recon = self.model(x_noisy, t, **cond)

        if isinstance(x_recon, tuple) and not return_ids:
            return x_recon[0]
        else:
            return x_recon
        
    def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
        return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
               extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
    
    def _prior_bpd(self, x_start):
        """
        Get the prior KL term for the variational lower-bound, measured in
        bits-per-dim.
        This term can't be optimized, as it only depends on the encoder.
        :param x_start: the [N x C x ...] tensor of inputs.
        :return: a batch of [N] KL values (in bits), one per batch element.
        """
        batch_size = x_start.shape[0]
        t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
        qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
        kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
        return mean_flat(kl_prior) / np.log(2.0)
    
    def p_losses(self, x_start, cond, t, noise=None):
#         print('x_start size:', x_start.size())
        noise = default(noise, lambda: torch.randn_like(x_start))
        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
#         print('x noisy size:', x_noisy.size())
        model_output = self.apply_model(x_noisy, t, cond)

        loss_dict = {}
        prefix = 'train' if self.training else 'val'

        if self.parameterization == "x0":
            target = x_start
        elif self.parameterization == "eps":
            target = noise
        else:
            raise NotImplementedError()

#         print('size of output and target:', model_output.size(), target.size())
        loss_simple0 = self.get_loss(model_output, target, mean=False)
        # change the mean dim from [1, 2, 3] to [1, 2] as our latent repr is 1D
        loss_simple = loss_simple0.mean([1, 2])
        loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})

#         print('self device:', self.device)
#         print('original t device:', t.device)
#         t = t.to(self.device)
        
#         print('logvar device:', self.logvar.device)
        self.logvar = self.logvar.to(self.device)
        logvar_t = self.logvar[t].to(self.device)
        loss = loss_simple / torch.exp(logvar_t) + logvar_t
        # loss = loss_simple / torch.exp(self.logvar) + self.logvar
        if self.learn_logvar:
            loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
            loss_dict.update({'logvar': self.logvar.data.mean()})

        loss = self.l_simple_weight * loss.mean()

        # change the mean dim from [1, 2, 3] to [1, 2] as our latent repr is 1D
        loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2))
        loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
        loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
        loss += (self.original_elbo_weight * loss_vlb)
        loss_dict.update({f'{prefix}/loss': loss})

        return loss, loss_dict
    
    def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
                        return_x0=False, score_corrector=None, corrector_kwargs=None):
        t_in = t
        model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)

        if score_corrector is not None:
            assert self.parameterization == "eps"
            model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)

        if return_codebook_ids:
            model_out, logits = model_out

        if self.parameterization == "eps":
            x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
        elif self.parameterization == "x0":
            x_recon = model_out
        else:
            raise NotImplementedError()

        if clip_denoised:
            x_recon.clamp_(-1., 1.)
        if quantize_denoised:
            x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
        if return_codebook_ids:
            return model_mean, posterior_variance, posterior_log_variance, logits
        elif return_x0:
            return model_mean, posterior_variance, posterior_log_variance, x_recon
        else:
            return model_mean, posterior_variance, posterior_log_variance
        
    @torch.no_grad()
    def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
                 return_codebook_ids=False, quantize_denoised=False, return_x0=False,
                 temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
        b, *_, device = *x.shape, x.device
        outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
                                       return_codebook_ids=return_codebook_ids,
                                       quantize_denoised=quantize_denoised,
                                       return_x0=return_x0,
                                       score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
        if return_codebook_ids:
            raise DeprecationWarning("Support dropped.")
            model_mean, _, model_log_variance, logits = outputs
        elif return_x0:
            model_mean, _, model_log_variance, x0 = outputs
        else:
            model_mean, _, model_log_variance = outputs

        noise = noise_like(x.shape, device, repeat_noise) * temperature
        if noise_dropout > 0.:
            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
        # no noise when t == 0
        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))

        if return_codebook_ids:
            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
        if return_x0:
            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
        else:
            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
        
    @torch.no_grad()
    def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
                              img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
                              score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
                              log_every_t=None):
        if not log_every_t:
            log_every_t = self.log_every_t
        timesteps = self.num_timesteps
        if batch_size is not None:
            b = batch_size if batch_size is not None else shape[0]
            shape = [batch_size] + list(shape)
        else:
            b = batch_size = shape[0]
        
        # img is standard normal dist or the last timestep x_T
        if x_T is None:
            img = torch.randn(shape, device=self.device)
        else:
            img = x_T
        intermediates = []
        if cond is not None:
            if isinstance(cond, dict):
                cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
                list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
            else:
                cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]

        if start_T is not None:
            timesteps = min(timesteps, start_T)
        iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
                        total=timesteps) if verbose else reversed(
            range(0, timesteps))
        if type(temperature) == float:
            temperature = [temperature] * timesteps

        for i in iterator:
            ts = torch.full((b,), i, device=self.device, dtype=torch.long)
            if self.shorten_cond_schedule:
                assert self.model.conditioning_key != 'hybrid'
                tc = self.cond_ids[ts].to(cond.device)
                cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))    # add noise to the cond?

            img, x0_partial = self.p_sample(img, cond, ts,
                                            clip_denoised=self.clip_denoised,
                                            quantize_denoised=quantize_denoised, return_x0=True,
                                            temperature=temperature[i], noise_dropout=noise_dropout,
                                            score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
            if mask is not None:
                assert x0 is not None
                img_orig = self.q_sample(x0, ts)
                img = img_orig * mask + (1. - mask) * img

            if i % log_every_t == 0 or i == timesteps - 1:
                intermediates.append(x0_partial)
            if callback: callback(i)
            if img_callback: img_callback(img, i)
        return img, intermediates
    
    @torch.no_grad()
    def p_sample_loop(self, cond, shape, return_intermediates=False,
                      x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
                      mask=None, x0=None, img_callback=None, start_T=None,
                      log_every_t=None):

        if not log_every_t:
            log_every_t = self.log_every_t
        device = self.betas.device
        b = shape[0]
        if x_T is None:
            img = torch.randn(shape, device=device)
        else:
            img = x_T

        intermediates = [img]
        if timesteps is None:
            timesteps = self.num_timesteps

        if start_T is not None:
            timesteps = min(timesteps, start_T)
        iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
            range(0, timesteps))

        if mask is not None:
            assert x0 is not None
            assert x0.shape[2:3] == mask.shape[2:3]  # spatial size has to match

        for i in iterator:
            ts = torch.full((b,), i, device=device, dtype=torch.long)
            if self.shorten_cond_schedule:
                assert self.model.conditioning_key != 'hybrid'
                tc = self.cond_ids[ts].to(cond.device)
                cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))

            img = self.p_sample(img, cond, ts,
                                clip_denoised=self.clip_denoised,
                                quantize_denoised=quantize_denoised)
            if mask is not None:
                img_orig = self.q_sample(x0, ts)
                img = img_orig * mask + (1. - mask) * img

            if i % log_every_t == 0 or i == timesteps - 1:
                intermediates.append(img)
            if callback: callback(i)
            if img_callback: img_callback(img, i)

        if return_intermediates:
            return img, intermediates
        return img
    
    @torch.no_grad()
    def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
               verbose=True, timesteps=None, quantize_denoised=False,
               mask=None, x0=None, shape=None,**kwargs):
        if shape is None:
            shape = (batch_size, self.channels, self.image_size, self.image_size)
        if cond is not None:
            if isinstance(cond, dict):
                cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
                list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
            else:
                cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
        return self.p_sample_loop(cond,
                                  shape,
                                  return_intermediates=return_intermediates, x_T=x_T,
                                  verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
                                  mask=mask, x0=x0)

    @torch.no_grad()
    def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
        # we can call this function outside to get the sampled latent reprs
        # x_T can be none, it will be set as randn automatically
        # x0 can also none, it's only applicable when mask is not none. (guess mask is for some inpainting function?)
        if ddim:
            ddim_sampler = DDIMSampler(self)
            shape = (self.channels, self.image_size, self.image_size)
            samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
                                                        shape,cond,verbose=False,**kwargs)

        else:
            samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
                                                 return_intermediates=True,**kwargs)

        return samples, intermediates
    
    @torch.no_grad()
    def log_mol(self, batch):
        # adapt log_images to record the generated molecules
        return 
    
    def configure_optimizers(self):
        lr = self.learning_rate
        params = list(self.model.parameters())
        if self.cond_stage_trainable:
            print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
            params = params + list(self.cond_stage_model.parameters())
        if self.learn_logvar:
            print('Diffusion model optimizing logvar')
            params.append(self.logvar)
        opt = torch.optim.AdamW(params, lr=lr)
        if self.use_scheduler:
            assert 'target' in self.scheduler_config
            scheduler = instantiate_from_config(self.scheduler_config)

            print("Setting up LambdaLR scheduler...")
            scheduler = [
                {
                    'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
                    'interval': 'step',
                    'frequency': 1
                }]
            return [opt], scheduler
        return opt

  rank_zero_deprecation(


# Direct changes to the LDM source code 

- ~~in `latent-diffusion/ldm/modules/diffusionmodules/util.py` normalization function, change GroupNorm from dividing channels into 32 groups to 1 group~~
- add `batch_size = self.batch_size` in each *log* and *log_dict* function in the DDPM module, to get rid of the lightning warning of infering batch size
- in attention.py-*SpatialTransformer*, change from 2D into 1D case. See `/data/conghao001/diffusion_model/latent-diffusion/ldm/modules/attention.py`

# Try the model

## Import data

In [8]:
import sys 
sys.path.append('../')

In [9]:
from dataset import LincsDataset
from torch_geometric.loader import DataLoader
from omegaconf import OmegaConf
from model_utils import get_params
from tqdm import tqdm
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import LearningRateMonitor
from datetime import datetime

2023-05-04 01:06:44.564462: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [10]:
config_file = 'config/ddim_vae_con.yml'
config = OmegaConf.load(config_file)
config

{'model': {'base_learning_rate': 5e-05, 'target': 'ldm.models.diffusion.ddpm.LatentDiffusion', 'params': {'linear_start': 0.00085, 'linear_end': 0.012, 'num_timesteps_cond': 1, 'log_every_t': 50, 'timesteps': 200, 'first_stage_key': 'image', 'cond_stage_key': 'gene_expressions', 'image_size': 512, 'channels': 1, 'cond_stage_trainable': False, 'conditioning_key': 'concat', 'monitor': 'val/loss_simple_ema', 'scale_factor': 1, 'use_ema': False, 'parameterization': 'eps'}, 'first_stage_config': {'target': 'model.BaseModel', 'model_type': 'vae', 'using_lincs': True, 'ckpt_path': '/data/conghao001/FYP/DrugDiscovery/first_stage_models/2023-03-11_23_33_36.921147/epoch=07-val_loss=0.60.ckpt'}, 'cond_stage_config': {'params': {'dim': 979, 'key': 'gene_expressions'}}, 'unet_config': {'target': 'ldm.modules.diffusionmodules.openaimodel.UNetModel', 'params': {'image_size': 512, 'in_channels': 1, 'out_channels': 1, 'model_channels': 64, 'dims': 1, 'attention_resolutions': [4, 2], 'num_res_blocks': 1

In [11]:
ldm_params = config['model']['params']
ldm_params

{'linear_start': 0.00085, 'linear_end': 0.012, 'num_timesteps_cond': 1, 'log_every_t': 50, 'timesteps': 200, 'first_stage_key': 'image', 'cond_stage_key': 'gene_expressions', 'image_size': 512, 'channels': 1, 'cond_stage_trainable': False, 'conditioning_key': 'concat', 'monitor': 'val/loss_simple_ema', 'scale_factor': 1, 'use_ema': False, 'parameterization': 'eps'}

*TODO:* write this into config as well

In [12]:
# args
batch_size = 1
NUM_WORKERS = 4
train_split1 = "train_0"
valid_split = "valid_0"

layer_type = "FiLMConv"
model_architecture = 'vae'
gradient_clip_val = 1.0
max_lr = 1e-5
gen_step_drop_probability = 0.5
use_oclr_scheduler = True
using_cyclical_anneal = False
use_clamp_log_var = False

raw_moler_trace_dataset_parent_folder = "/data/ongh0068/guacamol/trace_dir"
# raw_moler_trace_dataset_parent_folder = "/data/ongh0068/l1000/TRACE_DIR"
output_pyg_trace_dataset_parent_folder = (
    "/data/ongh0068/l1000/already_batched"
)

In [13]:
train_dataset = LincsDataset(
    root="/data/ongh0068",
    raw_moler_trace_dataset_parent_folder=raw_moler_trace_dataset_parent_folder,  # "/data/ongh0068/l1000/trace_playground",
    output_pyg_trace_dataset_parent_folder=output_pyg_trace_dataset_parent_folder,
    gene_exp_controls_file_path="/data/ongh0068/l1000/lincs/robust_normalized_controls.npz",
    gene_exp_tumour_file_path="/data/ongh0068/l1000/lincs/robust_normalized_tumors.npz",
    lincs_csv_file_path="/data/ongh0068/l1000/lincs/experiments_filtered.csv",
    split=train_split1,
    gen_step_drop_probability=gen_step_drop_probability,
)
train_dataset

Loading controls gene expression...
Loading tumour gene expression...
Loading csv...


LincsDataset(794)

In [14]:
train_dataset[791]

MolerDataBatch(x=[3796, 59], edge_index=[2, 7642], original_graph_edge_features=[14754], original_graph_node_categorical_features=[6757], focus_node=[253], partial_graph_edge_features=[7642], edge_features=[1597, 3], correct_edge_choices=[1597], correct_edge_choices_batch=[1597], correct_edge_choices_ptr=[254], num_correct_edge_choices=[253], stop_node_label=[253], valid_edge_choices=[1597, 2], valid_edge_choices_batch=[1597], valid_edge_choices_ptr=[254], valid_edge_types=[131, 3], correct_edge_types=[131, 3], correct_edge_types_batch=[131], correct_edge_types_ptr=[254], partial_node_categorical_features=[3796], correct_attachment_point_choice=[17], correct_attachment_point_choice_batch=[17], correct_attachment_point_choice_ptr=[254], correct_node_type_choices=[116, 166], correct_node_type_choices_batch=[116], correct_node_type_choices_ptr=[254], correct_first_node_type_choices=[253, 166], correct_first_node_type_choices_batch=[253], correct_first_node_type_choices_ptr=[254], sa_score

In [17]:
# can we remove this data point?
mask = np.ones(len(train_dataset), dtype=bool)
mask[791] = False
mask

array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,

In [18]:
train_dataset = train_dataset[mask]
train_dataset

LincsDataset(793)

In [19]:
len(train_dataset)

793

In [16]:
train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=False,
        # sampler=train_sampler,
        follow_batch=[
            "correct_edge_choices",
            "correct_edge_types",
            "valid_edge_choices",
            "valid_attachment_point_choices",
            "correct_attachment_point_choice",
            "correct_node_type_choices",
            "original_graph_x",
            "correct_first_node_type_choices",
        ],
        num_workers=NUM_WORKERS,
        # prefetch_factor=0,
    )

In [19]:
for id, batch in enumerate(train_dataloader):
    if batch.dose.size(0) == 1000:
        print(id)
    else:
        print(batch.dose.size())

torch.Size([472])
torch.Size([493])
torch.Size([504])
torch.Size([512])
torch.Size([499])
torch.Size([495])
torch.Size([481])
torch.Size([509])
torch.Size([465])
torch.Size([492])
torch.Size([499])
torch.Size([499])
torch.Size([529])
torch.Size([500])
torch.Size([517])
torch.Size([508])
torch.Size([496])
torch.Size([502])
torch.Size([477])
torch.Size([524])
torch.Size([499])
torch.Size([521])
torch.Size([486])
torch.Size([530])
torch.Size([502])
torch.Size([485])
torch.Size([485])
torch.Size([517])
torch.Size([499])
torch.Size([507])
torch.Size([466])
torch.Size([499])
torch.Size([506])
torch.Size([484])
torch.Size([487])
torch.Size([495])
torch.Size([519])
torch.Size([527])
torch.Size([471])
torch.Size([483])
torch.Size([502])
torch.Size([478])
torch.Size([516])
torch.Size([492])
torch.Size([478])
torch.Size([482])
torch.Size([509])
torch.Size([503])
torch.Size([507])
torch.Size([483])
torch.Size([495])
torch.Size([532])
torch.Size([488])
torch.Size([495])
torch.Size([479])
torch.Size

torch.Size([498])
torch.Size([486])
torch.Size([518])
torch.Size([517])
torch.Size([513])
torch.Size([501])
torch.Size([514])
torch.Size([496])
torch.Size([491])
torch.Size([503])
torch.Size([516])
torch.Size([515])
torch.Size([500])
torch.Size([516])
torch.Size([484])
torch.Size([482])
torch.Size([536])
torch.Size([491])
torch.Size([495])
torch.Size([510])
torch.Size([492])
torch.Size([497])
torch.Size([498])
torch.Size([494])
torch.Size([506])
torch.Size([506])
torch.Size([495])
torch.Size([511])
torch.Size([529])
torch.Size([483])
torch.Size([483])
torch.Size([501])
torch.Size([496])
torch.Size([496])
torch.Size([534])
torch.Size([503])
torch.Size([506])
torch.Size([529])
torch.Size([486])
torch.Size([516])
torch.Size([479])
torch.Size([508])
torch.Size([492])
torch.Size([522])
torch.Size([533])
torch.Size([504])
torch.Size([511])
torch.Size([484])
torch.Size([496])
torch.Size([509])
torch.Size([460])
torch.Size([493])
torch.Size([487])
torch.Size([522])
torch.Size([491])
torch.Size

In [20]:
valid_dataset = LincsDataset(
    root="/data/ongh0068",
    raw_moler_trace_dataset_parent_folder=raw_moler_trace_dataset_parent_folder,  # "/data/ongh0068/l1000/trace_playground",
    output_pyg_trace_dataset_parent_folder=output_pyg_trace_dataset_parent_folder,
    gene_exp_controls_file_path="/data/ongh0068/l1000/lincs/robust_normalized_controls.npz",
    gene_exp_tumour_file_path="/data/ongh0068/l1000/lincs/robust_normalized_tumors.npz",
    lincs_csv_file_path="/data/ongh0068/l1000/lincs/experiments_filtered.csv",
    split=valid_split,
    gen_step_drop_probability=gen_step_drop_probability,

)

Loading controls gene expression...
Loading tumour gene expression...
Loading csv...


In [22]:
# use the same mask to filter val dataset
valid_dataset = valid_dataset[mask]
len(valid_dataset)

793

In [21]:
valid_dataloader = DataLoader(
        valid_dataset,
        batch_size=batch_size,
        shuffle=False,
        # sampler=valid_sampler,
        follow_batch=[
            "correct_edge_choices",
            "correct_edge_types",
            "valid_edge_choices",
            "valid_attachment_point_choices",
            "correct_attachment_point_choice",
            "correct_node_type_choices",
            "original_graph_x",
            "correct_first_node_type_choices",
        ],
        num_workers=NUM_WORKERS,
        # prefetch_factor=0,
    )

In [22]:
for id, batch in enumerate(valid_dataloader):
    if batch.dose.size(0) == 1000:
        print(id)
    else:
        print(batch.dose.size())

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
27

In [24]:
for batch_id, batch in enumerate(valid_dataloader):
    if batch['dose'].size(0) == 1000:
        
        continue
    else:
        print(batch_id)
        print(batch)

In [25]:
first_stage_params = get_params(train_dataset)
first_stage_params

{'full_graph_encoder': {'input_feature_dim': 59,
  'atom_or_motif_vocab_size': 166,
  'aggr_layer_type': 'MoLeRAggregation',
  'total_num_moler_aggr_heads': 32,
  'layer_type': 'FiLMConv'},
 'partial_graph_encoder': {'input_feature_dim': 59,
  'atom_or_motif_vocab_size': 166,
  'aggr_layer_type': 'MoLeRAggregation',
  'total_num_moler_aggr_heads': 16,
  'layer_type': 'FiLMConv'},
 'mean_log_var_mlp': {'input_feature_dim': 832,
  'output_size': 1024,
  'hidden_layer_dims': [],
  'use_bias': False},
 'decoder': {'node_type_selector': {'input_feature_dim': 1344,
   'output_size': 167},
  'use_node_type_loss_weights': True,
  'node_type_loss_weights': tensor([10.0000,  0.1000,  3.6015,  0.1000,  0.1000,  0.4439,  0.7549,  0.4416,
          10.0000,  2.7939,  3.3916, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000,
          10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000,
          10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000,
          10.0000,

In [26]:
first_stage_config = config['model']['first_stage_config']
first_stage_config

{'target': 'model.BaseModel', 'model_type': 'vae', 'using_lincs': True, 'ckpt_path': '/data/conghao001/FYP/DrugDiscovery/first_stage_models/2023-03-11_23_33_36.921147/epoch=07-val_loss=0.60.ckpt'}

In [27]:
ldm_model = LatentDiffusion(
    first_stage_config,
    config['model']['cond_stage_config'],
    train_dataset, 
    batch_size,
    first_stage_params,
    first_stage_config['ckpt_path'],
    unet_config = config['model']['unet_config'],
    **ldm_params
)
ldm_model

LatentDiffusion: Running in x0-prediction mode
DiffusionWrapper has 1303.74 M params.


LatentDiffusion(
  (model): DiffusionWrapper(
    (diffusion_model): UNetModel(
      (time_embed): Sequential(
        (0): Linear(in_features=1024, out_features=4096, bias=True)
        (1): SiLU()
        (2): Linear(in_features=4096, out_features=4096, bias=True)
      )
      (input_blocks): ModuleList(
        (0): TimestepEmbedSequential(
          (0): Conv1d(1000, 1024, kernel_size=(3,), stride=(1,), padding=(1,))
        )
        (1): TimestepEmbedSequential(
          (0): ResBlock(
            (in_layers): Sequential(
              (0): GroupNorm32(32, 1024, eps=1e-05, affine=True)
              (1): SiLU()
              (2): Conv1d(1024, 1024, kernel_size=(3,), stride=(1,), padding=(1,))
            )
            (h_upd): Identity()
            (x_upd): Identity()
            (emb_layers): Sequential(
              (0): SiLU()
              (1): Linear(in_features=4096, out_features=1024, bias=True)
            )
            (out_layers): Sequential(
              (0): Gr

In [28]:
lr = config.model.base_learning_rate
ldm_model.learning_rate = lr

In [29]:
device = torch.device('cuda:1')
device

device(type='cuda', index=1)

### pytorch lightning config

In [30]:
# Get current time for folder path.
now = str(datetime.now()).replace(" ", "_").replace(":", "_")

# Callbacks
lr_monitor = LearningRateMonitor(logging_interval="step")
tensorboard_logger = TensorBoardLogger(save_dir=f"lightning_logs/{now}", name=f"logs_{now}")
early_stopping = EarlyStopping(monitor="val/loss", patience=3)
if model_architecture == "vae":
    checkpoint_callback = ModelCheckpoint(
        save_top_k=1,
        monitor="val/loss",
        dirpath=f"lightning_logs/{now}",
        mode="min",
        filename="{epoch:02d}-{val_loss:.2f}",
    )
elif model_architecture == "aae":
    checkpoint_callback = ModelCheckpoint(
        dirpath=f"lightning_logs/{now}",
        filename="{epoch:02d}-{train_loss:.2f}",
        monitor="epoch",
        every_n_epochs=3,
        save_on_train_epoch_end=True,
        save_top_k=-1,
    )

callbacks = (
    [checkpoint_callback, lr_monitor, early_stopping]
    if model_architecture == "vae"
    else [checkpoint_callback, lr_monitor]
)

In [31]:
trainer = Trainer(accelerator='gpu', max_epochs=1, devices=[1], callbacks=callbacks, logger=tensorboard_logger)
trainer.fit(ldm_model, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: lightning_logs/2023-05-03_18_49_58.215810/logs_2023-05-03_18_49_58.215810
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name              | Type             | Params
-------------------------------------------------------
0 | model             | DiffusionWrapper | 1.3 B 
1 | first_stage_model | BaseModel        | 6.1 M 
-------------------------------------------------------
1.3 B     Trainable params
6.1 M     Non-trainable params
1.3 B     Total params
5,239.375 Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


In [30]:
import importlib
import ldm.modules.diffusionmodules.openaimodel
from ldm.modules.diffusionmodules.openaimodel import UNetModel
from ldm.modules.diffusionmodules.openaimodel import TimestepEmbedSequential

string = 'ldm.modules.diffusionmodules.openaimodel.UNetModel'
pkg, cls = string.rsplit(".", 1)

module_imp = importlib.import_module(pkg, package=None)
unet_module = getattr(module_imp, cls)
unet_module

# importlib.import_module('UNetModel')

# importlib.reload(unet_module)

ldm.modules.diffusionmodules.openaimodel.UNetModel

In [31]:
%load_ext autoreload
%autoreload 2