In [1]:
import numpy as np
import os
import torch
import timm
import torch.nn as nn
from transformers import Wav2Vec2Processor, HubertModel
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.io import read_video
import pandas as pd
import time
from tqdm import tqdm
from sklearn.metrics import f1_score
import torchaudio
from moviepy.editor import VideoFileClip, vfx
import logging
logging.getLogger('moviepy').setLevel(logging.ERROR)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torchaudio.set_audio_backend("soundfile")  # 或者 "sox_io"

In [3]:
torch.cuda.is_available()

True

## tmp

In [4]:
def seq_feature_generation(video_feature, audio_feature, seq_len, pooling = "mean"):
    #video_feature : (771, 1, 197, 768)
    #audio_feature : [1, 773, 1024]
    video_feature = torch.tensor(video_feature, dtype=torch.float32)
    audio_feature = torch.tensor(audio_feature, dtype=torch.float32)

    video_feature = video_feature.permute(1,0,2,3)
    
    if pooling == "mean":
        video_feature = torch.mean(video_feature, dim = 2, keepdim=False)
    elif pooling == "max":
        video_feature = torch.max(video_feature, dim = 2, keepdim=False)[0]

    max_seq = min(video_feature.shape[1], audio_feature.shape[1])
    video_feature = video_feature[:, :max_seq, :]
    audio_feature = audio_feature[:, :max_seq, :]
    combined_feature = torch.cat([video_feature, audio_feature], dim = -1)
    #[1, max_seq, 1024 + 768]
    
    if max_seq < seq_len:
        # Pad both features to seq_len along the sequence dimension
        combined_sequences = F.pad(combined_feature, (0, 0, 0, seq_len - max_seq))
    else:
        num_complete_seqs = max_seq // seq_len
        combined_sequences = combined_feature[:,:num_complete_seqs*seq_len, :].view(-1, seq_len, combined_feature.shape[-1])
    #[-1, seq_len, combined_feature_size]
    return combined_sequences

In [4]:
def loglinspace(rate, step, end=None):
    t = 0
    while end is None or t <= end:
        yield t
        t = int(t + 1 + step * (1 - math.exp(-t * rate / step)))

In [8]:
class ViTHuBERTTransformer_prepossed(nn.Module):
    def __init__(self, vit_base_model,
                 hubert_base_model,
                 num_classes,
                 nhead,
                 num_layers,
                small_dataset = True):
        super().__init__()

        self.vit = timm.create_model(vit_base_model, pretrained=True)

        #self.processor = Wav2Vec2Processor.from_pretrained(hubert_base_model)
        self.hubert = HubertModel.from_pretrained(hubert_base_model)

        if small_dataset:
            for param in self.vit.parameters():
                param.requires_grad = False
        
            for param in self.hubert.parameters():
                param.requires_grad = False
            

        encoder_layer = nn.TransformerEncoderLayer(d_model = self.vit.num_features + self.hubert.config.hidden_size,
                                                  nhead = nhead,
                                                  dim_feedforward = (self.vit.num_features + self.hubert.config.hidden_size)//2,
                                                  batch_first = True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers = num_layers)

        # Classifier
        self.classifier = nn.Linear(self.vit.num_features + self.hubert.config.hidden_size, num_classes)
    def forward(self, combined_feature):

        transformer_output = self.transformer_encoder(combined_feature)

        logits = self.classifier(transformer_output.squeeze(1))
        return logits

In [5]:
def evaluate(model, dataloader, loss_fn, device):
    model.eval()
    loss_cumulative = 0.
    start_time = time.time()
    with torch.no_grad():
        for j, d in enumerate(dataloader):
            video_feature, audio_feature, labels = d
            
            video_feature.to(device)
            audio_feature.to(device)
            labels.to(device)

            output = model(video_feature, audio_feature)
            #print(len(output))
            #print(len(d.target))
            loss = loss_fn(output, d.target).cpu()
            loss_cumulative = loss_cumulative + loss.detach().item()
    return loss_cumulative / len(dataloader)

