In [95]:
import accelerate
import datasets
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from PIL import Image

import transformers
from transformers import (
    AutoImageProcessor,
    AutoModel,
    AutoTokenizer,
    HfArgumentParser,
    TrainingArguments,
    set_seed,
    Trainer
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import send_example_telemetry
from transformers.utils.versions import require_version
from transformers import CLIPTokenizer, CLIPConfig, CLIPModel
from transformers.trainer_pt_utils import get_parameter_names
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS

import diffusers
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers.utils import is_wandb_available

from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.state import AcceleratorState
from accelerate.utils import ProjectConfiguration, set_seed

if is_wandb_available():
    import wandb


In [96]:
clip_config1 = CLIPConfig()

In [97]:
type(clip_config1)

transformers.models.clip.configuration_clip.CLIPConfig

In [98]:
clip_config1.text_config

CLIPTextConfig {
  "attention_dropout": 0.0,
  "bos_token_id": 49406,
  "eos_token_id": 49407,
  "hidden_act": "quick_gelu",
  "hidden_size": 512,
  "initializer_factor": 1.0,
  "initializer_range": 0.02,
  "intermediate_size": 2048,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 77,
  "model_type": "clip_text_model",
  "num_attention_heads": 8,
  "num_hidden_layers": 12,
  "pad_token_id": 1,
  "projection_dim": 512,
  "transformers_version": "4.31.0",
  "vocab_size": 49408
}

In [104]:
clip_model_clip = AutoModel.from_config(clip_config1)

In [105]:
clip_config.vision_config.image_size

224

In [106]:
hasattr(clip_model1.text_model.embeddings,"token_embedding")

True

In [107]:
clip_model1.text_model.embeddings.token_embedding.embedding_dim

512

In [13]:
model_save_path = "/share/test/songtianwei/model_save"

In [99]:
clip_model_pre = AutoModel.from_pretrained(
    model_save_path,
)

In [100]:
hasattr(clip_model_pre.text_model.embeddings,"word_embeddings")

True

In [101]:
clip_model_pre.text_model.embeddings.word_embeddings.embedding_dim

768

In [29]:
clip_config.vision_config.num_channels

3

In [12]:
accelerator = Accelerator(
        gradient_accumulation_steps=1,
        mixed_precision="no",
   
    )

In [13]:
accelerator.use_distributed

False

In [31]:
vae_config = {
    'sample_size': [224,224],  # 512
    'in_channels': 3,
    'out_channels': 3,
    'down_block_types': ['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D'],
    'up_block_types': ['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D'],
    'block_out_channels': [128, 256, 512],
    'layers_per_block': 2,
    'act_fn': 'silu',
    'latent_channels': 4,
    'norm_num_groups': 32,
    'scaling_factor': 0.18215,
}
         

In [32]:
vae = AutoencoderKL(**vae_config)

In [33]:
vae_config = vae.config

In [34]:
vae_config

FrozenDict([('in_channels', 3),
            ('out_channels', 3),
            ('down_block_types',
             ['DownEncoderBlock2D',
              'DownEncoderBlock2D',
              'DownEncoderBlock2D']),
            ('up_block_types',
             ['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D']),
            ('block_out_channels', [128, 256, 512]),
            ('layers_per_block', 2),
            ('act_fn', 'silu'),
            ('latent_channels', 4),
            ('norm_num_groups', 32),
            ('sample_size', [224, 224]),
            ('scaling_factor', 0.18215),
            ('force_upcast', True),
            ('_use_default_values', ['force_upcast'])])

In [35]:
a = torch.ones((4,3,224,224))

In [36]:
a_latent = vae.encode(a).latent_dist.sample()

In [37]:
a_latent.shape

torch.Size([4, 4, 56, 56])

In [38]:
vae_config["out_channels"]

3

In [46]:
unet_config = {
    "in_channels": vae_config["latent_channels"],
    "out_channels": vae_config["latent_channels"],
    "sample_size": 56,
    "act_fn": "silu",
    "attention_head_dim": 8,
    "block_out_channels": [
        320,
        640,
        1280,
        1280
    ],
    "center_input_sample": False,
    "cross_attention_dim": 512,  
    "down_block_types": [
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",
        "DownBlock2D"
    ],
    "downsample_padding": 1,
    "flip_sin_to_cos": True,
    "freq_shift": 0,
    "layers_per_block": 2,
    "mid_block_scale_factor": 1,
    "norm_eps": 1e-05,
    "norm_num_groups": 32,
    "up_block_types": [
        "UpBlock2D",
        "CrossAttnUpBlock2D",
        "CrossAttnUpBlock2D",
        "CrossAttnUpBlock2D"
    ]
}


In [47]:
unet = UNet2DConditionModel(**unet_config)

In [48]:
timestep = torch.ones([1])

In [49]:
encoder_hidden_states = torch.ones((4,77,768))

In [50]:
encoder_hidden_states = torch.ones((4,77,512))  # goes wrong

In [51]:
a_latent2 = unet(a_latent,timestep,encoder_hidden_states).sample

In [52]:
a_latent2.shape

torch.Size([4, 4, 56, 56])

In [53]:
res = vae.decoder(a_latent2)

In [54]:
res.shape

torch.Size([4, 3, 224, 224])

In [56]:
import torch.nn as nn

In [57]:
class Generator(nn.Module):
    def __init__(self, image_channel=3, image_shape=[224,224], text_embedding_dim=512):
        super(Generator, self).__init__()
        self.image_channel = image_channel
        self.image_shape = image_shape if isinstance(image_shape, list) else [image_shape, image_shape]
        self.text_embedding_dim = text_embedding_dim
        
        self.vae_config = {
            'sample_size': self.image_shape,  # 512
            'in_channels': self.image_channel,
            'out_channels': self.image_channel,
            'down_block_types': ['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D'],
            'up_block_types': ['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D'],
            'block_out_channels': [128, 256, 512, 512],
            'layers_per_block': 2,
            'act_fn': 'silu',
            'latent_channels': 4,
            'norm_num_groups': 32,
            'scaling_factor': 0.18215,
        }
         
        self.vae = AutoencoderKL(**self.vae_config)
        
        self.unet_config = {
            "in_channels": self.vae_config["latent_channels"],
            "out_channels": self.vae_config["latent_channels"],
            "sample_size": 28,
            "act_fn": "silu",
            "attention_head_dim": 8,
            "block_out_channels": [
                320,
                640,
                1280,
                1280
            ],
            "center_input_sample": False,
            "cross_attention_dim": self.text_embedding_dim,  
            "down_block_types": [
                "CrossAttnDownBlock2D",
                "CrossAttnDownBlock2D",
                "CrossAttnDownBlock2D",
                "DownBlock2D"
            ],
            "downsample_padding": 1,
            "flip_sin_to_cos": True,
            "freq_shift": 0,
            "layers_per_block": 2,
            "mid_block_scale_factor": 1,
            "norm_eps": 1e-05,
            "norm_num_groups": 32,
            "up_block_types": [
                "UpBlock2D",
                "CrossAttnUpBlock2D",
                "CrossAttnUpBlock2D",
                "CrossAttnUpBlock2D"
            ]
        }

        self.unet = UNet2DConditionModel(**self.unet_config)
        
        
    def forward(self, img_pixel_values, encoder_hidden_states):
        latent = self.vae.encode(img_pixel_values).latent_dist.sample()
        timesteps = torch.randint(0, 1000, (1,),device=latent.device)
        timesteps = timesteps.long()  #  6
        unet_pred = self.unet(latent, timesteps, encoder_hidden_states).sample
        vae_decoding = self.vae.decoder(unet_pred)
        return vae_decoding
    
    
    def enable_xformers_memory_efficient_attention(self):
        self.unet.enable_xformers_memory_efficient_attention()
        self.vae.enable_xformers_memory_efficient_attention()

In [108]:
batch_input_ids = torch.ones((4,77),dtype=torch.int64)

In [109]:
batch_attention_mask = torch.ones((4,77),dtype=torch.int64)

In [111]:
text_encoder_pre = clip_model_pre.text_model

In [113]:
text_encoder_pre

RobertaModel(
  (embeddings): RobertaEmbeddings(
    (word_embeddings): Embedding(50265, 768, padding_idx=1)
    (position_embeddings): Embedding(514, 768, padding_idx=1)
    (token_type_embeddings): Embedding(1, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): RobertaEncoder(
    (layer): ModuleList(
      (0-11): 12 x RobertaLayer(
        (attention): RobertaAttention(
          (self): RobertaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): RobertaSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropou

In [120]:
text_encoder_pre(batch_input_ids,batch_attention_mask)[0].shape

torch.Size([4, 77, 768])

In [115]:
text_encoder_clip = clip_model_clip.text_model

In [116]:
text_encoder_clip

CLIPTextTransformer(
  (embeddings): CLIPTextEmbeddings(
    (token_embedding): Embedding(49408, 512)
    (position_embedding): Embedding(77, 512)
  )
  (encoder): CLIPEncoder(
    (layers): ModuleList(
      (0-11): 12 x CLIPEncoderLayer(
        (self_attn): CLIPAttention(
          (k_proj): Linear(in_features=512, out_features=512, bias=True)
          (v_proj): Linear(in_features=512, out_features=512, bias=True)
          (q_proj): Linear(in_features=512, out_features=512, bias=True)
          (out_proj): Linear(in_features=512, out_features=512, bias=True)
        )
        (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): CLIPMLP(
          (activation_fn): QuickGELUActivation()
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
        )
        (layer_norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      )
    )
  )
  (final_layer_norm): L

In [119]:
text_encoder_clip(batch_input_ids,batch_attention_mask)[0].shape

torch.Size([4, 77, 512])