In [1]:
%load_ext autoreload
%autoreload 2

import torch
from textual_inversion import TextualInversionWrapper, CLIPTextConfig

In [2]:
config = CLIPTextConfig(
    num_hidden_layers=1,
    hidden_size=32,
    num_attention_heads=4,
)
placeholder_token_id = 2
initializer_token_id = 3
model = TextualInversionWrapper(config, placeholder_token_id, initializer_token_id)

In [3]:
model.init_concept_embeddings()

In [4]:
assert torch.allclose(
    model.concept_embeddings,
    model.get_input_embeddings().weight.data[initializer_token_id]
)

In [5]:
model.freeze_text_model()
assert not model.get_input_embeddings().weight.requires_grad
assert model.concept_embeddings.requires_grad

In [6]:
input_ids = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long)
with torch.no_grad():
    out = model(input_ids)[0]
out.shape

torch.Size([1, 5, 32])

In [8]:

model.merge_concept_embeddings_in_embeddings()
assert not model.get_input_embeddings().weight.requires_grad
assert torch.allclose(
    model.get_input_embeddings().weight.data[placeholder_token_id],
    model.get_initializer_embeddings().data
)

In [9]:
model.save_pretrained("dummy-clip-text")

In [10]:
from transformers import CLIPTextModel

In [11]:
text_model = CLIPTextModel.from_pretrained("dummy-clip-text")

Some weights of the model checkpoint at dummy-clip-text were not used when initializing CLIPTextModel: ['concept_embeddings']
- This IS expected if you are initializing CLIPTextModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing CLIPTextModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [14]:
assert torch.allclose(
    text_model.get_input_embeddings().weight.data[placeholder_token_id],
    model.get_initializer_embeddings().data
)

assert torch.allclose(
    text_model.get_input_embeddings().weight.data[placeholder_token_id],
    model.get_concept_embeddings().data
)

In [26]:
optimized = model.get_concept_embeddings().unsqueeze(0)
coarse = model.get_initializer_embeddings().clone().to(optimized.device).unsqueeze(0)
coarse_loss = (optimized - coarse) @ (optimized - coarse).T / 1
coarse_loss

tensor([[0.]], grad_fn=<DivBackward0>)

In [27]:
optimized = model.get_concept_embeddings()
coarse = model.get_initializer_embeddings().clone().to(optimized.device)
coarse_loss = (optimized - coarse) @ (optimized - coarse).T / 1
coarse_loss

tensor(0., grad_fn=<DivBackward0>)

In [28]:
optimized = torch.tensor([1, 2, 3, 4])
coarse = torch.tensor([5, 6, 7, 8])
coarse_loss = (optimized - coarse) @ (optimized - coarse).T / 1
coarse_loss

tensor(64.)

In [29]:
optimized = torch.tensor([1, 2, 3, 4]).unsqueeze(0)
coarse = torch.tensor([5, 6, 7, 8]).unsqueeze(0)
coarse_loss = (optimized - coarse) @ (optimized - coarse).T / 1
coarse_loss

tensor([[64.]])

### Data

In [34]:
from textual_inversion import PersonalizedBase, CLIPTokenizer

In [35]:
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

In [48]:
data_root = "/Users/surajpatil/projects/stylepics"
dataset = PersonalizedBase(data_root, tokenizer, "style", size=512)

In [50]:
example = dataset[0]

In [51]:
example["pixel_values"].shape

torch.Size([3, 512, 512])