In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from functools import partial
import torch

In [3]:

def unflatten_elem_subf(x, f=None, numargs=1):
    args, rest = x[:numargs], x[numargs:]
    ret = f(args)
    return ret, rest
    
    
def flatten_inputs(inputs):
    if isinstance(inputs, torch.Tensor):
        return [inputs], lambda x: x[0]
    elif isinstance(inputs, (list, tuple)):
        istuple = isinstance(inputs, tuple)
        flat_out = []
        unflatten_elem_fs = []
        for elem in inputs:
            flat_elem, unflatten_elem = flatten_inputs(elem)
            unflatten_elem_f = partial(unflatten_elem_subf, f=unflatten_elem, numargs=len(flat_elem))
            unflatten_elem_fs.append(unflatten_elem_f)
            flat_out += flat_elem
        def unflatten_list(x):
            ret = []
            for unflatten_elem_f in unflatten_elem_fs:
                retelem, x = unflatten_elem_f(x)
                ret.append(retelem)
            if istuple:
                ret = tuple(ret)
            return ret
        return flat_out, unflatten_list
    elif isinstance(inputs, dict):
        flat_out = []
        unflatten_elem_fs = {}
        for k, v in inputs.items():
            flat_elem, unflatten_elem = flatten_inputs(v)
            unflatten_elem_f = partial(unflatten_elem_subf, f=unflatten_elem, numargs=len(flat_elem))
            unflatten_elem_fs[k] = unflatten_elem_f
            flat_out += flat_elem
        def unflatten_dict(x):
            ret = {}
            for k, unflatten_elem_f in unflatten_elem_fs.items():
                retelem, x = unflatten_elem_f(x)
                ret[k] = retelem
            return ret
        return flat_out, unflatten_dict
    elif hasattr(inputs, "flatten_inputs_for_gradient_checkpoint"):
        return inputs.flatten_inputs_for_gradient_checkpoint()
    else:
        raise Exception()

In [4]:
class CustomTextConditioning():
    def __init__(self, embs, layer_ids=None, token_ids=None, global_prompt_mask=None, global_bos_eos_mask=None):
        """
        embs:       (batsize, seqlen, embdim)
        layer_ids:  (batsize, seqlen) integers, with 0 for no-layer global tokens
        token_ids:  (batsize, seqlen) integers for tokens from tokenizer
        global_prompt_mask:  (batsize, seqlen) bool that is 1 where the global prompt is and 0 where the local regional prompts are
        global_bos_eos_mask: (batsize, seqlen) bool that is 1 where the global bos and eos tokens are and 0 elsewhere
        """
        self.embs = embs
        self.device = self.embs.device
        self.layer_ids = layer_ids
        self.token_ids = token_ids
        self.global_prompt_mask = global_prompt_mask
        self.global_bos_eos_mask = global_bos_eos_mask
        self.cross_attn_masks = None
        self.progress = None
        self.strength = 10
        self.threshold = None
        self.softness = 0.2
        self.controlonly = False
        self.controlledonly = False
        
    def cross_attention_control(self, sim, numheads=1):
        """ Takes the unscaled unnormalized attention scores computed by cross-attention module, returns adapted attention scores. """
        wf = self.weight_func(sim)
        
        wf = wf[:, None].repeat(1, numheads, 1, 1)
        wf = wf.view(-1, wf.shape[-2], wf.shape[-1])
        
        sim = sim + wf
        return sim
    
    def weight_func(self, sim):
        mask = self.cross_attn_masks[sim.shape[1]].to(sim.dtype)
        ret = mask * sim.std() * self.strength
        return ret
    
    def flatten_inputs_for_gradient_checkpoint(self):
        flat_out = [self.embs]
        def recon_f(x:list):
            self.embs = x[0]
            return self
        return flat_out, recon_f

In [5]:
x = [CustomTextConditioning(torch.tensor([1.00])), {"b": [[torch.tensor([2])]]}, [torch.tensor([3]), torch.tensor([4]), torch.tensor([5]), [torch.tensor([6]), torch.tensor([7])]], {"a": torch.tensor([8])}]
print(x)


[<__main__.CustomTextConditioning object at 0x7fb1b71da640>, {'b': [[tensor([2])]]}, [tensor([3]), tensor([4]), tensor([5]), [tensor([6]), tensor([7])]], {'a': tensor([8])}]


In [6]:
flat_x, unflatten_x = flatten_inputs(x)
flat_x

[tensor([1.]),
 tensor([2]),
 tensor([3]),
 tensor([4]),
 tensor([5]),
 tensor([6]),
 tensor([7]),
 tensor([8])]

In [7]:
unflatten_x(flat_x)

[<__main__.CustomTextConditioning at 0x7fb1b71da640>,
 {'b': [[tensor([2])]]},
 [tensor([3]), tensor([4]), tensor([5]), [tensor([6]), tensor([7])]],
 {'a': tensor([8])}]

In [8]:
x

[<__main__.CustomTextConditioning at 0x7fb1b71da640>,
 {'b': [[tensor([2])]]},
 [tensor([3]), tensor([4]), tensor([5]), [tensor([6]), tensor([7])]],
 {'a': tensor([8])}]

In [9]:
x[0].embs

tensor([1.])

In [10]:
x = (torch.zeros(5,3), torch.ones(5))
flat_x, unflatten_x = flatten_inputs(x)

In [11]:
flat_x

[tensor([[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]),
 tensor([1., 1., 1., 1., 1.])]

In [12]:
unflatten_x(flat_x)

(tensor([[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]),
 tensor([1., 1., 1., 1., 1.]))