In [6]:
def train(model, optimizer, dataloader_train, dataloader_valid, loss_fn,
             max_iter=101, scheduler=None, device="cpu"):
    model.to(device = device, dtype=torch.float32)
    print(device)
    checkpoint_generator = loglinspace(0.3, 5)
    checkpoint = next(checkpoint_generator)
    start_time = time.time()
    run_name = "vithubertformer"
    try:
        model.load_state_dict(torch.load(run_name + '.torch')['state'])
    except:
        results = {}
        history = []
        s0 = 0
    else:
        results = torch.load(run_name + '.torch')
        history = results['history']
        s0 = history[-1]['step'] + 1

    for step in range(max_iter):
        model.train()
        loss_cumulative = 0.

        for j, d in tqdm(enumerate(dataloader_train), total=len(dataloader_train)):
            video_feature, audio_feature, labels = d
            
            video_feature = video_feature.squeeze(0).to(device)
            audio_feature = audio_feature.squeeze(0).to(device)
            labels = labels.squeeze(0).to(device)

            print(video_feature.shape)
            print(audio_feature.shape)
            print(labels.shape)
            
            output = model(video_feature, audio_feature)
            loss = loss_fn(output, labels).cpu()
            loss_cumulative = loss_cumulative + loss.detach().item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        end_time = time.time()
        wall = end_time - start_time

        if step == checkpoint:
            checkpoint = next(checkpoint_generator)
            assert checkpoint > step

            valid_avg_loss = evaluate(model, dataloader_valid, loss_fn, device)
            train_avg_loss = evaluate(model, dataloader_train, loss_fn, device)

            history.append({
                'step': s0 + step,
                'wall': wall,
                'batch': {
                    'loss': loss.item(),
                },
                'valid': {
                    'loss': valid_avg_loss,
                },
                'train': {
                    'loss': train_avg_loss,
                },
            })

            results = {
                'history': history,
                'state': model.state_dict()
            }

            print(f"epoch {step + 1:4d}   " +
                  f"abs = {train_avg_loss:8.4f}   " +
                  f"valid loss mse= {valid_avg_loss[0]:8.4f}   " +
                  f"wall = {time.strftime('%H:%M:%S', time.gmtime(wall))}")

            with open(run_name + '.torch', 'wb') as f:
                torch.save(results, f)

        if scheduler is not None:
            scheduler.step()

## Training

