In [1]:
import accelerate
import datasets
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from PIL import Image

import transformers
from transformers import (
    AutoImageProcessor,
    AutoModel,
    AutoTokenizer,
    HfArgumentParser,
    TrainingArguments,
    set_seed,
    Trainer
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import send_example_telemetry
from transformers.utils.versions import require_version
from transformers import CLIPTokenizer, CLIPConfig, CLIPModel
from transformers.trainer_pt_utils import get_parameter_names
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS

import diffusers
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers.utils import is_wandb_available

from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.state import AcceleratorState
from accelerate.utils import ProjectConfiguration, set_seed

if is_wandb_available():
    import wandb


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
clip_config = CLIPConfig()

In [3]:
clip_config

CLIPConfig {
  "_commit_hash": null,
  "initializer_factor": 1.0,
  "logit_scale_init_value": 2.6592,
  "model_type": "clip",
  "projection_dim": 512,
  "text_config": {
    "_name_or_path": "",
    "add_cross_attention": false,
    "architectures": null,
    "attention_dropout": 0.0,
    "bad_words_ids": null,
    "begin_suppress_tokens": null,
    "bos_token_id": 49406,
    "chunk_size_feed_forward": 0,
    "cross_attention_hidden_size": null,
    "decoder_start_token_id": null,
    "diversity_penalty": 0.0,
    "do_sample": false,
    "early_stopping": false,
    "encoder_no_repeat_ngram_size": 0,
    "eos_token_id": 49407,
    "exponential_decay_length_penalty": null,
    "finetuning_task": null,
    "forced_bos_token_id": null,
    "forced_eos_token_id": null,
    "hidden_act": "quick_gelu",
    "hidden_size": 512,
    "id2label": {
      "0": "LABEL_0",
      "1": "LABEL_1"
    },
    "initializer_factor": 1.0,
    "initializer_range": 0.02,
    "intermediate_size": 2048,
    "is

In [4]:
model_save_path = "/share/test/songtianwei/model_save"

In [5]:
clip_model = AutoModel.from_pretrained(
    model_save_path,
)

In [11]:
clip_model.text_model

CLIPTextTransformer(
  (embeddings): CLIPTextEmbeddings(
    (token_embedding): Embedding(49408, 512)
    (position_embedding): Embedding(77, 512)
  )
  (encoder): CLIPEncoder(
    (layers): ModuleList(
      (0-11): 12 x CLIPEncoderLayer(
        (self_attn): CLIPAttention(
          (k_proj): Linear(in_features=512, out_features=512, bias=True)
          (v_proj): Linear(in_features=512, out_features=512, bias=True)
          (q_proj): Linear(in_features=512, out_features=512, bias=True)
          (out_proj): Linear(in_features=512, out_features=512, bias=True)
        )
        (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): CLIPMLP(
          (activation_fn): QuickGELUActivation()
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
        )
        (layer_norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      )
    )
  )
  (final_layer_norm): L

In [10]:
clip_model = AutoModel.from_config(clip_config)

In [16]:
clip_config.vision_config.image_size

224

In [12]:
accelerator = Accelerator(
        gradient_accumulation_steps=1,
        mixed_precision="no",
   
    )

In [13]:
accelerator.use_distributed

False

In [88]:
vae_config = {
    'sample_size': [224,224],  # 512
    'in_channels': 3,
    'out_channels': 3,
    'down_block_types': ['DownEncoderBlock2D', 'DownEncoderBlock2D', 'DownEncoderBlock2D'],
    'up_block_types': ['UpDecoderBlock2D', 'UpDecoderBlock2D', 'UpDecoderBlock2D'],
    'block_out_channels': [128, 256, 512],
    'layers_per_block': 2,
    'act_fn': 'silu',
    'latent_channels': 4,
    'norm_num_groups': 32,
    'scaling_factor': 0.18215,
}
         

In [89]:
vae = AutoencoderKL(**vae_config)
        

In [90]:
vae

AutoencoderKL(
  (encoder): Encoder(
    (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down_blocks): ModuleList(
      (0): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0-1): 2 x ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
        (downsamplers): ModuleList(
          (0): Downsample2D(
            (conv): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(2, 2))
          )
        )
      )
      (1): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0): ResnetBlock2D(
            (norm1): GroupNorm(32, 128, ep

In [91]:
a = torch.ones((4,3,224,224))

In [92]:
a_latent = vae.encode(a).latent_dist.sample()

In [93]:
a_latent.shape

torch.Size([4, 4, 56, 56])

In [94]:
vae_config["out_channels"]

3

In [95]:
unet_config = {
    "in_channels": vae_config["latent_channels"],
    "out_channels": vae_config["latent_channels"],
    "sample_size": 56,
    "act_fn": "silu",
    "attention_head_dim": 8,
    "block_out_channels": [
        320,
        640,
        1280,
        1280
    ],
    "center_input_sample": False,
    "cross_attention_dim": 768,  
    "down_block_types": [
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",
        "DownBlock2D"
    ],
    "downsample_padding": 1,
    "flip_sin_to_cos": True,
    "freq_shift": 0,
    "layers_per_block": 2,
    "mid_block_scale_factor": 1,
    "norm_eps": 1e-05,
    "norm_num_groups": 32,
    "up_block_types": [
        "UpBlock2D",
        "CrossAttnUpBlock2D",
        "CrossAttnUpBlock2D",
        "CrossAttnUpBlock2D"
    ]
}


In [56]:
unet = UNet2DConditionModel(**unet_config)

In [66]:
timestep = torch.ones([1])

In [69]:
encoder_hidden_states = torch.ones((4,77,768))

In [80]:
a_latent2 = unet(a_latent,timestep,encoder_hidden_states).sample

In [84]:
a_latent2.shape

torch.Size([4, 4, 28, 28])

In [86]:
res = vae.decoder(a_latent2)

In [87]:
res.shape

torch.Size([4, 3, 224, 224])