In [10]:
from datasets.ve8 import *
from core.utils import ModelEMA
from torch.utils.data import DataLoader
import os
proxy = 'http://10.16.106.234:13390'
os.environ['http_proxy'] = proxy
os.environ['https_proxy'] = proxy
import warnings
warnings.filterwarnings('ignore')


In [11]:
import argparse


def parse_opts():
    parser = argparse.ArgumentParser()
    arguments = {
        'coefficients': [
            dict(name='--lambda_0',
                 default='0.5',
                 type=float,
                 help='Penalty Coefficient that Controls the Penalty Extent in PCCE'),
        ],
        'paths': [
            dict(name='--exp_name',
                 default='test',
                 type=str,
                 help='exp_name'),
            dict(name='--resnet101_pretrained',
                 default='/mnt/d/MART-main/resnet-101-kinetics.pth',
                 type=str,
                 help='Global path of pretrained 3d resnet101 model (.pth)'),
            dict(name='--root_path',
                 default="/mnt/d/MART-main/datasets/ve8",
                 type=str,
                 help='Global path of root directory'),
            dict(name="--video_path",
                 default="imgs",
                 type=str,
                 help='Local path of videos', ),
            dict(name="--annotation_path",
                 default='ve8_01.json',
                 type=str,
                 help='Local path of annotation file'),
            dict(name="--result_path",
                 default='results',
                 type=str,
                 help="Local path of result directory"),
            dict(name='--expr_name',
                 type=str,
                 default=''),
            dict(name='--audio_path',
                 type=str,
                 default='mp3',
                 help='Local path of audios'),
            dict(name='--alg',
                 type=str,
                 default='MBT',
                 help='Local path of audios'),
            dict(name='--srt_path',
                 type=str,
                 default='srt',
                 help='Local path of text of audios'),
        ],
        'core': [
            dict(name='--batch_size',
                 default=1,
                 type=int,
                 help='Batch Size'),
            dict(name='--accu_step',
                 default=36,
                 type=int,
                 help='Batch Size'),
            dict(name='--snippet_duration',
                 default=16,
                 type=int),
            dict(name='--sample_size',
                 default=112,
                 type=int,
                 help='Heights and width of inputs'),
            dict(name='--n_classes',
                 default=8,
                 type=int,
                 help='Number of classes'),
            dict(name='--seq_len',
                 default=12,
                 type=int),
            dict(name='--val_len',
                 default=16,
                 type=int),
            dict(name='--r_act',
                 default=12,
                 type=int),
            dict(name='--loss_func',
                 default='pcce_ve8_av',
                 type=str,
                 help='ce'),
            dict(name='--learning_rate',
                 default=1e-4,
                 type=float,
                 help='Initial learning rate', ),
            dict(name='--weight_decay',
                 default=0.001,
                 type=float,
                 help='Weight Decay'),
            dict(name='--fps',
                 default=30,
                 type=int,
                 help='fps')

        ],
        'network': [
            {
                'name': '--audio_embed_size',
                'default': 256,
                'type': int,
            },
            {
                'name': '--audio_n_segments',
                'default': 12,
                'type': int,
            },
            {
                'name': '--audio_time',
                'default': 100,
                'type': int,
            }
        ],

        'common': [
            dict(name='--dataset',
                 type=str,
                 default='ve8',
                 ),
            dict(name='--use_cuda',
                 action='store_true',
                 default=False,
                 help='only cuda supported!'
                 ),
            dict(name='--debug',
                 default=False,
                 action='store_true'),
            dict(name='--dl',
                 action='store_false',
                 default=True,
                 help='drop last'),
            dict(
                name='--n_threads',
                default=10,
                type=int,
                help='Number of threads for multi-thread loading',
            ),
            dict(
                name='--n_epochs',
                default=100,
                type=int,
                help='Number of total epochs to run',
            ),
            dict(
                name='--device',
                default=1,
                type=int,
                help='device to run',
            )
        ]
    }

    for group in arguments.values():
        for argument in group:
            name = argument['name']
            del argument['name']
            parser.add_argument(name, **argument)

    args=parser.parse_args(args=[])
#    args = parser.parse_args()
    return args
opt = parse_opts()

In [12]:
def load_annotation_data(data_file_path):
    with open(data_file_path, 'r') as data_file:
        return json.load(data_file)
def get_video_names_and_annotations(data, subset):
    video_names = []
    annotations = []
    for key, value in data['database'].items():
        if value['subset'] == subset:
            label = value['annotations']['label']
            video_names.append('{}/{}'.format(label, key))
            annotations.append(value['annotations'])
    return video_names, annotations
def make_dataset(video_root_path, annotation_path, audio_root_path, srt_root_path, subset, fps=30, need_audio=True,
                 ORIGINAL_FPS=30):
    data = load_annotation_data(annotation_path)
    video_names, annotations = get_video_names_and_annotations(data, subset) # xx/xx, 'label':'xx'
    class_to_idx = get_class_labels(data) # class_to_idx['label'] = idx
    idx_to_class = {}
    for name, label in class_to_idx.items():
        idx_to_class[label] = name # idx_to_class['idx'] = label

    dataset = []
    for i in range(len(video_names)):
        if i % 100 == 0:
            print("Dataset loading [{}/{}]".format(i, len(video_names)))
        video_path = os.path.join(video_root_path, video_names[i])
        if need_audio:
            audio_path = os.path.join(audio_root_path, video_names[i] + '.mp3')
            assert os.path.exists(audio_path), audio_path
        else:
            audio_path = None
        srt_path = os.path.join(srt_root_path, video_names[i] + '.srt')

        
        assert os.path.exists(video_path), video_path
        assert os.path.exists(srt_path), srt_path

        n_frames_file_path = os.path.join(video_path, 'n_frames')
        n_frames = int(load_value_file(n_frames_file_path))
        if n_frames <= 0:
            print(video_path)
            continue

        begin_t = 1
        end_t = n_frames
        sample = {
            'video': video_path,
            'segment': [begin_t, end_t],
            'n_frames': n_frames,
            'video_id': video_names[i].split('/')[1],
            'srt': srt_path,
            'audio': audio_path
        }
        #if need_audio: sample['audio'] = audio_path
        assert len(annotations) != 0
        sample['label'] = class_to_idx[annotations[i]['label']]

        ORIGINAL_FPS = ORIGINAL_FPS
        step = ORIGINAL_FPS // fps

        sample['frame_indices'] = list(range(1, n_frames + 1, step))
        dataset.append(sample)
    return dataset, idx_to_class

In [13]:
dataitem,idx_to_class=make_dataset(annotation_path='dataset.json',
                                  video_root_path='dataset/imgs',
                                  audio_root_path='dataset/mp3',
                                  srt_root_path='dataset/srt',subset='training',
                                  fps=2,ORIGINAL_FPS=24
                                  )

Dataset loading [0/2000]
Dataset loading [100/2000]
Dataset loading [200/2000]
Dataset loading [300/2000]
Dataset loading [400/2000]
Dataset loading [500/2000]
Dataset loading [600/2000]
Dataset loading [700/2000]
Dataset loading [800/2000]
Dataset loading [900/2000]
Dataset loading [1000/2000]
Dataset loading [1100/2000]
Dataset loading [1200/2000]
Dataset loading [1300/2000]
Dataset loading [1400/2000]
Dataset loading [1500/2000]
Dataset loading [1600/2000]
Dataset loading [1700/2000]
Dataset loading [1800/2000]
Dataset loading [1900/2000]


In [8]:
dataitem[2]['video_id']

'dis4_11'

In [9]:
class TSN_slide(object):
    def __init__(self, seq_len=12, snippet_duration=16, center=False):
        self.seq_len = seq_len
        self.snippets_duration = snippet_duration
    def __call__(self, frame_indices):
        leng=len(frame_indices)
        snippets = []
        pad=max((leng-self.snippets_duration)//(self.seq_len-1),0)


        # crop = TemporalRandomCrop(size=self.snippets_duration)
        for i in range(self.seq_len):
            try:
                snippets.append([frame_indices[j] for j in range(0+pad*i,self.snippets_duration+pad*i)])
            except:
                print(0+pad*i,self.snippets_duration+pad*i)
        return snippets

In [14]:
class VE8Dataset(torch.utils.data.Dataset):
    def __init__(self,
                 video_path,
                 audio_path,
                 annotation_path,
                 srt_path,
                 subset,
                 fps=30,
                 spatial_transform=None,
                 temporal_transform=None,
                 target_transform=None,
                 get_loader=get_default_video_loader,
                 need_audio=True,
                 alg=opt.alg,
                 audio_n_segments=None,
                 ORIGINAL_FPS=30
                 ):
        self.subset = subset
        self.data, self.class_names = make_dataset(
            video_root_path=video_path,
            annotation_path=annotation_path,
            audio_root_path=audio_path,
            srt_root_path=srt_path, 
            subset=subset,
            fps=fps,ORIGINAL_FPS=ORIGINAL_FPS,
            need_audio=need_audio
        )
        self.spatial_transform = spatial_transform
        self.temporal_transform = temporal_transform
        self.target_transform = target_transform
        self.loader = get_loader()
        self.fps = fps
        self.ORIGINAL_FPS = ORIGINAL_FPS
        self.need_audio = need_audio
        self.alg = alg
        self.norm_mean = -6.6268077
        self.norm_std = 5.358466
        self.audio_n_segments = opt.audio_n_segments if audio_n_segments is None else audio_n_segments
        # self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
        # self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.sentinet = SenticNet()

    def __getitem__(self, index):
        data_item = self.data[index]
        video_path = data_item['video']
        frame_indices = data_item['frame_indices']
        snippets_frame_idx = self.temporal_transform(frame_indices)

        if self.need_audio:
            if self.alg == 'VAANet' or self.alg == 'MBT' or self.alg == 'MBT_w_language':
                timeseries_length = 100*self.audio_n_segments
                # audio_path = data_item['audio']
                # feature = preprocess_audio(audio_path).T
                # k = timeseries_length // feature.shape[0] + 1
                # feature = np.tile(feature, reps=(k, 1))
                # audios = feature[:timeseries_length, :]
                # audios = torch.FloatTensor(audios)
                waveform, sr = torchaudio.load(data_item['audio'])
                waveform = waveform - waveform.mean()
                fbank = torchaudio.compliance.kaldi.fbank(waveform, htk_compat=True, sample_frequency=sr, use_energy=False,
                                                  window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=10)
                if fbank.shape[0]<=timeseries_length:
                    k = timeseries_length // fbank.shape[0] + 1
                    fbank = np.tile(fbank, reps=(k, 1))
                    audios = fbank[:timeseries_length, :]
                else:
                    blk = int(fbank.shape[0]/self.audio_n_segments)
                    aud = []
                    for i in list(range(0,fbank.shape[0],blk))[:self.audio_n_segments]:
                        ind = i+int(random.random()*(blk-100))
                        aud.append(fbank[ind:ind+100])
                    audios = torch.cat(aud)
                if audios.shape[0]!=timeseries_length:
                    print(audios.shape)
                audios = torch.FloatTensor(audios)
                if self.subset == 'training':
                    freqm = torchaudio.transforms.FrequencyMasking(24)
                    timem = torchaudio.transforms.TimeMasking(192)
                    audios = torch.transpose(audios, 0, 1)
                    audios = audios.unsqueeze(0)
                    audios = freqm(audios)
                    audios = timem(audios)
                    audios = audios.squeeze(0)
                    audios = torch.transpose(audios, 0, 1)
                audios = (audios - self.norm_mean) / (self.norm_std * 2)
                # if self.subset == 'training':
                #     audios = audios + torch.rand(audios.shape[0], audios.shape[1]) * np.random.rand() / 10
                    # audios = torch.roll(audios, np.random.randint(-10, 10), 0)
                    
            elif self.alg == 'TFN':
                timeseries_length = 4096
                audio_path = data_item['audio']
                feature = preprocess_audio(audio_path).T
                k = timeseries_length // feature.shape[0] + 1
                feature = np.tile(feature, reps=(k, 1))
                audios = feature[:timeseries_length, :]
                audios = torch.FloatTensor(audios)
            elif self.alg == 'MSAF':
                timeseries_length = 186
                #timeseries_length = 212
                audio_path = data_item['audio']
                X, sample_rate = librosa.load(audio_path, duration=2.45, sr=22050 * 2, offset=0.5)
                sample_rate = np.array(sample_rate)
                audios = librosa.feature.mfcc(y=X, sr=sample_rate, n_mfcc=66)
                #audios = librosa.feature.mfcc(y=X, sr=sample_rate, n_mfcc=16)
                k = timeseries_length // audios.shape[1] + 1
                audios = np.tile(audios, reps=(1, k))
                print(audios.shape)
                audios = audios[:, :timeseries_length]
                print(audios.shape)
                audios = torch.FloatTensor(audios)
            # elif self.alg == 'MSAF':
            #     timeseries_length = 212
            #     audio_path = data_item['audio']
            #     X, sample_rate = librosa.load(audio_path, duration=2.45, sr=22050 * 2, offset=0.5)
            #     sample_rate = np.array(sample_rate)
            #     audios = librosa.feature.mfcc(y=X, sr=sample_rate, n_mfcc=16)
            #     k = timeseries_length // audios.shape[1] + 1
            #     audios = np.tile(audios, reps=(1, k))
            #     audios = audios[:, :timeseries_length]
            #     audios = torch.FloatTensor(audios)
        else:
            audios = []

        snippets = []
        for snippet_frame_idx in snippets_frame_idx:
            snippet = self.loader(video_path, snippet_frame_idx)
            snippets.append(snippet)

        self.spatial_transform.randomize_parameters()
        snippets_transformed = []
        for snippet in snippets:
            snippet = [self.spatial_transform(img) for img in snippet]
            snippet = torch.stack(snippet, 0).permute(1, 0, 2, 3)
            snippets_transformed.append(snippet)
        snippets = snippets_transformed
        snippets = torch.stack(snippets, 0)
        
        # if self.alg == 'MSAF':
        #     seq_len, c, duration, h, w = snippets.size()
        #     snippets = snippets.permute(1, 0, 2, 3, 4).contiguous()
        #     snippets = snippets.view(c, seq_len*duration, h, w).contiguous()

        target = self.target_transform(data_item)
        visualization_item = [data_item['video_id']]
        
        # for srt
        srt_content = read_srt(data_item['srt'])
        waveform, sr = torchaudio.load(data_item['audio'])
        srt_seg = srt2seg(self.audio_n_segments, waveform.shape[1]/sr * 1000, srt_content)

        return snippets, target, audios, visualization_item, srt_seg
    def __len__(self):
        return len(self.data)
def srt2seg(segnum, totaltime, srt_content):
    result = ['' for _ in range(segnum)]
    seglen = totaltime / segnum
    for content in srt_content:
        start = content['start']
        end = content['end']
        duration = content['duration']

        for i in range(segnum):
            if i * seglen <= start and (i + 1) * seglen > start:
                start_idx = i
            if i * seglen < end and (i + 1) * seglen >= end:
                end_idx = i    
            else:
                end_idx=segnum-1

        for j in range(start_idx, end_idx + 1):
            result[j] = result[j] + ' ' + content['text']
    return result

In [15]:
from core.utils import get_spatial_transform
class ClassLabel(object):
    def __call__(self, target):
        return target['label']

spatial_transform = get_spatial_transform(opt, 'train')
temporal_transform = TSN_slide(seq_len=opt.seq_len, snippet_duration=opt.snippet_duration, center=False)
target_transform = ClassLabel()

In [16]:

dataset=VE8Dataset(annotation_path='dataset.json',
                   video_path='dataset/imgs',
                   audio_path='dataset/mp3',
                   srt_path='dataset/srt',
                   subset='training',ORIGINAL_FPS=30
                ,fps=30,spatial_transform=spatial_transform,temporal_transform=temporal_transform,target_transform=target_transform)

loader=DataLoader(dataset,batch_size=1,shuffle=False,num_workers=0,pin_memory=True,drop_last=True)

Dataset loading [0/2000]
Dataset loading [100/2000]
Dataset loading [200/2000]
Dataset loading [300/2000]
Dataset loading [400/2000]
Dataset loading [500/2000]
Dataset loading [600/2000]
Dataset loading [700/2000]
Dataset loading [800/2000]
Dataset loading [900/2000]
Dataset loading [1000/2000]
Dataset loading [1100/2000]
Dataset loading [1200/2000]
Dataset loading [1300/2000]
Dataset loading [1400/2000]
Dataset loading [1500/2000]
Dataset loading [1600/2000]
Dataset loading [1700/2000]
Dataset loading [1800/2000]
Dataset loading [1900/2000]


In [10]:
for i in loader:
    print(i)
    break

[tensor([[[[[[194., 192., 192.,  ..., 188., 189., 189.],
            [193., 190., 190.,  ..., 185., 186., 186.],
            [191., 188., 188.,  ..., 180., 181., 182.],
            ...,
            [184., 183., 182.,  ..., 186., 187., 188.],
            [187., 185., 184.,  ..., 188., 189., 190.],
            [189., 187., 186.,  ..., 191., 191., 191.]],

           [[194., 191., 191.,  ..., 188., 189., 189.],
            [193., 189., 189.,  ..., 185., 186., 186.],
            [191., 187., 186.,  ..., 180., 181., 182.],
            ...,
            [184., 183., 181.,  ..., 186., 187., 188.],
            [186., 185., 184.,  ..., 188., 189., 190.],
            [189., 187., 186.,  ..., 191., 191., 191.]],

           [[194., 191., 191.,  ..., 188., 189., 189.],
            [193., 189., 189.,  ..., 185., 186., 186.],
            [191., 187., 186.,  ..., 180., 181., 182.],
            ...,
            [184., 183., 181.,  ..., 186., 187., 188.],
            [186., 185., 184.,  ..., 188., 189.,

In [11]:
for j in loader:
    for i in j:
        try:
            print(i.shape)
        except:
            print(i)
    break

torch.Size([1, 12, 3, 16, 112, 112])
torch.Size([1])
torch.Size([1, 1200, 128])
[['dis4_1']]
[[''], [''], [''], [''], [''], [''], [''], [''], [''], [''], [''], ['']]


In [12]:

opt.seq_len=30
opt.audio_n_segments=30
opt.snippet_duration=16
opt.val_len=16

In [13]:
data_item = dataset.data[0]
video_path = data_item['video']
frame_indices = data_item['frame_indices']
snippets_frame_idx = dataset.temporal_transform(frame_indices)

In [14]:

setup_seed()
opt.device_ids = list(range(device_count()))
local2global_path(opt)
#print(opt)
model, parameters = generate_model(opt)
model_ema = ModelEMA(model,decay=0.999)
criterion = get_loss(opt)
criterion = criterion.cuda()
optimizer = get_optim(opt, parameters)

# opt.exp_name = os.path.join('results',opt.exp_name)
if not os.path.exists(opt.exp_name):
    os.makedirs(opt.exp_name)
writer = SummaryWriter(logdir=opt.exp_name)

tokenizer, max_input_length, init_token_idx, eos_token_idx, _, _ = initialize_tokenizer()
text_tools = {
    'tokenizer': tokenizer,
    'max_input_length': max_input_length,
    'init_token_idx': init_token_idx,
    'eos_token_idx': eos_token_idx
}

# train
# spatial_transform = get_spatial_transform(opt, 'train')
# temporal_transform = TSN(seq_len=opt.seq_len, snippet_duration=opt.snippet_duration, center=False)
# target_transform = ClassLabel()

/mnt/d/MART-main/datasets/ve8/results/result_20240920_001656/tensorboard /mnt/d/MART-main/datasets/ve8/results/result_20240920_001656/checkpoints
Namespace(lambda_0=0.5, exp_name='test', resnet101_pretrained='/mnt/d/MART-main/resnet-101-kinetics.pth', root_path='/mnt/d/MART-main/datasets/ve8', video_path='/mnt/d/MART-main/datasets/ve8/imgs', annotation_path='/mnt/d/MART-main/datasets/ve8/ve8_01.json', result_path='/mnt/d/MART-main/datasets/ve8/results/result_20240920_001656', expr_name='', audio_path='/mnt/d/MART-main/datasets/ve8/mp3', alg='MBT', srt_path='/mnt/d/MART-main/datasets/ve8/srt', batch_size=1, accu_step=36, snippet_duration=16, sample_size=112, n_classes=8, seq_len=30, val_len=16, r_act=12, loss_func='pcce_ve8_av', learning_rate=0.0001, weight_decay=0.001, fps=30, audio_embed_size=256, audio_n_segments=30, audio_time=100, dataset='ve8', use_cuda=False, debug=False, dl=True, n_threads=10, n_epochs=100, device=1, device_ids=[0], log_path='/mnt/d/MART-main/datasets/ve8/result

In [15]:
label_info = {}

label_info['ve8'] = {}
label_info['ve8']['emotion'] = np.array(['Anger', 'Anticipation', 'Disgust', 'Fear',  'Joy', 'Sad', 'Surprise', 'Trust'])
label_info['ve8']['sentiment'] = np.array([0, 1, 0, 0, 1, 0, 1, 1])

In [17]:
# seq_len: 序列长度，可能是视频帧的数量。 12
# batch: 批次大小，表示一次处理的视频数量。1
# nc: 通道数，视频通常是3个颜色通道（RGB）。3
# snippet_duration: 片段时长，可能是指每个视频片段包含的帧数。16
# sample_size: 样本大小，可能是指每个帧的分辨率（例如，宽度和高度）。112
# _: 这个下划线变量通常用来忽略不需要的值。


In [18]:
def test_performance(opt, inputs, model, criterion, i=0, print_attention=True, period=30, return_attention=False, isTrain=True):
    if opt.alg=='VAANet':
        visual, target, audio = inputs
        outputs = model(visual, audio)
        y_pred, alpha, beta, gamma, temporal_score = outputs
        loss = criterion(y_pred, target)
    else:
        visual, target, audio = inputs
        outputs = model(visual, audio)
        y_pred, temporal_score = outputs
        #loss = criterion(y_pred, target)
    return y_pred,temporal_score

In [20]:
video_dir=os.listdir('dataset/mp4/Joy')
video_large=os.listdir('dataset/videos/Joy')


In [24]:
video_dir

['dis4_1.mp4',
 'dis4_10.mp4',
 'dis4_11.mp4',
 'dis4_12.mp4',
 'dis4_13.mp4',
 'dis4_14.mp4',
 'dis4_15.mp4',
 'dis4_16.mp4',
 'dis4_17.mp4',
 'dis4_18.mp4',
 'dis4_19.mp4',
 'dis4_2.mp4',
 'dis4_20.mp4',
 'dis4_21.mp4',
 'dis4_22.mp4',
 'dis4_23.mp4',
 'dis4_24.mp4',
 'dis4_25.mp4',
 'dis4_26.mp4',
 'dis4_27.mp4',
 'dis4_28.mp4',
 'dis4_29.mp4',
 'dis4_3.mp4',
 'dis4_30.mp4',
 'dis4_31.mp4',
 'dis4_32.mp4',
 'dis4_33.mp4',
 'dis4_34.mp4',
 'dis4_35.mp4',
 'dis4_36.mp4',
 'dis4_37.mp4',
 'dis4_38.mp4',
 'dis4_39.mp4',
 'dis4_4.mp4',
 'dis4_40.mp4',
 'dis4_41.mp4',
 'dis4_42.mp4',
 'dis4_43.mp4',
 'dis4_44.mp4',
 'dis4_45.mp4',
 'dis4_46.mp4',
 'dis4_47.mp4',
 'dis4_48.mp4',
 'dis4_49.mp4',
 'dis4_5.mp4',
 'dis4_50.mp4',
 'dis4_51.mp4',
 'dis4_52.mp4',
 'dis4_53.mp4',
 'dis4_54.mp4',
 'dis4_55.mp4',
 'dis4_56.mp4',
 'dis4_57.mp4',
 'dis4_58.mp4',
 'dis4_59.mp4',
 'dis4_6.mp4',
 'dis4_60.mp4',
 'dis4_61.mp4',
 'dis4_62.mp4',
 'dis4_63.mp4',
 'dis4_64.mp4',
 'dis4_65.mp4',
 'dis4_66.mp4'

In [29]:
torch.concat(embedding,dim=0).shape

torch.Size([2, 768])

In [28]:

for i, data_item in enumerate(loader):
    visual, target, audio, visualization_item, batch_size = process_data_item(opt, data_item)
    _,_,embedding=model(visual,audio)
    np.save('dataset/embedding/'+dataitem[i]['video_id']+'.npy',torch.concat(embedding,dim=0).detach().cpu().numpy())

feature shape torch.Size([1, 768]) torch.Size([1, 768])
out shape torch.Size([1, 8]) torch.Size([1, 8])
feature shape torch.Size([1, 768]) torch.Size([1, 768])
out shape torch.Size([1, 8]) torch.Size([1, 8])
feature shape torch.Size([1, 768]) torch.Size([1, 768])
out shape torch.Size([1, 8]) torch.Size([1, 8])
feature shape torch.Size([1, 768]) torch.Size([1, 768])
out shape torch.Size([1, 8]) torch.Size([1, 8])
feature shape torch.Size([1, 768]) torch.Size([1, 768])
out shape torch.Size([1, 8]) torch.Size([1, 8])
feature shape torch.Size([1, 768]) torch.Size([1, 768])
out shape torch.Size([1, 8]) torch.Size([1, 8])
feature shape torch.Size([1, 768]) torch.Size([1, 768])
out shape torch.Size([1, 8]) torch.Size([1, 8])
feature shape torch.Size([1, 768]) torch.Size([1, 768])
out shape torch.Size([1, 8]) torch.Size([1, 8])
feature shape torch.Size([1, 768]) torch.Size([1, 768])
out shape torch.Size([1, 8]) torch.Size([1, 8])
feature shape torch.Size([1, 768]) torch.Size([1, 768])
out shap

In [49]:
from scipy.io import savemat
embedding_dir=os.listdir('dataset/embedding')
video_large_dir=os.listdir('dataset/videos/Joy')
for video in video_large_dir:
    data=np.expand_dims(np.load(f'dataset/embedding/{video.split(".")[0]}_{1}.npy'), axis=0)
    i=2
    while f'{video.split(".")[0]}_{i}.npy' in embedding_dir:
        data=np.concatenate([data,np.expand_dims(np.load(f'dataset/embedding/{video.split(".")[0]}_{i}.npy'), axis=0)],axis=0)
        i+=1
    print(data.shape)
    dict_data={'visual_feature':data[:,0], 'audio_feature':data[:,1]}
    savemat(f'dataset/{video.split(".")[0]}.mat', dict_data)

(72, 2, 768)
(80, 2, 768)
(150, 2, 768)
(83, 2, 768)
(90, 2, 768)
(97, 2, 768)
(115, 2, 768)
(107, 2, 768)
(66, 2, 768)
(57, 2, 768)
(121, 2, 768)
(100, 2, 768)
(105, 2, 768)
(103, 2, 768)
(111, 2, 768)
(73, 2, 768)
(108, 2, 768)
(72, 2, 768)
(115, 2, 768)
(65, 2, 768)
(110, 2, 768)
