Skip to content

Commit

Permalink
add a open clip wrapper that reduces even more work
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 3, 2023
1 parent 2a6b92c commit be4f634
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 1 deletion.
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,29 @@ embeds_with_new_concept, embeds_with_superclass, embed_mask, concept_indices = w
# the embed_mask needs to be passed to the cross attention as key padding mask
```

If you can identify the `CLIP` instance within the stable diffusion instance, you can also pass it directly to the `OpenClipEmbedWrapper` to gain everything you need on forward for the cross attention layers

ex.

```python
from perfusion_pytorch import OpenClipEmbedWrapper

texts = [
'a portrait of dog',
'dog running through a green field',
'a man walking his dog'
]

wrapped_clip_with_new_concept = OpenClipEmbedWrapper(
text_encoder.clip,
superclass_string = 'dog'
)

enc, superclass_enc, mask, indices = wrapped_clip_with_new_concept(texts)

# (3, 77, 512), (3, 77, 512), (3, 77), (3,)
```

## Todo

- [ ] wire up with SD 1.5, starting with xiao's dreambooth-sd
Expand Down
1 change: 1 addition & 0 deletions perfusion_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from perfusion_pytorch.embedding import (
EmbeddingWrapper,
OpenClipEmbedWrapper,
merge_embedding_wrappers
)

Expand Down
40 changes: 40 additions & 0 deletions perfusion_pytorch/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,46 @@ def forward(

return EmbeddingReturn(embeds, superclass_embeds, embed_mask, concept_indices)

# a wrapper for clip
# that automatically wraps the token embedding with new concept
# and on forward, passes the concept embeddings + superclass concept embeddings through the text transformer + final layernorm
# as well as make the forward pass the ids and superclass_ids through the modified text encoder twice (will attempt to substitute the nn.Embedding with an nn.Identity)

from open_clip import CLIP

class OpenClipEmbedWrapper(Module):
@beartype
def __init__(
self,
clip: CLIP,
**embedding_wrapper_kwargs
):
super().__init__()
self.wrapped_embed = EmbeddingWrapper(clip.token_embedding, **embedding_wrapper_kwargs)

self.text_transformer = nn.Sequential(
clip.transformer,
clip.ln_final
)

def forward(
self,
x,
**kwargs
) -> EmbeddingWrapper:
text_embeds, superclass_text_embeds, text_mask, concept_indices = self.wrapped_embed(x, **kwargs)

text_enc = self.text_transformer(text_embeds)

superclass_text_enc = None

if exists(superclass_text_embeds):
superclass_text_enc = self.text_transformer(superclass_text_embeds)

return EmbeddingReturn(text_enc, superclass_text_embeds, text_mask, concept_indices)

# merging multiple embedding wrappers (with one concepts) into a merged embedding wrapper with multiple concepts

@beartype
def merge_embedding_wrappers(
*embeds: EmbeddingWrapper
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
install_requires=[
'beartype',
'einops>=0.6.1',
'open-clip-torch>=2.0.0,<3.0.0',
'open-clip-torch',
'opt-einsum',
'torch>=2.0'
],
Expand Down

0 comments on commit be4f634

Please sign in to comment.