In [1]:
import cv2
import numpy as np
import os
import torch
import timm
import torch.nn as nn
from transformers import Wav2Vec2Processor, HubertModel
import timm
import torch.nn.functional as F

In [9]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [44]:
audio_feature = torch.load("103-30-384x480_audio.pt").to(device)
video_feature = torch.tensor(np.load("103-30-384x480_video.npy")).to(device)

In [40]:
def seq_feature_generation(video_feature, audio_feature, seq_len, pooling = "mean"):
    #video_feature : (771, 1, 197, 768)
    #audio_feature : [1, 773, 1024]
    video_feature = torch.tensor(video_feature, dtype=torch.float32)
    audio_feature = torch.tensor(audio_feature, dtype=torch.float32)

    video_feature = video_feature.permute(1,0,2,3)
    
    if pooling == "mean":
        video_feature = torch.mean(video_feature, dim = 2, keepdim=False)
    elif pooling == "max":
        video_feature = torch.max(video_feature, dim = 2, keepdim=False)[0]

    max_seq = min(video_feature.shape[1], audio_feature.shape[1])
    video_feature = video_feature[:, :max_seq, :]
    audio_feature = audio_feature[:, :max_seq, :]
    combined_feature = torch.cat([video_feature, audio_feature], dim = -1)
    #[1, max_seq, 1024 + 768]
    
    if max_seq < seq_len:
        # Pad both features to seq_len along the sequence dimension
        combined_sequences = F.pad(combined_feature, (0, 0, 0, seq_len - max_seq))
    else:
        num_complete_seqs = max_seq // seq_len
        combined_sequences = combined_feature[:,:num_complete_seqs*seq_len, :].view(-1, seq_len, combined_feature.shape[-1])
    #[-1, seq_len, combined_feature_size]
    return combined_sequences

In [45]:
sequence = seq_feature_generation(video_feature, audio_feature, seq_len = 10)

  video_feature = torch.tensor(video_feature, dtype=torch.float32)
  audio_feature = torch.tensor(audio_feature, dtype=torch.float32)


In [46]:
sequence.shape

torch.Size([77, 10, 1792])

In [4]:
class ViTHuBERTTransformer(nn.Module):
    def __init__(self, vit_base_model,
                 hubert_base_model,
                 num_classes,
                 nhead,
                 num_layers,
                small_dataset = True):
        super().__init__()

        self.vit = timm.create_model(vit_base_model, pretrained=True)

        #self.processor = Wav2Vec2Processor.from_pretrained(hubert_base_model)
        self.hubert = HubertModel.from_pretrained(hubert_base_model)

        if small_dataset:
            for param in self.vit.parameters():
                param.requires_grad = False
        
            for param in self.hubert.parameters():
                param.requires_grad = False
            

        encoder_layer = nn.TransformerEncoderLayer(d_model = self.vit.num_features + self.hubert.config.hidden_size,
                                                  nhead = nhead,
                                                  dim_feedforward = (self.vit.num_features + self.hubert.config.hidden_size)//2,
                                                  batch_first = True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers = num_layers)

        # Classifier
        self.classifier = nn.Linear(self.vit.num_features + self.hubert.config.hidden_size, num_classes)
    def forward(self, video_feature_raw, audio_feature_raw):

        vit_feature = self.vit.forward_features(video_feature_raw)
        audio_feature = self.hubert(audio_feature_raw).last_hidden_state
        
        # Combine features
        combined_features = torch.cat((vit_feature, audio_feature), dim=1)

        transformer_output = self.transformer_encoder(combined_features)

        logits = self.classifier(transformer_output.squeeze(1))
        return logits

