In [None]:
# 비디오 로드
# 비디오 피처 뽑기
# 오디오 -> spectogram

In [14]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torch.nn.functional as F

import pandas as pd
import numpy as np

import pickle
import re
import random
import math
from gensim.models.keyedvectors import KeyedVectors

from functools import partial
from timm.models.vision_transformer import DropPath, Mlp, Attention

In [15]:
class MSRVTT_DataLoader(Dataset):
    def __init__(
            self,
            data_path,
            we, #word embewdding
            we_dim=300,
            max_words=30,
            num_frames_multiplier=5, #오디에 데이터 길이 조절용
            training=True
    ):
        self.data = pickle.load(open(data_path, 'rb')) #pkl파일을 바이트 스트림(이진 모드)
        self.we = we
        self.we_dim = we_dim
        self.max_words = max_words
        self.max_video = 30
        self.num_frames_multiplier = num_frames_multiplier
        self.training = training

    def __len__(self):
        return len(self.data)

    def _zero_pad_tensor(self, tensor, size):  #입력 텐서의 크기를 고정된 크기로 만들기
        if len(tensor) >= size:
            return tensor[:size]
        else:
            zero = np.zeros((size - len(tensor), self.we_dim), dtype=np.float32)
            return np.concatenate((tensor, zero), axis=0)

    def _tokenize_text(self, sentence):  #텍스트를 단어 또는 부분 문자열로 분할
        w = re.findall(r"[\w']+", str(sentence))
        return w

    def _words_to_we(self, words):  #단어를 임베딩 벡터로 변환
        # words = [word for word in words if word in self.we.vocab]
        words = [word for word in words if word in self.we.key_to_index]
        if words: #해당 단어가 임베딩 모델에 존재할 때 벡터 추출(학습 한 것만)
            we = self._zero_pad_tensor(self.we[words], self.max_words)
            return torch.from_numpy(we)
        else:
            return torch.zeros(self.max_words, self.we_dim)

    def _get_caption(self, idx):
        """Chooses random caption if training. Uses set caption if evaluating."""
        if self.training: #훈련중일 경우 무작위 caption 가져오기
            captions = self.data[idx]['caption']
            caption = self._words_to_we(self._tokenize_text(random.choice(captions)))
            return caption
        else:
            caption = self.data[idx]['eval_caption']
            return self._words_to_we(self._tokenize_text(caption))
        

    def __getitem__(self, idx):
        video_id = self.data[idx]['id']  #비디오의 고유 식별자 가져오기
        # load 2d and 3d features (features are pooled over the time dimension)
        feat_2d = F.normalize(torch.from_numpy(self.data[idx]['2d_pooled']).float(), dim=0) #2D 데이터 정규화
        feat_3d = F.normalize(torch.from_numpy(self.data[idx]['3d_pooled']).float(), dim=0) #3D 데이터 정규화
        video = torch.cat((feat_2d, feat_3d)) 

        # load audio and zero pad/truncate if necessary
        audio = self.data[idx]['audio']  #오디오의 특징 가져오기
        target_length = 1024 * self.num_frames_multiplier
        nframes = audio.numpy().shape[1]
        p = target_length - nframes #오디오의 길이를 확인하고 부족한 경우 패딩을 추가.
        if p > 0:
            audio = np.pad(audio, ((0,0),(0,p)), 'constant', constant_values=(0,0))
        elif p < 0:
            audio = audio[:,0:p]
        audio = torch.FloatTensor(audio)

        # choose a caption
        caption=''
        caption = self._get_caption(idx)

        # category 추가
        category = self.data[idx]['category']

        return {'video': video, 'text': caption, 'video_id': video_id,
                'audio': audio, 'nframes': nframes, 'category': category}

In [16]:
we_path = 'C:/Users/heeryung/code/24w_deep_daiv/GoogleNews-vectors-negative300.bin'
data_path = 'C:/Users/heeryung/code/24w_deep_daiv/msrvtt_category_test.pkl'

we = KeyedVectors.load_word2vec_format(we_path, binary=True)
dataset = MSRVTT_DataLoader(data_path=data_path, we=we)
ori_dataset = pickle.load(open(data_path, 'rb'))

In [30]:
ori_dataset[0]

