In [7]:
import cv2
import numpy as np
import os
import torch
import timm

In [12]:
if torch.cuda.is_available():
    device = "cuda:0"
else:
    device = "cpu"

In [14]:
import torch
import torch.nn as nn
from transformers import Wav2Vec2Processor, HubertModel
import timm


In [27]:
class ViTHuBERTTransformer(nn.Module):
    def __init__(self, vit_base_model,
                 hubert_base_model,
                 num_classes,
                 nhead,
                 num_layers,
                small_dataset = True):
        super().__init__()

        self.vit = timm.create_model(vit_base_model, pretrained=True)

        #self.processor = Wav2Vec2Processor.from_pretrained(hubert_base_model)
        self.hubert = HubertModel.from_pretrained(hubert_base_model)

        if small_dataset:
            for param in self.vit.parameters():
                param.requires_grad = False
        
            for param in self.hubert.parameters():
                param.requires_grad = False
            

        encoder_layer = nn.TransformerEncoderLayer(d_model = self.vit.num_features + self.hubert.config.hidden_size,
                                                  nhead = nhead,
                                                  dim_feedforward = (self.vit.num_features + self.hubert.config.hidden_size)//2,
                                                  batch_first = True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers = num_layers)

        # Classifier
        self.classifier = nn.Linear(self.vit.num_features + self.hubert.config.hidden_size, num_classes)
    def forward(self, video_feature_raw, audio_feature_raw):

        vit_feature = self.vit.forward_features(video_feature_raw)
        audio_feature = self.hubert(audio_feature_raw).last_hidden_state
        
        # Combine features
        combined_features = torch.cat((vit_feature, audio_feature), dim=1)

        transformer_output = self.transformer_encoder(combined_features)

        logits = self.classifier(transformer_output.squeeze(1))
        return logits

In [34]:
class ViTHuBERTTransformer_prepossed(nn.Module):
    def __init__(self, vit_base_model,
                 hubert_base_model,
                 num_classes,
                 nhead,
                 num_layers,
                small_dataset = True):
        super().__init__()

        self.vit = timm.create_model(vit_base_model, pretrained=True)

        #self.processor = Wav2Vec2Processor.from_pretrained(hubert_base_model)
        self.hubert = HubertModel.from_pretrained(hubert_base_model)

        if small_dataset:
            for param in self.vit.parameters():
                param.requires_grad = False
        
            for param in self.hubert.parameters():
                param.requires_grad = False
            

        encoder_layer = nn.TransformerEncoderLayer(d_model = self.vit.num_features + self.hubert.config.hidden_size,
                                                  nhead = nhead,
                                                  dim_feedforward = (self.vit.num_features + self.hubert.config.hidden_size)//2,
                                                  batch_first = True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers = num_layers)

        # Classifier
        self.classifier = nn.Linear(self.vit.num_features + self.hubert.config.hidden_size, num_classes)
    def forward(self, video_feature, audio_feature):

        # Combine features
        combined_features = torch.cat((vit_feature, audio_feature), dim=1)

        transformer_output = self.transformer_encoder(combined_features)

        logits = self.classifier(transformer_output.squeeze(1))
        return logits

In [35]:
model = ViTHuBERTTransformer_prepossed(
    vit_base_model = 'vit_base_patch16_224',
    hubert_base_model = "facebook/hubert-large-ls960-ft",
    num_classes = 27,
    nhead = 8,
    num_layers = 6,
    small_dataset = True
)

Some weights of HubertModel were not initialized from the model checkpoint at facebook/hubert-large-ls960-ft and are newly initialized: ['hubert.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'hubert.encoder.pos_conv_embed.conv.parametrizations.weight.original0']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [30]:
model

ViTHuBERTTransformer(
  (vit): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate='