In [70]:
from datasets import load_dataset
from torch import nn
import torch
from transformers import AutoImageProcessor, AutoModel, AutoFeatureExtractor, AutoModelForPreTraining

import IPython.display as ipd

from utils.img_utils import show_img, show_imgs

# pip install soundfile 不能conda
import torchaudio
from PIL.Image import Image



In [71]:
img_ckpt = 'facebook/vit-mae-base'

image_processor = AutoImageProcessor.from_pretrained(img_ckpt)
print(image_processor)

# automodel 只有embed、encoder, 无decoder
image_model = AutoModel.from_pretrained(img_ckpt)

print(image_model)



ViTImageProcessor {
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.485,
    0.456,
    0.406
  ],
  "image_processor_type": "ViTImageProcessor",
  "image_std": [
    0.229,
    0.224,
    0.225
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}



Some weights of the model checkpoint at facebook/vit-mae-base were not used when initializing ViTMAEModel: ['decoder.decoder_layers.7.attention.output.dense.weight', 'decoder.decoder_pred.bias', 'decoder.decoder_layers.0.layernorm_before.bias', 'decoder.decoder_layers.1.attention.attention.key.weight', 'decoder.decoder_layers.6.attention.output.dense.bias', 'decoder.decoder_norm.weight', 'decoder.decoder_layers.2.output.dense.weight', 'decoder.decoder_layers.3.attention.attention.query.weight', 'decoder.decoder_layers.3.layernorm_after.bias', 'decoder.decoder_layers.6.attention.output.dense.weight', 'decoder.decoder_layers.5.layernorm_after.weight', 'decoder.decoder_layers.4.attention.attention.key.weight', 'decoder.decoder_layers.1.attention.output.dense.bias', 'decoder.decoder_layers.3.attention.output.dense.weight', 'decoder.decoder_layers.0.intermediate.dense.weight', 'decoder.decoder_layers.2.attention.attention.key.weight', 'decoder.decoder_layers.4.layernorm_after.bias', 'decode

ViTMAEModel(
  (embeddings): ViTMAEEmbeddings(
    (patch_embeddings): ViTMAEPatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
  )
  (encoder): ViTMAEEncoder(
    (layer): ModuleList(
      (0-11): 12 x ViTMAELayer(
        (attention): ViTMAEAttention(
          (attention): ViTMAESelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTMAESelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTMAEIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
      

In [72]:

ds = load_dataset("fashion_mnist", split="test[:30]")
# image = ds['image'][0].convert('RGB')
images = [img.convert('RGB') for img in ds['image'][:25]]
# print(image)
# print(images)

# url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
# image = Image.open(requests.get(url, stream=True).raw)

inputs = image_processor(images=images, return_tensors='pt')
print(inputs.keys())
print(inputs['pixel_values'].shape)
img_encode = image_model(inputs['pixel_values'])
print(img_encode.keys())
print(img_encode['last_hidden_state'].shape)
# print(outputs['logits'].shape)
# print(outputs['mask'].shape)


Found cached dataset fashion_mnist (C:/Users/82716/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/8d6c32399aa01613d96e2cbc9b13638f359ef62bb33612b077b4c247f6ef99c1)


dict_keys(['pixel_values'])
torch.Size([25, 3, 224, 224])
odict_keys(['last_hidden_state', 'mask', 'ids_restore'])
torch.Size([25, 50, 768])


In [73]:
# audio_ckpt = 'facebook/wav2vec2-base'
audio_ckpt = 'facebook/hubert-base-ls960'

audio_feature_extractor = AutoFeatureExtractor.from_pretrained(audio_ckpt)
print(audio_feature_extractor)

audio_encoder = AutoModel.from_pretrained(audio_ckpt)
print(audio_encoder)


Wav2Vec2FeatureExtractor {
  "do_normalize": true,
  "feature_extractor_type": "Wav2Vec2FeatureExtractor",
  "feature_size": 1,
  "padding_side": "right",
  "padding_value": 0,
  "return_attention_mask": false,
  "sampling_rate": 16000
}

HubertModel(
  (feature_extractor): HubertFeatureEncoder(
    (conv_layers): ModuleList(
      (0): HubertGroupNormConvLayer(
        (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
        (activation): GELUActivation()
        (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
      )
      (1-4): 4 x HubertNoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
      (5-6): 2 x HubertNoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
    )
  )
  (feature_projection): HubertFeatureProjection(
    (layer_norm): LayerNorm((512,), eps=1e-05, eleme

In [74]:
audio, sample_rate = torchaudio.load('../data/audio/audio.wav')

print(audio.shape, sample_rate)
if sample_rate != audio_feature_extractor.sampling_rate:
    audio = torchaudio.functional.resample(audio, sample_rate, audio_feature_extractor.sampling_rate)

ipd.display(ipd.Audio(audio, rate=audio_feature_extractor.sampling_rate))

# audio = audio[:, :16300]
# 时长变输出大小变
audio_encode = audio_encoder(audio)
print(audio_encode.keys())
print(audio_encode['last_hidden_state'].shape)



torch.Size([1, 22528]) 16000


odict_keys(['last_hidden_state'])
torch.Size([1, 70, 768])


In [75]:
class Wav2Lip(nn.Module):
    def __init__(self, img_ckpt: str = 'facebook/vit-mae-base', audio_ckpt: str = 'facebook/wav2vec2-base'):
        super().__init__()
        self.face_encoder = AutoModel.from_pretrained(img_ckpt)
        self.pose_encoder = AutoModel.from_pretrained(img_ckpt)
        self.audio_encoder = AutoModel.from_pretrained(audio_ckpt)
        # todo 找到合适的解码器
        self.decoder = None

    def forward(self, faces, poses, audios):
        faces = self.face_encoder(faces)
        poses = self.pose_encoder(poses)
        audios = self.audio_encoder(audios)

        return self.decoder(faces, poses, audios)