{'id': 'video7020',
 'audio': tensor([[-80.0000, -80.0000, -80.0000,  ..., -42.0759, -39.5331, -71.7272],
         [-80.0000, -80.0000, -80.0000,  ..., -33.3933, -31.5323, -68.4606],
         [-80.0000, -80.0000, -80.0000,  ..., -13.8896, -23.3922, -69.1881],
         ...,
         [-64.5946, -65.5943, -66.7466,  ..., -31.8602, -35.7037, -48.6106],
         [-69.1042, -68.7961, -67.8672,  ..., -39.7601, -41.4094, -53.2567],
         [-67.2767, -67.4139, -68.0604,  ..., -63.2213, -64.2628, -67.5640]]),
 '3d': array([[0.00738  , 0.007713 , 0.00536  , ..., 0.00956  , 0.02512  ,
         0.0010195],
        [0.0008926, 0.0002233, 0.00241  , ..., 0.00905  , 0.00473  ,
         0.000507 ],
        [0.000391 , 0.003115 , 0.005142 , ..., 0.004242 , 0.00395  ,
         0.000588 ],
        ...,
        [0.       , 0.       , 0.02353  , ..., 0.001635 , 0.000309 ,
         0.003887 ],
        [0.       , 0.       , 0.00997  , ..., 0.01394  , 0.005226 ,
         0.00604  ],
        [0.0002294, 0.  

In [31]:
for data in ori_dataset:
    if data['id']=='video7061':
        caption_7061 = data['caption']
        eval_7061 = data['eval_caption']

In [32]:
for data in ori_dataset:
    if data['id']=='video7118':
        caption_7118 = data['caption']
        eval_7118 = data['eval_caption']

In [33]:
eval_7118

'a young girl in a horror movie is haunted'

In [34]:
caption = {}
caption['caption_7061'] = caption_7061
caption['caption_7118'] = caption_7118
caption['eval_caption_7118'] = caption_7118
caption['eval_caption_7118'] = caption_7118
caption

{'caption_7061': ['in a fish tank two red fishes are playing',
  'music playing in the background showing fish swiming around',
  'two orange and white fish are swimming together',
  'two fishes are moving in a water aquarium',
  'there are two fish floating in to the water',
  'a fish tank with two gold fish and plants',
  'goldfish chase each other around a blue tank to music',
  'gold fishes are swimming in the blue water of aquarium',
  'two orange and white fish frolic around an aquatic plant in a pool with a deep blue bottom',
  'two coy fish swimming and eating as they play',
  'two cute little gold fish are playing in water so nice to see',
  'the orange fishes are present in the aquarium looking very beautifull',
  'in the water fish seem to be eating from a plant',
  'the orange fish are swimming in the aquarium',
  'the orange fish are swimming in the aquarium',
  'two little koi fish keep swimming next to each other',
  'in an acquarium tank two fishes are playing and the c

In [29]:
import json
with open("C:/Users/heeryung/code/24w-Tri-Modalities/test_caption.json", "w") as json_file:
    json.dump(caption, json_file)

In [4]:
ori_dataset[0] # 실제 video7020과 비교

{'id': 'video7020',
 'audio': tensor([[-80.0000, -80.0000, -80.0000,  ..., -42.0759, -39.5331, -71.7272],
         [-80.0000, -80.0000, -80.0000,  ..., -33.3933, -31.5323, -68.4606],
         [-80.0000, -80.0000, -80.0000,  ..., -13.8896, -23.3922, -69.1881],
         ...,
         [-64.5946, -65.5943, -66.7466,  ..., -31.8602, -35.7037, -48.6106],
         [-69.1042, -68.7961, -67.8672,  ..., -39.7601, -41.4094, -53.2567],
         [-67.2767, -67.4139, -68.0604,  ..., -63.2213, -64.2628, -67.5640]]),
 '3d': array([[0.00738  , 0.007713 , 0.00536  , ..., 0.00956  , 0.02512  ,
         0.0010195],
        [0.0008926, 0.0002233, 0.00241  , ..., 0.00905  , 0.00473  ,
         0.000507 ],
        [0.000391 , 0.003115 , 0.005142 , ..., 0.004242 , 0.00395  ,
         0.000588 ],
        ...,
        [0.       , 0.       , 0.02353  , ..., 0.001635 , 0.000309 ,
         0.003887 ],
        [0.       , 0.       , 0.00997  , ..., 0.01394  , 0.005226 ,
         0.00604  ],
        [0.0002294, 0.  

In [5]:
def conv1d(in_planes, out_planes, width=9, stride=1, bias=False):
    """1xd convolution with padding"""
    if width % 2 == 0:
        pad_amt = int(width / 2)
    else:
        pad_amt = int((width - 1) / 2)
    return nn.Conv2d(in_planes, out_planes, kernel_size=(1, width), stride=stride, padding=(0,pad_amt), bias=bias)

class SpeechBasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, width=9, stride=1, downsample=None):
        super(SpeechBasicBlock, self).__init__()
        self.conv1 = conv1d(inplanes, planes, width=width, stride=stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv1d(planes, planes, width=width)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out
    
class ResDavenet(nn.Module):
    def __init__(self, feat_dim=40, block=SpeechBasicBlock, layers=[2, 2, 2, 2], layer_widths=[128, 128, 256, 512, 1024], convsize=9):
        super(ResDavenet, self).__init__()
        self.feat_dim = feat_dim
        self.inplanes = layer_widths[0]
        self.batchnorm1 = nn.BatchNorm2d(1)
        self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=(self.feat_dim,1), stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, layer_widths[1], layers[0], width=convsize, stride=2)
        self.layer2 = self._make_layer(block, layer_widths[2], layers[1], width=convsize, stride=2)
        self.layer3 = self._make_layer(block, layer_widths[3], layers[2], width=convsize, stride=2)
        self.layer4 = self._make_layer(block, layer_widths[4], layers[3], width=convsize, stride=2)
        if len(layers) == 6:
            self.layer5 = self._make_layer(block, layer_widths[5], layers[4], width=convsize, stride=2)
            self.layer6 = self._make_layer(block, layer_widths[6], layers[5], width=convsize, stride=2)
        else:
            self.layer5 = None
            self.layer6 = None

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, width=9, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )
        layers = []
        layers.append(block(self.inplanes, planes, width=width, stride=stride, downsample=downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, width=width, stride=1))
        return nn.Sequential(*layers)

    def forward(self, x):
        if x.dim() == 3:
            x = x.unsqueeze(1)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        if self.layer5 is not None:
            x = self.layer5(x)
            x = self.layer6(x)
        x = x.squeeze(2)
        return x

def load_DAVEnet(v2=True):
    if v2:
        audio_model = ResDavenet(feat_dim=40, layers=[2,2,1,1,1,1], convsize=9,
                                 layer_widths=[128,128,256,512,1024,2048,4096])
    else:
        audio_model = ResDavenet(feat_dim=40, layers=[2, 2, 2, 2], convsize=9,
                                 layer_widths=[128, 128, 256, 512, 1024])

    return audio_model

In [6]:
class Context_Gating(nn.Module):
    def __init__(self, dimension):
        super(Context_Gating, self).__init__()
        self.fc = nn.Linear(dimension, dimension)  

    def forward(self, x):
        x1 = self.fc(x)          
        x = torch.cat((x, x1), 1)   # 차원 = 2 * dimension
        return F.glu(x, 1)       # 차원 = dimension , glu가 반만 이용
    
class Gated_Embedding_Unit(nn.Module):
    def __init__(self, input_dimension, output_dimension):
        super(Gated_Embedding_Unit, self).__init__()
        self.fc = nn.Linear(input_dimension, output_dimension)  # 차원 맞추기
        self.cg = Context_Gating(output_dimension)              # Context Gating 

    def forward(self, x):
        x = self.fc(x)         
        x = self.cg(x)         
        return x               
    
class projection_net(nn.Module):
    def __init__(
            self,
            embed_dim=1024,
            video_dim=4096,
            we_dim=300,
            cross_attention=False
    ):
        super(projection_net, self).__init__()
        self.cross_attention = cross_attention

    # Fuse적용 X
        self.DAVEnet = load_DAVEnet(v2=True)
        self.GU_audio = Gated_Embedding_Unit(4096, embed_dim)
        self.GU_video = Gated_Embedding_Unit(video_dim, embed_dim)
        self.GU_text_captions = Gated_Embedding_Unit(we_dim, embed_dim)

    def forward(self, video, audio_input, nframes, text=None):
        audio = self.DAVEnet(audio_input) # [16, 1024, 320]
        audio = audio.permute(0,2,1)

        # text = self.GU_text_captions(self.text_pooling_caption(text)) # [16,30,300] -> [16,4096]
        text = self.GU_text_captions(text)
        audio = self.GU_audio(audio) 
        video = self.GU_video(video) 
        return audio, text, video

In [7]:
class FusionBlock(nn.Module):
    def __init__(
            self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
            drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_softmax=False):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.cross_attn = MultiHeadCrossAttention(d_model=dim, n_head=num_heads, use_softmax=use_softmax)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, k, q, attention_mask=None):
        output = q + self.drop_path(self.cross_attn(self.norm1(k), self.norm1(q)))#, attention_mask))  ##### 1) query만 residual
        output = output + self.drop_path(self.mlp(self.norm2(output)))
        return output

class ScaleDotProductAttention(nn.Module):
    def __init__(self, use_softmax):
        super(ScaleDotProductAttention, self).__init__()
        self.use_softmax = use_softmax
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, q, k, v, e=1e-12):
        # input size: [batch_size, head, length, d_tensor]
        batch_size, head, length, d_tensor = k.size()

        k_t = k.transpose(2,3)
        score = (q @ k_t) / math.sqrt(d_tensor)  # scaled dot product
        if self.use_softmax:
            score = self.softmax(score)  #[0,1]
        v = score @ v

        return v, score 
    
class MultiHeadCrossAttention(nn.Module):
    def __init__(self, d_model, n_head, use_softmax):
        super(MultiHeadCrossAttention, self).__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.use_softmax = use_softmax
        self.attention = ScaleDotProductAttention(use_softmax)

        self.w_k = nn.Linear(d_model, d_model)
        self.w_q = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_concat = nn.Linear(d_model, d_model)
    
    def forward(self, k, q):
        v = k
        q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)
        q, k, v = self.split(q), self.split(k), self.split(v)
        out, attention = self.attention(q, k, v)
        out = self.concat(out)
        out = self.w_concat(out)
        return out 

    def split(self, tensor):
        # [batch_size, length, d_model] -> [batch_size, head, length, d_model]
        batch_size, length, d_model = tensor.size()
        d_tensor = d_model // self.n_head
        tensor = tensor.view(batch_size, length, self.n_head, d_tensor).transpose(1,2)
        return tensor 
    
    def concat(self, tensor):
        batch_size, head, length, d_tensor = tensor.size()
        tensor = tensor.transpose(1,2).contiguous().view(batch_size, length, self.d_model)
        return tensor

In [8]:
class FusionTransformer(nn.Module):
    def __init__(self, embed_dim=1024, depth=1, num_heads=64, mlp_ratio=1, qkv_bias=True,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
                 act_layer=None,
                 use_cls_token=True,
                 num_classes=20,
                 use_softmax=False
                 ):
        super().__init__()

        self.embed_dim = embed_dim

        if use_cls_token:
            self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        else:
            self.cls_token = None

        self.masking_token = nn.Parameter(torch.zeros(embed_dim))

        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        act_layer = act_layer or nn.GELU
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        self.blocks = nn.Sequential(*[
            FusionBlock(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
                attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, use_softmax=use_softmax
            )
            for i in range(depth)])

        self.norm = norm_layer(embed_dim) # TODO: not needed, remove?

        self.mlp_head = nn.Linear(embed_dim, num_classes)

    def forward(self, key, query, key_modal='', query_modal=''):
        token_k = key
        token_q = query

        # FusionBlock (cross attnetion)
        for block in self.blocks:
            tokens = block(token_k, token_q)
        output = tokens

        return output

In [9]:
class Classifier(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Classifier, self).__init__()
        self.layer_1 = nn.Linear(latent_dim, 256)
        self.layer_2 = nn.Linear(256, 128)
        self.layer_3 = nn.Linear(128, num_classes)
    
    def forward(self, x):
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        x = F.relu(x)
        x = self.layer_3(x)
        return x

In [10]:
class EverythingAtOnceModel(nn.Module):
    def __init__(self,
                 args,
                 embed_dim=1024,
                 video_embed_dim=4096,
                 text_embed_dim=300,
                 video_max_tokens=None,
                 text_max_tokens=None,
                 audio_max_num_STFT_frames=None,
                 projection_dim=6144,
                 projection='gated',
                 strategy_audio_pooling='none',
                 davenet_v2=True,
                 individual_projections=True,
                 use_positional_emb=False
                 ):
        super().__init__()

        self.embed_dim = embed_dim
        self.use_softmax = True
        self.use_cls_token = False
        self.num_classes = 20

        self.fusion = FusionTransformer(embed_dim=self.embed_dim, use_softmax=self.use_softmax, use_cls_token=self.use_cls_token, num_classes = self.num_classes)

        self.token_projection = 'projection_net'

        self.individual_projections = individual_projections
        self.use_positional_emb = use_positional_emb
        self.strategy_audio_pooling = strategy_audio_pooling

        self.video_norm_layer = nn.LayerNorm(self.embed_dim, eps=1e-6)
        self.text_norm_layer = nn.LayerNorm(self.embed_dim, eps=1e-6)
        self.audio_norm_layer = nn.LayerNorm(self.embed_dim, eps=1e-6)
        self.norm_layer = nn.LayerNorm(self.embed_dim, eps=1e-6)

        # audio token preprocess
        # self.davenet = load_DAVEnet(v2=davenet_v2)
        self.token_proj = projection_net(embed_dim=self.embed_dim)

        # self.commonencoder=CommonEncoder(common_dim=self.embed_dim, latent_dim=512)
        self.classifier1 = Classifier(latent_dim=self.embed_dim, num_classes=self.num_classes)
        self.classifier2 = Classifier(latent_dim=self.embed_dim, num_classes=self.num_classes)
        self.classifier3 = Classifier(latent_dim=self.embed_dim, num_classes=self.num_classes)
    
    def extract_tokens(self, video, audio, text, nframes):
        audio, text, video = self.token_proj(video, audio, nframes, text)
        return audio, text, video

    def forward(self, video, audio, nframes, text, category, force_cross_modal=False):
        audio_raw_embed, text_raw_embed, video_raw_embed = self.extract_tokens(video, audio, text, nframes)
        video_raw_embed = torch.unsqueeze(video_raw_embed, 1) # ([16, 1, 1024] [16, 80, 1024] [16, 30, 1024]


        ## Visual - Audio
        va = self.fusion(key=video_raw_embed, query=audio_raw_embed) # [16, 80, 20]
        va = self.classifier1(va.mean(dim=1)) 

        ## Audio - Text
        at = self.fusion(key=audio_raw_embed, query=text_raw_embed) # [16, 30, 20]
        at = self.classifier2(at.mean(dim=1))

        ## Text - Video
        tv = self.fusion(key=text_raw_embed, query=video_raw_embed) # [16, 1, 20]
        tv = self.classifier3(tv.mean(dim=1))

        return va ,at, tv


In [13]:
import argparse 

parser = argparse.ArgumentParser()
parser.add_argument('--we_path', default='C:/Users/heeryung/code/24w_deep_daiv/GoogleNews-vectors-negative300.bin', type=str)
parser.add_argument('--data_path', default='C:/Users/heeryung/code/24w_deep_daiv/msrvtt_category_test.pkl', type=str)
parser.add_argument('--checkpoint_path', default='D:/download/epoch200.pth', type=str)
parser.add_argument('--token_projection', default='projection_net', type=str) 
parser.add_argument('--use_softmax', default=True, type=bool) 
parser.add_argument('--use_cls_token', default=False, type=bool) 
parser.add_argument('--num_classes', default=20, type=int) 
parser.add_argument('--batch_size', default=16, type=int) 
args, unknown = parser.parse_known_args()

ckpt_path = 'D:/Download/epoch200.pth'
checkpoint = torch.load(ckpt_path)

model = EverythingAtOnceModel(args).cuda()
model.load_state_dict(checkpoint['model_state_dict'])

RuntimeError: Error(s) in loading state_dict for EverythingAtOnceModel:
	Unexpected key(s) in state_dict: "davenet.batchnorm1.weight", "davenet.batchnorm1.bias", "davenet.batchnorm1.running_mean", "davenet.batchnorm1.running_var", "davenet.batchnorm1.num_batches_tracked", "davenet.conv1.weight", "davenet.bn1.weight", "davenet.bn1.bias", "davenet.bn1.running_mean", "davenet.bn1.running_var", "davenet.bn1.num_batches_tracked", "davenet.layer1.0.conv1.weight", "davenet.layer1.0.bn1.weight", "davenet.layer1.0.bn1.bias", "davenet.layer1.0.bn1.running_mean", "davenet.layer1.0.bn1.running_var", "davenet.layer1.0.bn1.num_batches_tracked", "davenet.layer1.0.conv2.weight", "davenet.layer1.0.bn2.weight", "davenet.layer1.0.bn2.bias", "davenet.layer1.0.bn2.running_mean", "davenet.layer1.0.bn2.running_var", "davenet.layer1.0.bn2.num_batches_tracked", "davenet.layer1.0.downsample.0.weight", "davenet.layer1.0.downsample.1.weight", "davenet.layer1.0.downsample.1.bias", "davenet.layer1.0.downsample.1.running_mean", "davenet.layer1.0.downsample.1.running_var", "davenet.layer1.0.downsample.1.num_batches_tracked", "davenet.layer1.1.conv1.weight", "davenet.layer1.1.bn1.weight", "davenet.layer1.1.bn1.bias", "davenet.layer1.1.bn1.running_mean", "davenet.layer1.1.bn1.running_var", "davenet.layer1.1.bn1.num_batches_tracked", "davenet.layer1.1.conv2.weight", "davenet.layer1.1.bn2.weight", "davenet.layer1.1.bn2.bias", "davenet.layer1.1.bn2.running_mean", "davenet.layer1.1.bn2.running_var", "davenet.layer1.1.bn2.num_batches_tracked", "davenet.layer2.0.conv1.weight", "davenet.layer2.0.bn1.weight", "davenet.layer2.0.bn1.bias", "davenet.layer2.0.bn1.running_mean", "davenet.layer2.0.bn1.running_var", "davenet.layer2.0.bn1.num_batches_tracked", "davenet.layer2.0.conv2.weight", "davenet.layer2.0.bn2.weight", "davenet.layer2.0.bn2.bias", "davenet.layer2.0.bn2.running_mean", "davenet.layer2.0.bn2.running_var", "davenet.layer2.0.bn2.num_batches_tracked", "davenet.layer2.0.downsample.0.weight", "davenet.layer2.0.downsample.1.weight", "davenet.layer2.0.downsample.1.bias", "davenet.layer2.0.downsample.1.running_mean", "davenet.layer2.0.downsample.1.running_var", "davenet.layer2.0.downsample.1.num_batches_tracked", "davenet.layer2.1.conv1.weight", "davenet.layer2.1.bn1.weight", "davenet.layer2.1.bn1.bias", "davenet.layer2.1.bn1.running_mean", "davenet.layer2.1.bn1.running_var", "davenet.layer2.1.bn1.num_batches_tracked", "davenet.layer2.1.conv2.weight", "davenet.layer2.1.bn2.weight", "davenet.layer2.1.bn2.bias", "davenet.layer2.1.bn2.running_mean", "davenet.layer2.1.bn2.running_var", "davenet.layer2.1.bn2.num_batches_tracked", "davenet.layer3.0.conv1.weight", "davenet.layer3.0.bn1.weight", "davenet.layer3.0.bn1.bias", "davenet.layer3.0.bn1.running_mean", "davenet.layer3.0.bn1.running_var", "davenet.layer3.0.bn1.num_batches_tracked", "davenet.layer3.0.conv2.weight", "davenet.layer3.0.bn2.weight", "davenet.layer3.0.bn2.bias", "davenet.layer3.0.bn2.running_mean", "davenet.layer3.0.bn2.running_var", "davenet.layer3.0.bn2.num_batches_tracked", "davenet.layer3.0.downsample.0.weight", "davenet.layer3.0.downsample.1.weight", "davenet.layer3.0.downsample.1.bias", "davenet.layer3.0.downsample.1.running_mean", "davenet.layer3.0.downsample.1.running_var", "davenet.layer3.0.downsample.1.num_batches_tracked", "davenet.layer4.0.conv1.weight", "davenet.layer4.0.bn1.weight", "davenet.layer4.0.bn1.bias", "davenet.layer4.0.bn1.running_mean", "davenet.layer4.0.bn1.running_var", "davenet.layer4.0.bn1.num_batches_tracked", "davenet.layer4.0.conv2.weight", "davenet.layer4.0.bn2.weight", "davenet.layer4.0.bn2.bias", "davenet.layer4.0.bn2.running_mean", "davenet.layer4.0.bn2.running_var", "davenet.layer4.0.bn2.num_batches_tracked", "davenet.layer4.0.downsample.0.weight", "davenet.layer4.0.downsample.1.weight", "davenet.layer4.0.downsample.1.bias", "davenet.layer4.0.downsample.1.running_mean", "davenet.layer4.0.downsample.1.running_var", "davenet.layer4.0.downsample.1.num_batches_tracked", "davenet.layer5.0.conv1.weight", "davenet.layer5.0.bn1.weight", "davenet.layer5.0.bn1.bias", "davenet.layer5.0.bn1.running_mean", "davenet.layer5.0.bn1.running_var", "davenet.layer5.0.bn1.num_batches_tracked", "davenet.layer5.0.conv2.weight", "davenet.layer5.0.bn2.weight", "davenet.layer5.0.bn2.bias", "davenet.layer5.0.bn2.running_mean", "davenet.layer5.0.bn2.running_var", "davenet.layer5.0.bn2.num_batches_tracked", "davenet.layer5.0.downsample.0.weight", "davenet.layer5.0.downsample.1.weight", "davenet.layer5.0.downsample.1.bias", "davenet.layer5.0.downsample.1.running_mean", "davenet.layer5.0.downsample.1.running_var", "davenet.layer5.0.downsample.1.num_batches_tracked", "davenet.layer6.0.conv1.weight", "davenet.layer6.0.bn1.weight", "davenet.layer6.0.bn1.bias", "davenet.layer6.0.bn1.running_mean", "davenet.layer6.0.bn1.running_var", "davenet.layer6.0.bn1.num_batches_tracked", "davenet.layer6.0.conv2.weight", "davenet.layer6.0.bn2.weight", "davenet.layer6.0.bn2.bias", "davenet.layer6.0.bn2.running_mean", "davenet.layer6.0.bn2.running_var", "davenet.layer6.0.bn2.num_batches_tracked", "davenet.layer6.0.downsample.0.weight", "davenet.layer6.0.downsample.1.weight", "davenet.layer6.0.downsample.1.bias", "davenet.layer6.0.downsample.1.running_mean", "davenet.layer6.0.downsample.1.running_var", "davenet.layer6.0.downsample.1.num_batches_tracked", "commonencoder.feature_extractor.0.weight", "commonencoder.feature_extractor.0.bias", "commonencoder.feature_extractor.2.weight", "commonencoder.feature_extractor.2.bias", "commonencoder.feature_extractor.4.weight", "commonencoder.feature_extractor.4.bias". 
	size mismatch for classifier1.layer_1.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([256, 1024]).
	size mismatch for classifier2.layer_1.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([256, 1024]).
	size mismatch for classifier3.layer_1.weight: copying a param with shape torch.Size([256, 2048]) from checkpoint, the shape in current model is torch.Size([256, 1024]).

In [None]:
data = dataset[0]

In [None]:
model.eval()

video = data['video']#.cuda()
audio = data['audio']#.cuda()
text = data['text']#.cuda()
nframes = data['nframes']#.cuda()
category = data['category']#.cuda()

video = video.view(-1, video.shape[-1])
audio = audio.view(-1, audio.shape[-2], audio.shape[-1])
text = text.view(-1, text.shape[-2], text.shape[-1])

pred = model(video, audio, nframes, text, category)
pred_category = torch.argmax(pred, dim=1) 
# accuracy = torch.mean((pred_category == category).float()) 
print("Real category:", category, 'Pred category:', pred_category)