In [1]:
import torch
from torch import nn
import torchaudio
import torchvision

from sklearn.model_selection import train_test_split

from tqdm import tqdm
from sklearn import metrics

import pickle

#from trainer import TorchSupervisedTrainer

import os

import pandas as pd
import json
import glob

import random
import os
import numpy as np

import argparse

from torchvision.transforms import v2

from datasets import MultimodalPhysVerbDataset, AggrBatchSampler, AppendZeroValues, MultimodalDataset
from trainer import MultimodalTrainer
from models import PhysVerbClassifierConcatFeatures, AveragedFeaturesTransformerFusion, PhysVerbClassifierAddFeatures, PhysVerbClassifier, PhysVerbModel, TransformerSequenceProcessor, EqualSizedTransformerModalitiesFusion, MultimodalModel, CNN1D, Swin3d_T_extractor, OutputClassifier, MultiModalCrossEntropyLoss, Wav2vec2Extractor, Wav2vecExtractor, AudioCnn1DExtractorWrapper


In [2]:
class MultimodalFeatureGenDataset(MultimodalPhysVerbDataset):
    def __getitem__(self, idx):
        output_data_tuple, output_labels_tuple = super().__getitem__(idx)
        data_entry = self.time_intervals_df.loc[idx]
        aggr_type = data_entry['aggr_type']
        cluster_id = data_entry['cluster_id']
        video_id = data_entry['video_id']
        phys_t1 = data_entry['phys_t1']
        phys_t2 = data_entry['phys_t2']
        verb_t1 = data_entry['verb_t1']
        verb_t2 = data_entry['verb_t2']
        person_id = data_entry['person_id']
        phys_label = data_entry['phys_aggr_label']
        verb_label = data_entry['verb_aggr_label']
        verb_name = f'c-{cluster_id}_{video_id}_{person_id}_{verb_t1/1000}-{verb_t2/1000}_{verb_label}'
        phys_name = f'c-{cluster_id}_{video_id}_{person_id}_{phys_t1/1000}-{phys_t2/1000}_{phys_label}'
        return output_data_tuple, output_labels_tuple, (verb_name, phys_name)
    
class PhysVerbModelFeat(PhysVerbModel):
    def forward(self, input_data):
        # извлечение признаков
        modalities_features_dict = self.extract_features(input_data)
        
        # выполняем слияние модальностей
        modalities_features_dict = self.modality_fusion_module(modalities_features_dict)
        #return modalities_features_dict
        # Выполнение классификации
        '''
        output_dict = {}
        for aggr_type in self.classifiers:
            output_dict[aggr_type] = self.classifiers[aggr_type](modalities_features_dict[aggr_type])
        '''
        output_dict = self.classifiers(modalities_features_dict)
        return output_dict, modalities_features_dict

In [3]:
parser = argparse.ArgumentParser()
parser.add_argument('--path_to_dataset',  required=True)
parser.add_argument('--path_to_intersections_csv')
parser.add_argument('--path_to_train_test_split_json')
parser.add_argument('--gpu_device_idx', type=int)
parser.add_argument('--class_num', type=int)
parser.add_argument('--resume_training', action='store_true')
parser.add_argument('--path_to_checkpoint')
parser.add_argument('--batch_size', required=True, type=int)
parser.add_argument('--nn_input_size', nargs='+', type=int)
parser.add_argument('--epoch_num', type=int)
parser.add_argument('--test_size', type=float)
parser.add_argument('--max_audio_len', type=int)
parser.add_argument('--max_embeddings_len', type=int)
parser.add_argument('--video_frames_num', type=int)
parser.add_argument('--video_window_size', type=int)