In [7]:
class AudioVideoDataset(Dataset):
    def __init__(self, video_dir, label_dir, device, seq_len):
        self.video_dir = video_dir
        self.label_dir = label_dir
        self.device = device
        self.seq_len = seq_len
        self.transform = self.create_transform()

        possible_extensions = ['.mp4', '.avi']
        # Collect all label files, and construct corresponding video and audio file paths
        self.entries = []
        for label_file in sorted(os.listdir(label_dir)):
            if label_file.endswith('.txt'):
                base_name = os.path.splitext(label_file)[0]
                video_file = None
                for ext in possible_extensions:
                    video_path = os.path.join(video_dir, f"{base_name}{ext}")
                    if os.path.exists(video_path):
                        video_file = video_path
                        break
                label_file = os.path.join(label_dir, label_file)
                
                if os.path.exists(video_file):
                    self.entries.append((video_file, label_file))
                else:
                    print(f"Missing video or audio file for {label_file}")

    def __len__(self):
        return len(self.entries)

    def __getitem__(self, idx):
        video_file_path, label_file_path = self.entries[idx]
        video, audio, info = read_video(video_file_path)
        #print(video.shape) #[6893, 360, 640, 3] 
        #print(audio.shape) [1,18480840]
        #print(info)
        labels = torch.tensor(np.loadtxt(label_file_path, skiprows=1, delimiter=','))

        video_feature = torch.stack([self.transform(frame.permute(2,0,1)) for frame in video])
        #print(video_feature.shape)

        audio_feature = self.audio_pre(video_file_path)
        #audio_feature.shape
        #video_feature 此时 torch.Size([6286, 3, 224, 224])， audio_feature还为raw， 经过hubert后，转为 torch.Size([1, 6291, 1024])， 需要在后续模型中进行对齐
        return video_feature, audio_feature, labels

    def create_transform(self):
        #transform image to image feature
        return transforms.Compose([
            transforms.ToPILImage(),  # 将 numpy 数组或 tensor 转换为 PIL 图像
            transforms.Resize((224, 224)),  # 调整图像大小
            transforms.ToTensor(),  # 将 PIL 图像转换为 tensor，并归一化至 [0,1]
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 标准化
        ])

    def audio_pre(self,video_path):
        #output audio_preprocessed feature before goes in hubert
        processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
        #CHUNK_SIZE = 60 * 16000
        def load_audio(video_path):
            video_clip = VideoFileClip(video_path)
            audio_clip = video_clip.audio
            # print(video_clip.reader.nframes)
            new_audio = audio_clip.fx(vfx.speedx, 1 / (0.02002 * video_clip.fps))
            new_audio.write_audiofile('output.wav')
            return video_clip.reader.nframes

        frames_number = load_audio(video_path) 
        #print(frames) #6286
        audio_input, sample_rate = torchaudio.load("./output.wav")
        #print(audio_input.shape) torch.Size([2, 5549707])
        #print(sample_rate) 44100
        # Check if the audio is stereo and convert to mono if necessary
        if audio_input.shape[0] > 1:  # More than one channel
            audio_input = torch.mean(audio_input, dim=0, keepdim=True)
    
        # Resample the audio file if necessary
        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
            audio_input = resampler(audio_input)

        audio_input = audio_input.squeeze()
        features = processor(audio_input, return_tensors="pt", sampling_rate=16000).input_values

        return features
        


In [8]:
# test
train_label_dir = "/home/yifan/Desktop/deep_learning/multimodal/Aff-Wild2/Aff-Wild2/labels/AU_Detection_Challenge/small_dataset_train"
val_label_dir = "/home/yifan/Desktop/deep_learning/multimodal/Aff-Wild2/Aff-Wild2/labels/AU_Detection_Challenge/test"
video_dir = "/home/yifan/Desktop/deep_learning/multimodal/Aff-Wild2/Aff-Wild2/video"

train_dataset = AudioVideoDataset(video_dir, train_label_dir, device, 10)
val_dataset = AudioVideoDataset(video_dir, val_label_dir, device, 10)

In [9]:
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