In [50]:
class ViTHuBERTTransformer_prepossed(nn.Module):
    def __init__(self, vit_base_model,
                 hubert_base_model,
                 num_classes,
                 nhead,
                 num_layers,
                small_dataset = True):
        super().__init__()

        self.vit = timm.create_model(vit_base_model, pretrained=True)

        #self.processor = Wav2Vec2Processor.from_pretrained(hubert_base_model)
        self.hubert = HubertModel.from_pretrained(hubert_base_model)

        if small_dataset:
            for param in self.vit.parameters():
                param.requires_grad = False
        
            for param in self.hubert.parameters():
                param.requires_grad = False
            

        encoder_layer = nn.TransformerEncoderLayer(d_model = self.vit.num_features + self.hubert.config.hidden_size,
                                                  nhead = nhead,
                                                  dim_feedforward = (self.vit.num_features + self.hubert.config.hidden_size)//2,
                                                  batch_first = True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers = num_layers)

        # Classifier
        self.classifier = nn.Linear(self.vit.num_features + self.hubert.config.hidden_size, num_classes)
    def forward(self, combined_feature):

        transformer_output = self.transformer_encoder(combined_feature)

        logits = self.classifier(transformer_output.squeeze(1))
        return logits

In [51]:
model = ViTHuBERTTransformer_prepossed(
    vit_base_model = 'vit_base_patch16_224',
    hubert_base_model = "facebook/hubert-large-ls960-ft",
    num_classes = 27,
    nhead = 8,
    num_layers = 6,
    small_dataset = True
)

Some weights of HubertModel were not initialized from the model checkpoint at facebook/hubert-large-ls960-ft and are newly initialized: ['hubert.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'hubert.encoder.pos_conv_embed.conv.parametrizations.weight.original0']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [54]:
model.to(device)

ViTHuBERTTransformer_prepossed(
  (vit): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(app

In [56]:
model(sequence).shape

torch.Size([77, 10, 27])

In [None]:
def evaluate(model, dataloader, loss_fn, device):
    model.eval()
    loss_cumulative = 0.
    start_time = time.time()
    with torch.no_grad():
        for j, d in enumerate(dataloader):
            d.to(device)
            output = model(d)
            #print(len(output))
            #print(len(d.target))
            loss = loss_fn(output, d.target).cpu()
            loss_cumulative = loss_cumulative + loss.detach().item()
    return loss_cumulative / len(dataloader)

In [None]:
def train(model, optimizer, dataloader_train, dataloader_valid, loss_fn,
             max_iter=101, scheduler=None, device="cpu"):
    model.to(device = device, dtype=torch.float32)
    print(device)
    checkpoint_generator = loglinspace(0.3, 5)
    checkpoint = next(checkpoint_generator)
    start_time = time.time()
    run_name = "vithubertformer"
    try:
        model.load_state_dict(torch.load(run_name + '.torch')['state'])
    except:
        results = {}
        history = []
        s0 = 0
    else:
        results = torch.load(run_name + '.torch')
        history = results['history']
        s0 = history[-1]['step'] + 1

    for step in range(max_iter):
        model.train()
        loss_cumulative = 0.

        for j, d in tqdm(enumerate(dataloader_train), total=len(dataloader_train), bar_format=bar_format):
            d.to(device)
            output = model(d)
            loss = loss_fn(output, d.target).cpu()
            loss_cumulative = loss_cumulative + loss.detach().item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        end_time = time.time()
        wall = end_time - start_time

        if step == checkpoint:
            checkpoint = next(checkpoint_generator)
            assert checkpoint > step

            valid_avg_loss = evaluate(model, dataloader_valid, loss_fn, device)
            train_avg_loss = evaluate(model, dataloader_train, loss_fn, device)

            history.append({
                'step': s0 + step,
                'wall': wall,
                'batch': {
                    'loss': loss.item(),
                },
                'valid': {
                    'loss': valid_avg_loss,
                },
                'train': {
                    'loss': train_avg_loss,
                },
            })

            results = {
                'history': history,
                'state': model.state_dict()
            }

            print(f"epoch {step + 1:4d}   " +
                  f"abs = {train_avg_loss:8.4f}   " +
                  f"valid loss mse= {valid_avg_loss[0]:8.4f}   " +
                  f"wall = {time.strftime('%H:%M:%S', time.gmtime(wall))}")

            with open(run_name + '.torch', 'wb') as f:
                torch.save(results, f)

        if scheduler is not None:
            scheduler.step()

In [None]:
loss_function = torch.nn.CrossEntropyLoss()
opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)
scheduler = torch.optim.lr_scheduler.ExponentialLR(opt, gamma=0.96)
train_e3(model, opt, dataloader_train, dataloader_valid, loss_fn,
             max_iter=50, scheduler=scheduler, device=device)