sample_args = [
    '--path_to_dataset',
    #r'/home/ubuntu/mikhail_u/DATASET_V0',
    #r'/home/aggr/mikhail_u/DATA/DATSET_V0',
    #r'C:\Users\admin\python_programming\DATA\AVABOS\DATSET_V0',
    r'I:\AVABOS\DATSET_V0',
    '--path_to_intersections_csv',
    #r'/home/ubuntu/mikhail_u/DATASET_V0/time_intervals_combinations_table.csv',
    #r'/home/aggr/mikhail_u/DATA/DATSET_V0/time_intervals_combinations_table.csv',
    #r'C:\Users\admin\python_programming\DATA\AVABOS\DATSET_V0\time_intervals_combinations_table.csv',
    r'i:\AVABOS\DATSET_V0\time_intervals_combinations_table.csv',
    '--path_to_train_test_split_json',
    r'train_test_split.json',
    '--gpu_device_idx', '0',
    '--class_num', '2',
    '--epoch_num', '100',
    '--batch_size', '1',
    '--max_audio_len', '80000',
    '--max_embeddings_len', '48',
    '--video_frames_num', '128',
    '--video_window_size', '8'
    ]

args = parser.parse_args(sample_args)

path_to_dataset_root = args.path_to_dataset
resume_training = args.resume_training
path_to_intersections_csv = args.path_to_intersections_csv
path_to_train_test_split_json = args.path_to_train_test_split_json

path_to_checkpoint = args.path_to_checkpoint
epoch_num = int(args.epoch_num)
class_num = args.class_num
batch_size = int(args.batch_size)
max_audio_len = args.max_audio_len
max_embeddings_len = args.max_embeddings_len
video_frames_num = args.video_frames_num
video_window_size = args.video_window_size
gpu_device_idx = args.gpu_device_idx

# имя модели соответствует имени экстрактора признаков
phys_gamma_val = 2
verb_gamma_val = 2
model_name = 'A+T(ce)+fusion1L'
modality2aggr = {'video':'phys', 'text':'verb', 'audio':'verb'}
#modality2aggr = {'video':'verb', 'text':'verb', 'audio':'phys'}
modalities_list = [
    'audio',
    'text',
    'video'
    ]
aggr_types_list = set()
for m in modalities_list:
    aggr_types_list.add(modality2aggr[m])

aggr_types_list = list(aggr_types_list)

time_interval_combinations_df = pd.read_csv(path_to_intersections_csv)

with open(path_to_train_test_split_json) as fd:
    combinations_indices_dict = json.load(fd)

train_time_interval_combinations_df =  []
for cluster_id in combinations_indices_dict['train_clusters']:
    df = time_interval_combinations_df[time_interval_combinations_df['cluster_id']==cluster_id]
    train_time_interval_combinations_df.append(df)
train_time_interval_combinations_df = pd.concat(train_time_interval_combinations_df, ignore_index=True)
# для выравнивания баланса классов (баланс смещен в сторону не агрессивного поведения)
# удалим не агрессивные интервалы физ. поведения, которые не пересекаются с вербальным поведением
#drop_no_aggr_filter = (train_time_interval_combinations_df['aggr_type']=='phys')&(train_time_interval_combinations_df['phys_aggr_label']=='NOAGGR')
#train_time_interval_combinations_df = train_time_interval_combinations_df[~drop_no_aggr_filter]
# DEBUG
#train_time_interval_combinations_df = train_time_interval_combinations_df.loc[0:500]

test_time_interval_combinations_df =  []
for cluster_id in combinations_indices_dict['test_clusters']:
    df = time_interval_combinations_df[time_interval_combinations_df['cluster_id']==cluster_id]
    test_time_interval_combinations_df.append(df)
test_time_interval_combinations_df = pd.concat(test_time_interval_combinations_df, ignore_index=True)

device = torch.device(f'cuda:{gpu_device_idx}')