In [10]:
class ViTHuBERTTransformer(nn.Module):
    def __init__(self, vit_base_model,
                 hubert_base_model,
                 num_classes,
                 nhead,
                 num_layers,
                 seq_len,
                small_dataset = False):
        super().__init__()

        self.seq_len = seq_len
        self.vit = timm.create_model(vit_base_model, pretrained=True)

        #self.processor = Wav2Vec2Processor.from_pretrained(hubert_base_model)
        self.hubert = HubertModel.from_pretrained(hubert_base_model)

        if small_dataset:
            for param in self.vit.parameters():
                param.requires_grad = False
        
            for param in self.hubert.parameters():
                param.requires_grad = False
            

        encoder_layer = nn.TransformerEncoderLayer(d_model = self.vit.num_features + self.hubert.config.hidden_size,
                                                  nhead = nhead,
                                                  dim_feedforward = (self.vit.num_features + self.hubert.config.hidden_size)//2,
                                                  batch_first = True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers = num_layers)

        # Classifier
        self.classifier = nn.Linear(self.vit.num_features + self.hubert.config.hidden_size, num_classes)
    def forward(self, video_feature_raw, audio_feature_raw):

        audio_feature = self.hubert(audio_feature_raw).last_hidden_state
        vit_feature = self.vit.forward_features(video_feature_raw)

        combined_features = self.seq_feature_generation(vit_feature, audio_feature, self.seq_len)
        #[batch, seq_len, combined_feature_size]

        transformer_output = self.transformer_encoder(combined_features)

        logits = self.classifier(transformer_output.squeeze(1))
        return logits

    def seq_feature_generation(self, video_feature, audio_feature, seq_len, pooling = "mean"):
        #video_feature : (771, 1, 197, 768)
        #audio_feature : [1, 773, 1024]
        video_feature = torch.tensor(video_feature, dtype=torch.float32)
        audio_feature = torch.tensor(audio_feature, dtype=torch.float32)
    
        video_feature = video_feature.permute(1,0,2,3)
        
        if pooling == "mean":
            video_feature = torch.mean(video_feature, dim = 2, keepdim=False)
        elif pooling == "max":
            video_feature = torch.max(video_feature, dim = 2, keepdim=False)[0]
    
        max_seq = min(video_feature.shape[1], audio_feature.shape[1])
        video_feature = video_feature[:, :max_seq, :]
        audio_feature = audio_feature[:, :max_seq, :]
        combined_feature = torch.cat([video_feature, audio_feature], dim = -1)
        #[1, max_seq, 1024 + 768]
        
        if max_seq < seq_len:
            # Pad both features to seq_len along the sequence dimension
            combined_sequences = F.pad(combined_feature, (0, 0, 0, seq_len - max_seq))
        else:
            num_complete_seqs = max_seq // seq_len
            combined_sequences = combined_feature[:,:num_complete_seqs*seq_len, :].view(-1, seq_len, combined_feature.shape[-1])
        #[batch, seq_len, combined_feature_size]
        return combined_sequences

In [11]:
model = ViTHuBERTTransformer(
    vit_base_model = 'vit_base_patch16_224',
    hubert_base_model = "facebook/hubert-large-ls960-ft",
    num_classes = 27,
    nhead = 8,
    num_layers = 6,
    seq_len = 10,
    small_dataset = False
)

Some weights of HubertModel were not initialized from the model checkpoint at facebook/hubert-large-ls960-ft and are newly initialized: ['hubert.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'hubert.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
loss_function = torch.nn.CrossEntropyLoss()
opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)
scheduler = torch.optim.lr_scheduler.ExponentialLR(opt, gamma=0.96)


In [13]:
device

'cuda:0'

In [14]:
train(model, opt, train_loader, val_loader, loss_function,max_iter=1, scheduler=scheduler, device=device)

cuda:0


  0%|                                                                                              | 0/29 [00:10<?, ?it/s]

MoviePy - Writing audio in output.wav



chunk:   0%|                                                                           | 0/2275 [00:00<?, ?it/s, now=None][A
chunk:  20%|████████████▍                                                  | 451/2275 [00:00<00:00, 4509.34it/s, now=None][A
chunk:  41%|█████████████████████████▊                                     | 930/2275 [00:00<00:00, 4627.54it/s, now=None][A
chunk:  61%|██████████████████████████████████████                        | 1395/2275 [00:00<00:00, 4611.03it/s, now=None][A
chunk:  82%|██████████████████████████████████████████████████▋           | 1860/2275 [00:00<00:00, 4613.53it/s, now=None][A
  0%|                                                                                              | 0/29 [00:11<?, ?it/s][A

MoviePy - Done.
torch.Size([5153, 3, 224, 224])
torch.Size([1, 1650624])
torch.Size([5153, 12])


  0%|                                                                                              | 0/29 [00:12<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 1.59 GiB. GPU 0 has a total capacity of 11.74 GiB of which 541.62 MiB is free. Including non-PyTorch memory, this process has 10.71 GiB memory in use. Of the allocated memory 10.31 GiB is allocated by PyTorch, and 105.05 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)