In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Dict, Union, Optional
from dataclasses import dataclass

In [None]:
class SelfAttention(nn.Module):
    """
    Self Attention mechanism for sequence data.
    
    Attributes:
        scale (float): Scaling factor for the attention scores.
        n_heads (int): Number of attention heads.
        d_head (int): Dimension of each attention head.
        QKV (nn.Linear): Linear layer for Query, Key, Value.
        O (nn.Linear): Linear output layer.
    """
    
    def __init__(self, d_embed, n_heads: int = 4, qkv_bias=True, out_bias=True) -> None:
        """
        Initializes the SelfAttention class.
        
        Args:
            d_embed (int): Dimension of the embedding.
            n_heads (int): Number of attention heads. Defaults to 4.
            qkv_bias (bool): If True, adds bias to QKV linear layer. Defaults to True.
            out_bias (bool): If True, adds bias to O linear layer. Defaults to True.
        """
        
        super().__init__()
        
        self.n_heads = n_heads
        self.d_head = d_embed // n_heads
        self.scale = self.d_head ** -0.5
        self.QKV = nn.Linear(d_embed, d_embed * 3, bias=qkv_bias)
        self.O = nn.Linear(d_embed, d_embed, bias=out_bias)
        
    def forward(self, x: torch.Tensor, mask: bool = False) -> torch.Tensor:
        """
        Forward pass for the SelfAttention mechanism.
        
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_embed).
            mask (bool): If True, applies the attention mask. Defaults to False.

        Returns:
            torch.Tensor: Processed tensor.
        """
        
        # x: (batch_size, height*width, channels)
        # x: (batch_size, seq_len, d_embed)
        in_shape = x.shape
        bs, seq_len, d_embed = x.shape
        q, k, v = self.QKV(x).chunk(3, dim=-1) # (batch_size, seq_len, d_embed)@(d_embed, d_embed*3) -> (batch_size, seq_len, d_embed*3) -> (3x) (batch_size, seq_len, d_embed)
        
        # (batch_size, seq_len, d_embed) -> (batch_size, seq_len, n_heads, d_head) -> (batch_size, n_heads, seq_len, d_head)
        q = q.view(bs, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        k = k.view(bs, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        v = v.view(bs, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        
        attn_scores = (q @ k.transpose(-2, -1)) * self.scale
        if not mask:
            mask = torch.ones_like(attn_scores).bool().triu(1) # (batch_size, n_heads, seq_len, seq_len)
            attn_scores.masked_fill_(mask, -1e9)
        weights = F.softmax(attn_scores, dim=-1)
        output = weights @ v # (batch_size, n_heads, seq_len, seq_len) -> (batch_size, n_heads, seq_len, d_head) -> (batch_size, n_heads, seq_len, d_head)
        output = output.transpose(1, 2).contiguous().view(in_shape) # (batch_size, n_heads, seq_len, d_head) -> (batch_size, seq_len, n_heads, d_head) -> (batch_size, seq_len, d_embed)
        return self.O(output) # (batch_size, seq_len, d_embed)@(d_embed, d_embed) -> (batch_size, seq_len, d_embed)


In [None]:
@dataclass
class Config:
    vocab_size: int = 49408
    hidden_size: int = 768
    seq_len: int = 77
    n_heads: int = 12
    n_layers: int = 12
    layer_norm_eps: float = 1e-5

In [None]:
class CLIPEmbedding(nn.Module):
    """
    A class that converts input ids to embeddings.
    """
    def __init__(self, config: Config) -> None:
        super().__init__()
        
        self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
        self.position_embedding = nn.Parameter(torch.zeros(config.seq_len, config.hidden_size))
        
    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        # input_ids: (batch_size, seq_len)
        embeddings = self.token_embedding(input_ids) + self.position_embedding # (batch_size, seq_len, hidden_size)
        return embeddings

In [None]:
class CLIPBlock(nn.Module):
    
    def __init__(self, config: Config) -> None:
        super().__init__()
        
        self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.attention = SelfAttention(config.hidden_size, config.n_heads)
        self.mlp = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size * 4),
            nn.GELU(),
            nn.Linear(config.hidden_size * 4, config.hidden_size)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch_size, seq_len, hidden_size)
        x = x + self.attention(self.layer_norm1(x), mask=True)
        return x + self.mlp(self.layer_norm2(x)) # (batch_size, seq_len, hidden_size)

In [None]:
class CLIP(nn.Module):
    """
    A basic implementation of OpenAI's CLIP model
    """
    
    def __init__(self, config: Config) -> None:
        super().__init__()
        
        self.embeddings = CLIPEmbedding(config)
        self.layers = nn.ModuleList([CLIPBlock(config) for _ in range(config.n_layers)])
        # self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        
    def forward(self, x: torch.LongTensor) -> torch.FloatTensor:
        token_embeddings = self.embeddings(x.type(torch.long)) # (batch_size, seq_len) -> (batch_size, seq_len, hidden_size)
        for layer in self.layers:
            token_embeddings = layer(token_embeddings)
        return token_embeddings # (batch_size, seq_len, hidden_size)

In [None]:
# generate a test code
input_ids = torch.randint(0, 49408, (1, 77))
config = Config()
clip = CLIP(config)

In [None]:
def print_parameters(model):
    for name, param in model.state_dict().items():
        print(name)
        
# Assuming 'model' is your Pytorch model
# print_parameters(clip)

In [None]:
input_file = '../data/checkpoints/v1-5-pruned-emaonly.ckpt'
device = 'mps'
original_model = torch.load(input_file, map_location=device, weights_only = False)["state_dict"]
converted = {}

In [None]:
converted['clip'] = {}

converted['clip']['embeddings.token_embedding.weight'] = original_model['cond_stage_model.transformer.text_model.embeddings.token_embedding.weight']
converted['clip']['embeddings.position_embedding'] = original_model['cond_stage_model.transformer.text_model.embeddings.position_embedding.weight']
converted['clip']['layers.0.attention.O.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight']
converted['clip']['layers.0.attention.O.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias']
converted['clip']['layers.0.layer_norm1.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight']
converted['clip']['layers.0.layer_norm1.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias']
converted['clip']['layers.0.mlp.0.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight']
converted['clip']['layers.0.mlp.0.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias']
converted['clip']['layers.0.mlp.2.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight']
converted['clip']['layers.0.mlp.2.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias']
converted['clip']['layers.0.layer_norm2.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight']
converted['clip']['layers.0.layer_norm2.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias']
converted['clip']['layers.1.attention.O.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight']
converted['clip']['layers.1.attention.O.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias']
converted['clip']['layers.1.layer_norm1.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight']
converted['clip']['layers.1.layer_norm1.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias']
converted['clip']['layers.1.mlp.0.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight']
converted['clip']['layers.1.mlp.0.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias']
converted['clip']['layers.1.mlp.2.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight']
converted['clip']['layers.1.mlp.2.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias']
converted['clip']['layers.1.layer_norm2.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight']
converted['clip']['layers.1.layer_norm2.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias']
converted['clip']['layers.2.attention.O.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight']
converted['clip']['layers.2.attention.O.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias']
converted['clip']['layers.2.layer_norm1.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight']
converted['clip']['layers.2.layer_norm1.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias']
converted['clip']['layers.2.mlp.0.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight']
converted['clip']['layers.2.mlp.0.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias']
converted['clip']['layers.2.mlp.2.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight']
converted['clip']['layers.2.mlp.2.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias']
converted['clip']['layers.2.layer_norm2.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight']
converted['clip']['layers.2.layer_norm2.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias']
converted['clip']['layers.3.attention.O.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight']
converted['clip']['layers.3.attention.O.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias']
converted['clip']['layers.3.layer_norm1.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight']
converted['clip']['layers.3.layer_norm1.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias']
converted['clip']['layers.3.mlp.0.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight']
converted['clip']['layers.3.mlp.0.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias']
converted['clip']['layers.3.mlp.2.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight']
converted['clip']['layers.3.mlp.2.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias']
converted['clip']['layers.3.layer_norm2.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight']
converted['clip']['layers.3.layer_norm2.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias']
converted['clip']['layers.4.attention.O.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight']
converted['clip']['layers.4.attention.O.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias']
converted['clip']['layers.4.layer_norm1.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight']
converted['clip']['layers.4.layer_norm1.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias']
converted['clip']['layers.4.mlp.0.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight']
converted['clip']['layers.4.mlp.0.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias']
converted['clip']['layers.4.mlp.2.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight']
converted['clip']['layers.4.mlp.2.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias']
converted['clip']['layers.4.layer_norm2.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight']
converted['clip']['layers.4.layer_norm2.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias']
converted['clip']['layers.5.attention.O.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight']
converted['clip']['layers.5.attention.O.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias']
converted['clip']['layers.5.layer_norm1.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight']
converted['clip']['layers.5.layer_norm1.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias']
converted['clip']['layers.5.mlp.0.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight']
converted['clip']['layers.5.mlp.0.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias']
converted['clip']['layers.5.mlp.2.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight']
converted['clip']['layers.5.mlp.2.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias']
converted['clip']['layers.5.layer_norm2.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight']
converted['clip']['layers.5.layer_norm2.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias']
converted['clip']['layers.6.attention.O.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight']
converted['clip']['layers.6.attention.O.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias']
converted['clip']['layers.6.layer_norm1.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight']
converted['clip']['layers.6.layer_norm1.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias']
converted['clip']['layers.6.mlp.0.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight']
converted['clip']['layers.6.mlp.0.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias']
converted['clip']['layers.6.mlp.2.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight']
converted['clip']['layers.6.mlp.2.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias']
converted['clip']['layers.6.layer_norm2.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight']
converted['clip']['layers.6.layer_norm2.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias']
converted['clip']['layers.7.attention.O.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight']
converted['clip']['layers.7.attention.O.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias']
converted['clip']['layers.7.layer_norm1.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight']
converted['clip']['layers.7.layer_norm1.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias']
converted['clip']['layers.7.mlp.0.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight']
converted['clip']['layers.7.mlp.0.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias']
converted['clip']['layers.7.mlp.2.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight']
converted['clip']['layers.7.mlp.2.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias']
converted['clip']['layers.7.layer_norm2.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight']
converted['clip']['layers.7.layer_norm2.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias']
converted['clip']['layers.8.attention.O.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight']
converted['clip']['layers.8.attention.O.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias']
converted['clip']['layers.8.layer_norm1.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight']
converted['clip']['layers.8.layer_norm1.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias']
converted['clip']['layers.8.mlp.0.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight']
converted['clip']['layers.8.mlp.0.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias']
converted['clip']['layers.8.mlp.2.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight']
converted['clip']['layers.8.mlp.2.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias']
converted['clip']['layers.8.layer_norm2.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight']
converted['clip']['layers.8.layer_norm2.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias']
converted['clip']['layers.9.attention.O.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight']
converted['clip']['layers.9.attention.O.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias']
converted['clip']['layers.9.layer_norm1.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight']
converted['clip']['layers.9.layer_norm1.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias']
converted['clip']['layers.9.mlp.0.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight']
converted['clip']['layers.9.mlp.0.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias']
converted['clip']['layers.9.mlp.2.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight']
converted['clip']['layers.9.mlp.2.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias']
converted['clip']['layers.9.layer_norm2.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight']
converted['clip']['layers.9.layer_norm2.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias']
converted['clip']['layers.10.attention.O.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight']
converted['clip']['layers.10.attention.O.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias']
converted['clip']['layers.10.layer_norm1.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight']
converted['clip']['layers.10.layer_norm1.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias']
converted['clip']['layers.10.mlp.0.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight']
converted['clip']['layers.10.mlp.0.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias']
converted['clip']['layers.10.mlp.2.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight']
converted['clip']['layers.10.mlp.2.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias']
converted['clip']['layers.10.layer_norm2.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight']
converted['clip']['layers.10.layer_norm2.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias']
converted['clip']['layers.11.attention.O.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight']
converted['clip']['layers.11.attention.O.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias']
converted['clip']['layers.11.layer_norm1.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight']
converted['clip']['layers.11.layer_norm1.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias']
converted['clip']['layers.11.mlp.0.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight']
converted['clip']['layers.11.mlp.0.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias']
converted['clip']['layers.11.mlp.2.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight']
converted['clip']['layers.11.mlp.2.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias']
converted['clip']['layers.11.layer_norm2.weight'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight']
converted['clip']['layers.11.layer_norm2.bias'] = original_model['cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias']
converted['clip']['layers.11.layer_norm2.weight'] = original_model['cond_stage_model.transformer.text_model.final_layer_norm.weight']
converted['clip']['layers.11.layer_norm2.bias'] = original_model['cond_stage_model.transformer.text_model.final_layer_norm.bias']
converted['clip']['layers.0.attention.QKV.weight'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight']), 0)
converted['clip']['layers.0.attention.QKV.bias'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias']), 0)
converted['clip']['layers.1.attention.QKV.weight'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight']), 0)
converted['clip']['layers.1.attention.QKV.bias'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias']), 0)
converted['clip']['layers.2.attention.QKV.weight'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight']), 0)
converted['clip']['layers.2.attention.QKV.bias'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias']), 0)
converted['clip']['layers.3.attention.QKV.weight'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight']), 0)
converted['clip']['layers.3.attention.QKV.bias'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias']), 0)
converted['clip']['layers.4.attention.QKV.weight'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight']), 0)
converted['clip']['layers.4.attention.QKV.bias'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias']), 0)
converted['clip']['layers.5.attention.QKV.weight'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight']), 0)
converted['clip']['layers.5.attention.QKV.bias'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias']), 0)
converted['clip']['layers.6.attention.QKV.weight'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight']), 0)
converted['clip']['layers.6.attention.QKV.bias'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias']), 0)
converted['clip']['layers.7.attention.QKV.weight'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight']), 0)
converted['clip']['layers.7.attention.QKV.bias'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias']), 0)
converted['clip']['layers.8.attention.QKV.weight'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight']), 0)
converted['clip']['layers.8.attention.QKV.bias'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias']), 0)
converted['clip']['layers.9.attention.QKV.weight'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight']), 0)
converted['clip']['layers.9.attention.QKV.bias'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias']), 0)
converted['clip']['layers.10.attention.QKV.weight'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight']), 0)
converted['clip']['layers.10.attention.QKV.bias'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias']), 0)
converted['clip']['layers.11.attention.QKV.weight'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight'], original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight']), 0)
converted['clip']['layers.11.attention.QKV.bias'] = torch.cat((original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias'], original_model['cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias']), 0)

In [None]:
clip.load_state_dict(converted['clip'], strict=True)

: 

In [None]:
converted['clip']