train_video_transform = v2.Compose([
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomAffine(degrees=20, translate=(0.1, 0.3), scale=(0.6, 1.1), shear=10),
    v2.RandomPerspective(distortion_scale=0.2),
    v2.Resize((112, 112), antialias=True),
    AppendZeroValues(target_size=[video_frames_num, 3, 112, 112]),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_video_transform = v2.Compose([
    v2.Resize((112, 112), antialias=True),
    AppendZeroValues(target_size=[video_frames_num, 3, 112, 112]),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_audio_transform = v2.Compose([AppendZeroValues(target_size=[max_audio_len])])
test_audio_transform = v2.Compose([AppendZeroValues(target_size=[max_audio_len])])

train_text_transform = v2.Compose([AppendZeroValues(target_size=[max_embeddings_len, 768])])
test_text_transform = v2.Compose([AppendZeroValues(target_size=[max_embeddings_len, 768])])

train_transforms_dict = {
    'audio': train_audio_transform,
    'text': train_text_transform,
    'video': train_video_transform
}
train_transforms_dict = {k:v for k,v in train_transforms_dict.items()if k in modalities_list}
test_transforms_dict = {
    'audio': test_audio_transform,
    'text': test_text_transform,
    'video': test_video_transform
}
test_transforms_dict = {k:v for k,v in test_transforms_dict.items()if k in modalities_list}


train_dataset = MultimodalFeatureGenDataset(
    time_intervals_df=train_time_interval_combinations_df,
    path_to_dataset=path_to_dataset_root,
    modality_augmentation_dict=train_transforms_dict,
    modality2aggr=modality2aggr,
    actual_modalities_list=modalities_list,
    device=device,
    text_embedding_type='ru_conversational_cased_L-12_H-768_A-12_pt_v1_tokens'
    )
test_dataset = MultimodalFeatureGenDataset(
    time_intervals_df=test_time_interval_combinations_df,
    path_to_dataset=path_to_dataset_root,
    modality_augmentation_dict=train_transforms_dict,
    modality2aggr=modality2aggr,
    actual_modalities_list=modalities_list,
    device=device,
    text_embedding_type='ru_conversational_cased_L-12_H-768_A-12_pt_v1_tokens'
    )

train_batch_sampler = AggrBatchSampler(train_time_interval_combinations_df, batch_size=batch_size, shuffle=True)
test_batch_sampler = AggrBatchSampler(test_time_interval_combinations_df, batch_size=batch_size, shuffle=False)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_sampler=train_batch_sampler,
    num_workers=0
    #pin_memory=True
)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_sampler=test_batch_sampler,
    num_workers=0
    #pin_memory=True
)

path_to_model = r'I:\AVABOS\saving_dir2\saving_dir\12.11.2024, 15-19-48 (V(focal,g=1)+A+T(ce)+fusion1L)\verb_best_ep-3.pt'
model = torch.load(path_to_model)
model = model.to(device)
model.eval()
class PidorEbaniy(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, input_data):
        # извлечение признаков
        modalities_features_dict = self.model.extract_features(input_data)
        
        # выполняем слияние модальностей
        modalities_features_dict = self.model.modality_fusion_module(modalities_features_dict)
        #return modalities_features_dict
        # Выполнение классификации
        '''
        output_dict = {}
        for aggr_type in self.classifiers:
            output_dict[aggr_type] = self.classifiers[aggr_type](modalities_features_dict[aggr_type])
        '''
        output_dict = self.model.classifiers(modalities_features_dict)
        concat_features = []
        for modality in ['audio', 'text']:
            features = modalities_features_dict[modality]
            features = features.mean(dim=1)
            concat_features.append(features)
        concat_features = torch.cat(concat_features, dim=1)
        return output_dict, concat_features

pedota_ebanaya = PidorEbaniy(model)
#with torch.no_grad():
files_lst = []
preds_list = []
true_list = []
for data, labels, (verb_name, phys_name) in tqdm(test_dataloader):

    ret_dict, multimodal_features = pedota_ebanaya(data)
    
    verb_name = verb_name[0]
    filtered_labels = [v for v in labels if v[0][0]=='verb']
    if len(filtered_labels) > 0:
        aggr_type, label = filtered_labels[0]
        label = label.item()
        true_list.append(label)
        one_hot_label = np.zeros((2,), dtype=int)
        one_hot_label[label] = 1
        res = ret_dict['verb']
        res = res.detach()
        _, pred_label = res.max(dim=1)
        pred_label = pred_label.item()
        preds_list.append(pred_label)
        one_hot_pred_label = np.zeros((2,), dtype=int)
        one_hot_pred_label[pred_label] = 1
        multimodal_features = multimodal_features.squeeze(0).detach().cpu().numpy()
        d = {'filename':verb_name,'features':multimodal_features, 'targets':one_hot_label, 'predictions':one_hot_pred_label}
        files_lst.append(d)
with open('test_set.pkl', 'wb') as fd:
    pickle.dump(files_lst, fd)
        
for data, labels, (verb_name, phys_name) in tqdm(train_dataloader):

    ret_dict, multimodal_features = pedota_ebanaya(data)
    
    verb_name = verb_name[0]
    filtered_labels = [v for v in labels if v[0][0]=='verb']
    if len(filtered_labels) > 0:
        aggr_type, label = filtered_labels[0]
        label = label.item()
        true_list.append(label)
        one_hot_label = np.zeros((2,), dtype=int)
        one_hot_label[label] = 1
        res = ret_dict['verb']
        res = res.detach()
        _, pred_label = res.max(dim=1)
        pred_label = pred_label.item()
        preds_list.append(pred_label)
        one_hot_pred_label = np.zeros((2,), dtype=int)
        one_hot_pred_label[pred_label] = 1
        multimodal_features = multimodal_features.squeeze(0).detach().cpu().numpy()
        d = {'filename':verb_name,'features':multimodal_features, 'targets':one_hot_label, 'predictions':one_hot_pred_label}
        files_lst.append(d)
        

with open('train_set.pkl', 'wb') as fd:
    pickle.dump(files_lst, fd)

  model = torch.load(path_to_model)
  video = torch.load(path_to_video)
  audio = torch.load(path_to_audio).to(self.device)
  video = torch.load(path_to_video)
  audio = torch.load(path_to_audio).to(self.device)
100%|██████████| 4112/4112 [27:56<00:00,  2.45it/s]
100%|██████████| 9516/9516 [1:10:30<00:00,  2.25it/s]


In [77]:
print(metrics.classification_report(true_list, preds_list, digits=3))

              precision    recall  f1-score   support

           0      0.706     0.806     0.753       732
           1      0.876     0.804     0.838      1253

    accuracy                          0.805      1985
   macro avg      0.791     0.805     0.796      1985
weighted avg      0.813     0.805     0.807      1985



In [66]:
multimodal_features.shape

torch.Size([1, 1536])

In [48]:
one_hot_label

array([0, 1])

In [11]:
for d in test_dataloader:
    break
len(d)

3

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument('--path_to_dataset',  required=True)
parser.add_argument('--path_to_intersections_csv')
parser.add_argument('--path_to_train_test_split_json')
parser.add_argument('--gpu_device_idx', type=int)
parser.add_argument('--class_num', type=int)
parser.add_argument('--resume_training', action='store_true')
parser.add_argument('--path_to_checkpoint')
parser.add_argument('--batch_size', required=True, type=int)
parser.add_argument('--nn_input_size', nargs='+', type=int)
parser.add_argument('--epoch_num', type=int)
parser.add_argument('--test_size', type=float)
parser.add_argument('--max_audio_len', type=int)
parser.add_argument('--max_embeddings_len', type=int)
parser.add_argument('--video_frames_num', type=int)
parser.add_argument('--video_window_size', type=int)

sample_args = [
    '--path_to_dataset',
    #r'/home/ubuntu/mikhail_u/DATASET_V0',
    #r'/home/aggr/mikhail_u/DATA/DATSET_V0',
    #r'C:\Users\admin\python_programming\DATA\AVABOS\DATSET_V0',
    r'I:\AVABOS\DATSET_V0',
    '--path_to_intersections_csv',
    #r'/home/ubuntu/mikhail_u/DATASET_V0/time_intervals_combinations_table.csv',
    #r'/home/aggr/mikhail_u/DATA/DATSET_V0/time_intervals_combinations_table.csv',
    #r'C:\Users\admin\python_programming\DATA\AVABOS\DATSET_V0\time_intervals_combinations_table.csv',
    r'i:\AVABOS\DATSET_V0\time_intervals_combinations_table.csv',
    '--path_to_train_test_split_json',
    r'train_test_split.json',
    '--gpu_device_idx', '0',
    '--class_num', '2',
    '--epoch_num', '100',
    '--batch_size', '1',
    '--max_audio_len', '80000',
    '--max_embeddings_len', '48',
    '--video_frames_num', '128',
    '--video_window_size', '8'
    ]

args = parser.parse_args(sample_args)

path_to_dataset_root = args.path_to_dataset
resume_training = args.resume_training
path_to_intersections_csv = args.path_to_intersections_csv
path_to_train_test_split_json = args.path_to_train_test_split_json

path_to_checkpoint = args.path_to_checkpoint
epoch_num = int(args.epoch_num)
class_num = args.class_num
batch_size = int(args.batch_size)
max_audio_len = args.max_audio_len
max_embeddings_len = args.max_embeddings_len
video_frames_num = args.video_frames_num
video_window_size = args.video_window_size
gpu_device_idx = args.gpu_device_idx

# имя модели соответствует имени экстрактора признаков
phys_gamma_val = 2
verb_gamma_val = 2
model_name = 'A+T(ce)+fusion1L'
modality2aggr = {'video':'phys', 'text':'verb', 'audio':'verb'}
#modality2aggr = {'video':'verb', 'text':'verb', 'audio':'phys'}
modalities_list = [
    'audio',
    'text',
    'video'
    ]
aggr_types_list = set()
for m in modalities_list:
    aggr_types_list.add(modality2aggr[m])

aggr_types_list = list(aggr_types_list)

time_interval_combinations_df = pd.read_csv(path_to_intersections_csv)

with open(path_to_train_test_split_json) as fd:
    combinations_indices_dict = json.load(fd)

train_time_interval_combinations_df =  []
for cluster_id in combinations_indices_dict['train_clusters']:
    df = time_interval_combinations_df[time_interval_combinations_df['cluster_id']==cluster_id]
    train_time_interval_combinations_df.append(df)
train_time_interval_combinations_df = pd.concat(train_time_interval_combinations_df, ignore_index=True)
# для выравнивания баланса классов (баланс смещен в сторону не агрессивного поведения)
# удалим не агрессивные интервалы физ. поведения, которые не пересекаются с вербальным поведением
#drop_no_aggr_filter = (train_time_interval_combinations_df['aggr_type']=='phys')&(train_time_interval_combinations_df['phys_aggr_label']=='NOAGGR')
#train_time_interval_combinations_df = train_time_interval_combinations_df[~drop_no_aggr_filter]
# DEBUG
#train_time_interval_combinations_df = train_time_interval_combinations_df.loc[0:500]

test_time_interval_combinations_df =  []
for cluster_id in combinations_indices_dict['test_clusters']:
    df = time_interval_combinations_df[time_interval_combinations_df['cluster_id']==cluster_id]
    test_time_interval_combinations_df.append(df)
test_time_interval_combinations_df = pd.concat(test_time_interval_combinations_df, ignore_index=True)


device = torch.device(f'cuda:{gpu_device_idx}')

train_video_transform = v2.Compose([
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomAffine(degrees=20, translate=(0.1, 0.3), scale=(0.6, 1.1), shear=10),
    v2.RandomPerspective(distortion_scale=0.2),
    v2.Resize((112, 112), antialias=True),
    AppendZeroValues(target_size=[video_frames_num, 3, 112, 112]),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_video_transform = v2.Compose([
    v2.Resize((112, 112), antialias=True),
    AppendZeroValues(target_size=[video_frames_num, 3, 112, 112]),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_audio_transform = v2.Compose([AppendZeroValues(target_size=[max_audio_len])])
test_audio_transform = v2.Compose([AppendZeroValues(target_size=[max_audio_len])])



train_text_transform = v2.Compose([AppendZeroValues(target_size=[max_embeddings_len, 768])])
test_text_transform = v2.Compose([AppendZeroValues(target_size=[max_embeddings_len, 768])])

train_transforms_dict = {
    'audio': train_audio_transform,
    'text': train_text_transform,
    'video': train_video_transform
}
train_transforms_dict = {k:v for k,v in train_transforms_dict.items()if k in modalities_list}
test_transforms_dict = {
    'audio': test_audio_transform,
    'text': test_text_transform,
    'video': test_video_transform
}
test_transforms_dict = {k:v for k,v in test_transforms_dict.items()if k in modalities_list}


train_dataset = MultimodalFeatureGenDataset(
    time_intervals_df=train_time_interval_combinations_df,
    path_to_dataset=path_to_dataset_root,
    modality_augmentation_dict=train_transforms_dict,
    modality2aggr=modality2aggr,
    actual_modalities_list=modalities_list,
    device=device,
    text_embedding_type='ru_conversational_cased_L-12_H-768_A-12_pt_v1_tokens'
    )
test_dataset = MultimodalFeatureGenDataset(
    time_intervals_df=test_time_interval_combinations_df,
    path_to_dataset=path_to_dataset_root,
    modality_augmentation_dict=train_transforms_dict,
    modality2aggr=modality2aggr,
    actual_modalities_list=modalities_list,
    device=device,
    text_embedding_type='ru_conversational_cased_L-12_H-768_A-12_pt_v1_tokens'
    )


train_batch_sampler = AggrBatchSampler(train_time_interval_combinations_df, batch_size=batch_size, shuffle=True)
test_batch_sampler = AggrBatchSampler(test_time_interval_combinations_df, batch_size=batch_size, shuffle=False)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_sampler=train_batch_sampler,
    num_workers=0
    #pin_memory=True
)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_sampler=test_batch_sampler,
    num_workers=0
    #pin_memory=True
)

# вычисляем веса классов для физичекой и вербальной агрессии
phys_aggr_filter = (train_time_interval_combinations_df['aggr_type'] == 'phys')
verb_aggr_filter = (train_time_interval_combinations_df['aggr_type'] == 'verb')
phys_verb_agr_filter = (train_time_interval_combinations_df['aggr_type'] == 'phys&verb')
phys_aggr_df = train_time_interval_combinations_df[phys_aggr_filter]
verb_aggr_df = train_time_interval_combinations_df[verb_aggr_filter]
phys_verb_aggr_df = train_time_interval_combinations_df[phys_verb_agr_filter]
all_phys_aggr_df = train_time_interval_combinations_df[phys_aggr_filter|phys_verb_agr_filter]
all_verb_aggr_df = train_time_interval_combinations_df[verb_aggr_filter|phys_verb_agr_filter]

verb_weights_series = 1-all_verb_aggr_df['verb_aggr_label'].value_counts()/len(all_verb_aggr_df)
phys_weights_series = 1-all_phys_aggr_df['phys_aggr_label'].value_counts()/len(all_phys_aggr_df)

verb_weights = torch.zeros([class_num], device=device)
phys_weights = torch.zeros([class_num], device=device)

verb_weights[0] = verb_weights_series['NOAGGR']
verb_weights[1] = verb_weights_series['AGGR']

phys_weights[0] = phys_weights_series['NOAGGR']
phys_weights[1] = phys_weights_series['AGGR']

phys_focal_loss = torch.hub.load(
        'adeelh/pytorch-multi-class-focal-loss',
        model='FocalLoss',
        alpha=phys_weights,
        gamma=phys_gamma_val,
        reduction='mean',
        force_reload=False
    )

audio_extractor = AudioCnn1DExtractorWrapper(hidden_size=768)


audio_model = TransformerSequenceProcessor(
    extractor_model=audio_extractor,
    transformer_layer_num=1,
    transformer_head_num=8,
    hidden_size=768,
    class_num=class_num
    )
text_model = TransformerSequenceProcessor(
    extractor_model=nn.Sequential(),
    transformer_layer_num=1,
    transformer_head_num=8,
    hidden_size=768,
    class_num=2
)

# dummy class
class E(nn.Module):
    def __init__(self):
        super().__init__()
        self.e = Swin3d_T_extractor(frame_num=video_frames_num, window_size=video_window_size)
    def forward(self, x, ret_type='PIDOR EPTA'):
        return self.e(x)

video_extractor = Swin3d_T_extractor(frame_num=video_frames_num, window_size=video_window_size)

video_model = TransformerSequenceProcessor(
    extractor_model=video_extractor,
    transformer_layer_num=1,
    transformer_head_num=8,
    hidden_size=768,
    class_num=class_num
    )
    
# определяем размерности векторов признаков для многомодальной обработки
video_features_shape = video_model(torch.zeros([1, 3, video_frames_num, 112, 112])).shape
audio_features_shape = audio_extractor(torch.zeros([1, max_audio_len])).shape
text_features_shape = text_model(torch.zeros([1, max_embeddings_len, 768])).shape
modality_features_shapes_dict = {
    'audio':list(audio_features_shape)[1:],
    'text':list(text_features_shape)[1:],
    'video':list(video_features_shape)[1:]
}

'''
print(f'video_Shape={torch.zeros([1, 3, video_frames_num, 112, 112]).shape}')
print(f'audio_Shape={torch.zeros([1, max_audio_len]).shape}')
print(f'text_Shape= {torch.zeros([1, max_embeddings_len, 768]).shape}')
print(modality_features_shapes_dict)
exit()
'''
modality_features_shapes_dict = {k:v for k,v in modality_features_shapes_dict.items() if k in modalities_list}
modality_extractors_dict = {
    'audio':audio_extractor,
    'text':nn.Sequential(),
    #'text':text_model,
    #'text':text_model,
    'video':video_extractor
    #'video':video_model
}
modality_extractors_dict = {k:v for k,v in modality_extractors_dict.items() if k in modalities_list}
modality_extractors_dict = nn.ModuleDict(modality_extractors_dict)

modality_fusion_module = EqualSizedTransformerModalitiesFusion(fusion_transformer_layer_num=1, fusion_transformer_hidden_size=768, fusion_transformer_head_num=8)

modalities_adaptors_inout_sizes_dict = {'video':[video_features_shape[-1], 768], 'audio':[audio_features_shape[-1], 768], 'text':[text_features_shape[-1], 768]}
aggr_classifiers = PhysVerbClassifierConcatFeatures(
    modalities_list=modalities_list,
    class_num=2,
    modalities_adaptors_inout_sizes_dict=modalities_adaptors_inout_sizes_dict,
    modality2aggr=modality2aggr
    )
model = PhysVerbModel(
    modality_extractors_dict=modality_extractors_dict,
    modality_fusion_module=modality_fusion_module,
    #modality_fusion_module=nn.Sequential(),
    classifiers=aggr_classifiers,
    modality_features_shapes_dict=modality_features_shapes_dict,
    modality2aggr=modality2aggr,
    hidden_size=768,
    class_num=2)

model = model.to(device)

for data, labels, (verb_name, phys_name) in tqdm(test_dataloader):

    ret_dict = model(data)
    
    verb_name = verb_name[0]
    filtered_labels = [v for v in labels if v[0]=='verb']
    if len(filtered_labels) > 0:
        aggr_type, label = filtered_labels[0]
        label = label.item()
        one_hot_label = np.zeros((2,), dtype=int)
        one_hot_label[label] = 1
        res = ret['verb']


In [32]:
ret_dict

{'phys': tensor([[0.0926, 0.0431]], device='cuda:0', grad_fn=<AddmmBackward0>),
 'verb': tensor([[ 0.0345, -0.1720]], device='cuda:0', grad_fn=<AddmmBackward0>)}

In [51]:
aggr_type, label = labels[0]
label = label.item()
one_hot_label = np.zeros((2,), dtype=int)
one_hot_label[label] = 1
one_hot_label

array([0, 1])

In [24]:
verb_name

('c-2_v-2-0_F-7-0_-0.001--0.001_-1',)

In [27]:
labels

[[('verb_EMPTY',), tensor([-1], device='cuda:0')],
 [('phys',), tensor([0], device='cuda:0')]]

In [52]:
ret

{'phys': tensor([[-0.0234,  0.0085]], device='cuda:0', grad_fn=<AddmmBackward0>),
 'verb': tensor([[0.0880, 0.0727]], device='cuda:0', grad_fn=<AddmmBackward0>)}

In [None]:
path_to_checkpoint = r'saving_dir\12.11.2024, 00-37-51 (A+T(ce)+fus1L)\A+T(ce)+fus1L_current_ep-51.pt'
trainer = torch.load(path_to_checkpoint)
saved_model = trainer.model
#state_dict = saved_model.state_dict()
#model.load_state_dict(state_dict=state_dict)
saved_model

In [None]:
saved_model

In [25]:
data

[[('audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_EMPTY',
   'audio_