#Define CLIPTextDeprojector

In [None]:
base_dir = "/content/drive/MyDrive/sd/deprojector"

In [None]:
#@title Install library
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.28.1-py3-none-any.whl (7.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m58.0 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m36.4 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.11.0
  Downloading huggingface_hub-0.13.4-py3-none-any.whl (200 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m200.1/200.1 kB[0m [31m24.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.13.4 tokenizers-0.13.3 transformers-4.28.1


In [None]:
#@title CLIPTextDeprojector
import torch
import torch.nn as nn
from IPython.display import display
from transformers import CLIPPreTrainedModel, CLIPTextConfig
from transformers.models.clip.modeling_clip import CLIPEncoderLayer

class CLIPTextDeprojector(CLIPPreTrainedModel):
    config_class = CLIPTextConfig
    _no_split_modules = ["CLIPEncoderLayer"]

    def __init__(self, config: CLIPTextConfig):
        super().__init__(config)
        self.config = config
        embed_dim = config.hidden_size

        self.to_use_projection = True
        self.projection = nn.Linear(config.projection_dim, embed_dim, bias=False)
        for param in self.projection.parameters():
            param.requires_grad = False  # Fix the parameter of the projection layer.

        self.register_buffer(
            "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
        )
        self.position_embedding = nn.Embedding(
            config.max_position_embeddings, embed_dim
        )
        self.encoder_layer = CLIPEncoderLayer(config)
        self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

        self.register_buffer("sos_embed", torch.zeros([embed_dim]))

    def use_projection(self, to_use: bool = True):
        self.to_use_projection = to_use

    def dont_use_projection(self):
        self.use_projection(False)

    def forward(self, hidden_state):
        bsz, seq_len, _ = hidden_state.size()
        causal_attention_mask = self._build_causal_attention_mask(
            bsz, seq_len, hidden_state.dtype
        ).to(hidden_state.device)

        attention_mask = None

        if self.to_use_projection:
            embeds = hidden_state[:, 0, :]
            embeds = self.projection(embeds)
            hidden_state = torch.cat(
                [embeds.unsqueeze(1), hidden_state[:, 1:, :]], dim=1
            )

        position_embeddings = self.position_embedding(self.position_ids)
        hidden_state = hidden_state + position_embeddings

        layer_outputs = self.encoder_layer(
            hidden_state,
            attention_mask,
            causal_attention_mask,
        )
        output = self.final_layer_norm(layer_outputs[0])
        sos_embed = torch.cat([self.sos_embed.unsqueeze(0).unsqueeze(0)] * bsz, dim=0)
        return torch.cat([sos_embed, output[:, 1:]], dim=1)

    def _build_causal_attention_mask(self, bsz, seq_len, dtype):
        # lazily create causal attention mask, with full attention between the vision tokens
        # pytorch uses additive attention mask; fill with -inf
        mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
        mask.fill_(torch.tensor(torch.finfo(dtype).min))
        mask.triu_(1)  # zero out the lower diagonal
        mask = mask.unsqueeze(1)  # expand mask
        return mask

    def Inference(self, embeds):
        max_len = self.config.max_position_embeddings
        embeds = embeds.unsqueeze(1)
        result = torch.zeros([embeds.shape[0], 0, embeds.shape[-1]]).to(embeds.device)
        while True:
            result_len = len(result[0])
            remaining_len = max_len - result_len - 1
            input = torch.cat(
                [
                    embeds,
                    result,
                    torch.zeros([embeds.shape[0], remaining_len, embeds.shape[-1]]).to(
                        embeds.device
                    ),
                ],
                dim=1,
            )
            result = self(input)
            if remaining_len == 0:
                return result
            result = result[:, : result_len + 1, :]


#Create initial model weights from pretrained CLIP model

In [None]:
#@title Download LAION-400M and extract texts
!wget https://deploy.laion.ai/8f83b608504d46bb81708ec86e912220/dataset/part-00000-5b54c5d5-bbcf-484d-a2ce-0d6f73df1a36-c000.snappy.parquet \
  -P $base_dir/

import pyarrow as pa
import pyarrow.parquet
import pyarrow.csv as csv
table = pa.parquet.read_table(f'{base_dir}/part-00000-5b54c5d5-bbcf-484d-a2ce-0d6f73df1a36-c000.snappy.parquet', columns=['TEXT'])
csv.write_csv(table, f"{base_dir}/texts.csv")

In [None]:
!mkdir $base_dir/init_model
!wget https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/pytorch_model.bin
!wget https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/text_encoder/config.json \
  -P $base_dir/init_model/

import torch
model_data = torch.load("pytorch_model.bin")

new_model_data = {}
new_model_data["position_embedding.weight"]             = model_data["text_model.embeddings.position_embedding.weight"]
new_model_data["encoder_layer.self_attn.k_proj.weight"] = model_data["text_model.encoder.layers.11.self_attn.k_proj.weight"]
new_model_data["encoder_layer.self_attn.k_proj.bias"]   = model_data["text_model.encoder.layers.11.self_attn.k_proj.bias"]
new_model_data["encoder_layer.self_attn.v_proj.weight"] = model_data["text_model.encoder.layers.11.self_attn.v_proj.weight"]
new_model_data["encoder_layer.self_attn.v_proj.bias"]   = model_data["text_model.encoder.layers.11.self_attn.v_proj.bias"]
new_model_data["encoder_layer.self_attn.q_proj.weight"] = model_data["text_model.encoder.layers.11.self_attn.q_proj.weight"]
new_model_data["encoder_layer.self_attn.q_proj.bias"]   = model_data["text_model.encoder.layers.11.self_attn.q_proj.bias"]
new_model_data["encoder_layer.self_attn.out_proj.weight"] = model_data["text_model.encoder.layers.11.self_attn.out_proj.weight"]
new_model_data["encoder_layer.self_attn.out_proj.bias"] = model_data["text_model.encoder.layers.11.self_attn.out_proj.bias"]
new_model_data["encoder_layer.layer_norm1.weight"]      = model_data["text_model.encoder.layers.11.layer_norm1.weight"]
new_model_data["encoder_layer.layer_norm1.bias"]        = model_data["text_model.encoder.layers.11.layer_norm1.bias"]
new_model_data["encoder_layer.mlp.fc1.weight"]          = model_data["text_model.encoder.layers.11.mlp.fc1.weight"]
new_model_data["encoder_layer.mlp.fc1.bias"]            = model_data["text_model.encoder.layers.11.mlp.fc1.bias"]
new_model_data["encoder_layer.mlp.fc2.weight"]          = model_data["text_model.encoder.layers.11.mlp.fc2.weight"]
new_model_data["encoder_layer.mlp.fc2.bias"]            = model_data["text_model.encoder.layers.11.mlp.fc2.bias"]
new_model_data["encoder_layer.layer_norm2.weight"]      = model_data["text_model.encoder.layers.11.layer_norm2.weight"]
new_model_data["encoder_layer.layer_norm2.bias"]        = model_data["text_model.encoder.layers.11.layer_norm2.bias"]
new_model_data["final_layer_norm.weight"]               = model_data["text_model.final_layer_norm.weight"]
new_model_data["final_layer_norm.bias"]                 = model_data["text_model.final_layer_norm.bias"]
new_model_data["projection.weight"] = torch.linalg.inv(model_data["text_projection.weight"])
torch.save(new_model_data, f"{base_dir}/init_model/pytorch_model.bin")

#Train the model

In [None]:
#@title Set the previous part number.
part = -1

In [None]:
#@title Load CLIP models
from transformers import CLIPTokenizer, CLIPTextModelWithProjection

device = "cuda"

if "encoder" not in globals():
  tokenizer = CLIPTokenizer.from_pretrained(
    "openai/clip-vit-large-patch14")
  encoder = CLIPTextModelWithProjection.from_pretrained(
    "openai/clip-vit-large-patch14").to(device)

def tokenize(text):
        return tokenizer(
            text,
            padding="max_length",
            max_length=77,
            truncation=True,
            return_tensors="pt",  # PyTorch
        )

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/961k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/905 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/4.52k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.71G [00:00<?, ?B/s]

Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPTextModelWithProjection: ['vision_model.encoder.layers.3.self_attn.out_proj.bias', ... 'vision_model.encoder.layers.3.layer_norm1.bias']
- This IS expected if you are initializing CLIPTextModelWithProjection from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing CLIPTextModelWithProjection from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
#@title Trainer
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler

class DataSet:
    def __init__(self, data, start, end):
        self.embeds_ls = data["embeds"][start:end].to(device)
        self.last_state_ls = data["last_state"][start:end, :].to(device)

    def __len__(self):
        return len(self.embeds_ls)

    def __getitem__(self, index):
        embeds = self.embeds_ls[index]
        last_state = self.last_state_ls[index]
        return last_state, torch.cat([embeds.unsqueeze(0), last_state], dim=0)[:-1, :]

class Trainer:
  def __init__(self, data, start, end):
    self.ds = DataSet(data, start, end - 100)
    self.test_ds = DataSet(data, end - 100, end)

    # learning_rate = 0.001
    self.criterion = nn.MSELoss()
    self.optimizer = torch.optim.Adam(model.parameters())

  def training(self):
    global epoch
    epoch += 1
    print(f"Epoch: {epoch}")

    #sampler = RandomSampler(ds, replacement=True, num_samples=100*10)
    dl = DataLoader(self.ds, batch_size=100, shuffle=True) # sampler=sampler
    for i, (target, input) in enumerate(dl):
      self.optimizer.zero_grad()
      pred = model(input)
      loss = self.criterion(pred, target)
      loss.backward()
      self.optimizer.step()
      if (i+1)%10 == 0:
        print(f"Batch {i+1}, loss {loss.item():.4f}, target: {target.shape}, input: {input.shape}")

  def test(self):
    with torch.no_grad():
      dl = DataLoader(self.test_ds, batch_size=100)
      for target, input in dl:
        pred = model(input)
        loss = self.criterion(pred, target)
        print(f"Test, loss {loss.item():.4f}, target: {target.shape}, input: {input.shape}")
        return loss.item()

  def test_inference(self):
    with torch.no_grad():
      dl = DataLoader(self.test_ds, batch_size=100)
      for target, input in dl:
        pred = model.Inference(input[:, :1, :].squeeze(1))
        loss = self.criterion(pred, target)
        print(f"Inference test, loss {loss.item():.4f}, target: {target.shape}, input: {input.shape}")
        return loss.item()

##Repeat below

In [None]:
#@title Training Data Generation

part += 1
print(f"Start part {part}")

import csv
import torch

if "data" in globals():
  for k, v in data.items():
    v.to("cpu")
  del data
  torch.cuda.empty_cache()

num_data = 10000  # 500000
embed_ls = []
last_state_ls = []
with open(f'{base_dir}/texts.csv', newline='') as f:
  with torch.no_grad():
    r = csv.reader(f)
    for i, row in enumerate(r):
      if i == 0:
        print(f"{i:08n} => {row}")
        continue
      orig_i = i
      i = i - (part * num_data)
      if i <= 0: continue
      if i > num_data: break
      if len(row) == 0: continue
      try:
        tokens = tokenize(row[0]).input_ids.to("cuda")
        encoded = encoder(tokens)
        embed_ls.append(encoded.text_embeds)
        last_state_ls.append(encoded.last_hidden_state)
        if i % 1000 == 0:
          print(f"{orig_i:08n} => {row}")
          #print(tokens.shape)
          #print(encoded.text_embeds.shape)
          #print(encoded.last_hidden_state.shape)
      except Exception:
        print(f"error at: {orig_i:08n} => {row}")
        print(tokens.shape)
        raise

data = {
    "embeds": torch.cat(embed_ls, dim=0),
    "last_state": torch.cat(last_state_ls, dim=0),
}
#torch.save(data, f"{base_dir}/dataset_{part:05n}.pt")

Start part 9
00000000 => ['TEXT']
00091000 => ['bankruptcy order cartoon']
00092000 => ['Australia-NSW-Comboyne Plateau and Beach Ride']
00093000 => ['Tormenter 4X4 Ocean Fade Board Shorts']
00094000 => ['I Love FLAGSTAFF Arizona Coffee Mug']
00095000 => ['Derrick Rose #1 - MVP - Chicago Bulls (cshimala) Tags: chicago basketball court 1 crowd bulls playoffs fans adidas nba aroundtown basketballcourt unitedcenter game1 mvp drose chicagobulls pacers chicagoist nbaplayoffs indianapacers dabulls derrickrose']
00096000 => ['10 fun facts about coral reefs']
00097000 => ['miriam keloy in Hadrian Gala After-Party 2014']
00098000 => ['S5 Julien Double Wall Sconce']
00099000 => ['Alice In Wonerland Rare White Rabbit And Cheshire Cat Sugar And Creamer Set']
00100000 => ['Monastery of Ostrog in Montenegro - Stock Photo']


In [None]:
#@title Load Base Model and reset epoch
if "model" in globals():
  model.to("cpu")
  del model

if part == 0:
  model_path = f"{base_dir}/init_model/"
else:
  model_path = f"{base_dir}/model_part_{part - 1:02n}/"
print(f"part {part} based on {model_path}")
model = CLIPTextDeprojector.from_pretrained(model_path)
model.to(device)

if part == 0:
  print(f"Set SOS state")
  model.sos_embed = data["last_state"][0, 0].squeeze().detach().clone()
  print(model.sos_embed.shape)

epoch = 0

part 9 based on /content/drive/MyDrive/sd/deprojector/model_part_08/


In [None]:
#@title Initialize Training
test_results = []
inference_test_results = []
trainer = Trainer(data, 0, 10000)
test_results.append(trainer.test())
display(test_results)
inference_test_results.append(trainer.test_inference())
display(inference_test_results)

Test, loss 0.1147, target: torch.Size([100, 77, 768]), input: torch.Size([100, 77, 768])


[0.1146882027387619]

Inference test, loss 0.8182, target: torch.Size([100, 77, 768]), input: torch.Size([100, 77, 768])


[0.818166971206665]

In [None]:
print(f"part {part}")
trainer.training()
test_results.append(trainer.test())
display(test_results[-10:])
inference_test_results.append(trainer.test_inference())
display(inference_test_results[-10:])

part 9
Epoch: 3
Batch 10, loss 0.0968, target: torch.Size([100, 77, 768]), input: torch.Size([100, 77, 768])
Batch 20, loss 0.0986, target: torch.Size([100, 77, 768]), input: torch.Size([100, 77, 768])
Batch 30, loss 0.1103, target: torch.Size([100, 77, 768]), input: torch.Size([100, 77, 768])
Batch 40, loss 0.1152, target: torch.Size([100, 77, 768]), input: torch.Size([100, 77, 768])
Batch 50, loss 0.1012, target: torch.Size([100, 77, 768]), input: torch.Size([100, 77, 768])
Batch 60, loss 0.1028, target: torch.Size([100, 77, 768]), input: torch.Size([100, 77, 768])
Batch 70, loss 0.0918, target: torch.Size([100, 77, 768]), input: torch.Size([100, 77, 768])
Batch 80, loss 0.1122, target: torch.Size([100, 77, 768]), input: torch.Size([100, 77, 768])
Batch 90, loss 0.0921, target: torch.Size([100, 77, 768]), input: torch.Size([100, 77, 768])
Test, loss 0.1151, target: torch.Size([100, 77, 768]), input: torch.Size([100, 77, 768])


[0.1146882027387619,
 0.1144610121846199,
 0.114711232483387,
 0.11509108543395996]

Inference test, loss 0.8553, target: torch.Size([100, 77, 768]), input: torch.Size([100, 77, 768])


[0.818166971206665, 0.8375610709190369, 0.8436762094497681, 0.8552973866462708]

In [None]:
model_to_save = f"model_part_{part:02n}"
model.save_pretrained(f"{base_dir}/{model_to_save}/")
print(model_to_save)
!ls -lh $base_dir/$model_to_save/
summary = {
    "model": model_to_save,
    "part": part,
    "epoch": epoch,
    "test_loss": test_results[-1],
    "inference_test_loss": inference_test_results[-1]
}
torch.save(summary, f"{base_dir}/{model_to_save}/epoch.pt")
ep = torch.load(f"{base_dir}/{model_to_save}/epoch.pt")
ep

model_part_09
total 30M
-rw------- 1 root root 641 Apr 22 11:06 config.json
-rw------- 1 root root 30M Apr 22 11:06 pytorch_model.bin


{'model': 'model_part_09',
 'part': 9,
 'epoch': 1,
 'test_loss': 0.1144610121846199,
 'inference_test_loss': 0.8375610709190369}

In [None]:
for model_name in [f"model_part_{p:02n}" for p in range(part + 1)]:
  ep = torch.load(f"{base_dir}/{model_name}/epoch.pt")
  print(model_name)
  display(ep)

model_part_00


{'model': 'model_part_00',
 'part': 0,
 'epoch': 7,
 'test_loss': 0.1280847191810608,
 'inference_test_loss': 0.8230238556861877}

model_part_01


{'model': 'model_part_01',
 'part': 1,
 'epoch': 4,
 'test_loss': 0.11513874679803848,
 'inference_test_loss': 0.8857951760292053}

model_part_02


{'model': 'model_part_02',
 'part': 2,
 'epoch': 3,
 'test_loss': 0.1050417423248291,
 'inference_test_loss': 0.8039863109588623}

model_part_03


{'model': 'model_part_03',
 'part': 3,
 'epoch': 2,
 'test_loss': 0.10449869185686111,
 'inference_test_loss': 0.8036409020423889}

model_part_04


{'model': 'model_part_04',
 'part': 4,
 'epoch': 2,
 'test_loss': 0.11442147940397263,
 'inference_test_loss': 0.8488426208496094}

model_part_05


{'model': 'model_part_05',
 'part': 5,
 'epoch': 2,
 'test_loss': 0.11371234804391861,
 'inference_test_loss': 0.9070765376091003}

model_part_06


{'model': 'model_part_06',
 'part': 6,
 'epoch': 3,
 'test_loss': 0.10346761345863342,
 'inference_test_loss': 0.7673318982124329}

model_part_07


{'model': 'model_part_07',
 'part': 7,
 'epoch': 2,
 'test_loss': 0.10422901809215546,
 'inference_test_loss': 0.829045295715332}

model_part_08


{'model': 'model_part_08',
 'part': 8,
 'epoch': 2,
 'test_loss': 0.11055289208889008,
 'inference_test_loss': 0.7834908962249756}

model_part_09


{'model': 'model_part_09',
 'part': 9,
 'epoch': 1,
 'test_loss': 0.1144610121846199,
 'inference_test_loss': 0.8375610709190369}