In [33]:
from transformers import CLIPTextModel, AutoTokenizer
from torch.utils.data import Dataset

import torch
import torch.nn as nn

In [39]:
class TokenDataset(Dataset):
    
    def __init__(self, model_one: str = "openai/clip-vit-base-patch32", 
                 model_two: str = "openai/clip-vit-large-patch14"):
        
        self.model_b32 = CLIPTextModel.from_pretrained(model_one)
        self.tokenizer_b32 = AutoTokenizer.from_pretrained(model_one)

        self.model_l14 = CLIPTextModel.from_pretrained(model_two)
        self.tokenizer_l14 = AutoTokenizer.from_pretrained(model_two)
        
        self.embeds_one = self.model_b32.get_input_embeddings()
        self.embeds_two = self.model_l14.get_input_embeddings()
        
        self.embed_dim_one = self.embeds_one.embedding_dim
        self.embed_dim_two = self.embeds_two.embedding_dim
        
    def __len__(self):
        
        return len(self.tokenizer_b32)
        
    def __getitem__(self, idx):
        
        token = self.tokenizer_b32.convert_ids_to_tokens(idx)
        second_idx = self.tokenizer_l14.convert_tokens_to_ids(token)
        
        idx = torch.tensor([idx], dtype=torch.int64)
        second_idx = torch.tensor([second_idx], dtype=torch.int64)
        
        embeds_one = self.embeds_one(idx)[0]
        embeds_two = self.embeds_two(second_idx)[0]
        
        return embeds_one, embeds_two

In [55]:
dataset = TokenDataset(
    model_one="openai/clip-vit-base-patch32", 
    model_two="openai/clip-vit-large-patch14"
)

Some weights of the model checkpoint at openai/clip-vit-base-patch32 were not used when initializing CLIPTextModel: ['vision_model.encoder.layers.3.mlp.fc1.weight', 'vision_model.encoder.layers.6.layer_norm2.bias', 'vision_model.encoder.layers.11.mlp.fc1.weight', 'vision_model.encoder.layers.10.layer_norm2.weight', 'logit_scale', 'vision_model.encoder.layers.1.self_attn.q_proj.weight', 'vision_model.encoder.layers.3.self_attn.v_proj.bias', 'vision_model.encoder.layers.2.self_attn.out_proj.weight', 'vision_model.encoder.layers.8.self_attn.q_proj.bias', 'vision_model.encoder.layers.7.layer_norm1.weight', 'vision_model.encoder.layers.5.self_attn.k_proj.bias', 'vision_model.encoder.layers.6.layer_norm2.weight', 'vision_model.encoder.layers.4.self_attn.k_proj.bias', 'vision_model.encoder.layers.11.self_attn.v_proj.bias', 'vision_model.encoder.layers.9.self_attn.q_proj.bias', 'vision_model.encoder.layers.11.layer_norm2.weight', 'vision_model.encoder.layers.11.mlp.fc2.bias', 'vision_model.enc

Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPTextModel: ['vision_model.encoder.layers.3.mlp.fc1.weight', 'vision_model.encoder.layers.15.layer_norm1.weight', 'vision_model.encoder.layers.6.layer_norm2.bias', 'vision_model.encoder.layers.14.self_attn.q_proj.bias', 'vision_model.encoder.layers.21.self_attn.v_proj.weight', 'vision_model.encoder.layers.1.self_attn.q_proj.weight', 'vision_model.encoder.layers.3.self_attn.v_proj.bias', 'vision_model.encoder.layers.6.layer_norm2.weight', 'vision_model.encoder.layers.23.mlp.fc1.bias', 'vision_model.encoder.layers.2.layer_norm1.bias', 'vision_model.encoder.layers.22.mlp.fc1.weight', 'vision_model.encoder.layers.23.mlp.fc1.weight', 'vision_model.encoder.layers.4.mlp.fc2.bias', 'vision_model.encoder.layers.15.mlp.fc2.weight', 'vision_model.encoder.layers.11.layer_norm2.bias', 'vision_model.encoder.layers.7.mlp.fc1.bias', 'vision_model.encoder.layers.12.layer_norm1.weight', 'vision_model

In [59]:
A, B = zip(*[dataset[i] for i in range(len(dataset))])

A = torch.stack(A, dim=0)
B = torch.stack(B, dim=0)

In [60]:
X = torch.linalg.lstsq(B, A).solution

In [61]:
((B @ X - A) ** 2).mean()

tensor(0.0001, grad_fn=<MeanBackward0>)

In [63]:
torch.save(X, "L14-to-B32-linear.pt")