## Papers
- [Video Summarization Using Deep Neural
Networks: A Survey](https://arxiv.org/pdf/2101.06072.pdf)
- [CLIP-It! Language-Guided Video Summarization](https://proceedings.neurips.cc/paper/2021/file/7503cfacd12053d309b6bed5c89de212-Paper.pdf)
- [DSnet](https://github.com/li-plus/DSNet)
- [DSnet paper](https://liplus.me/publication/dsnet/dsnet.pdf)
- [PGL-SUM: 2021, top 1 model](https://github.com/e-apostolidis/PGL-SUM)


In [1]:
import os
import shutil

import pandas as pd
import numpy as np

# TVSum dataset

In [2]:
tvsum_path = os.path.join('tvsum', 'ydata-tvsum50-v1_1')
tvsum_data_path = os.path.join(tvsum_path, 'data')

df_anno = pd.read_csv(os.path.join(tvsum_data_path, 'ydata-tvsum50-anno.tsv'), sep='\t', header=None)
df_anno.columns = ['video_id', 'category', 'annotation']
df_anno['annotation'] = df_anno['annotation'].str.split(',').map(lambda x: [int(i) for i in x])

df_info = pd.read_csv(os.path.join(tvsum_data_path, 'ydata-tvsum50-info.tsv'), sep='\t')

In [3]:
df_anno.head()

Unnamed: 0,video_id,category,annotation
0,AwmHb44_ouw,VT,"[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, ..."
1,AwmHb44_ouw,VT,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ..."
2,AwmHb44_ouw,VT,"[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, ..."
3,AwmHb44_ouw,VT,"[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, ..."
4,AwmHb44_ouw,VT,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ..."


In [4]:
df_info.head()

Unnamed: 0,category,video_id,title,url,length
0,VT,AwmHb44_ouw,#1306 How to change tires for off road vehicle...,https://www.youtube.com/watch?v=AwmHb44_ouw,5:54
1,VT,98MoyGZKHXc,How to use a tyre repair kit - Which? guide,https://www.youtube.com/watch?v=98MoyGZKHXc,3:07
2,VT,J0nA4VgnoCo,#0001: FLAT TIRE,https://www.youtube.com/watch?v=J0nA4VgnoCo,9:44
3,VT,gzDbaEs1Rlg,ŠKODA Tips How to Repair Your Tyre,https://www.youtube.com/watch?v=gzDbaEs1Rlg,4:48
4,VT,XzYM3PfTM4w,When to Replace Your Tires GMC,https://www.youtube.com/watch?v=XzYM3PfTM4w,1:51


In [5]:
category_map = {
        'VT': 'Changing Vehicle Tire',
        'VU': 'Getting Vehicle Unstuck',
        'GA': 'Grooming an Animal',
        'MS': 'Making Sandwich',
        'PK': 'Parkour',
        'PR': 'Parade',
        'FM': 'Flash Mob Gathering',
        'BK': 'Bee Keeping',
        'BT': 'Attempting Bike Tricks',
        'DS': 'Dog Show',
}

df_info['category'] = df_info['category'].map(category_map)
df_anno['category'] = df_anno['category'].map(category_map)
df_info.head()

Unnamed: 0,category,video_id,title,url,length
0,Changing Vehicle Tire,AwmHb44_ouw,#1306 How to change tires for off road vehicle...,https://www.youtube.com/watch?v=AwmHb44_ouw,5:54
1,Changing Vehicle Tire,98MoyGZKHXc,How to use a tyre repair kit - Which? guide,https://www.youtube.com/watch?v=98MoyGZKHXc,3:07
2,Changing Vehicle Tire,J0nA4VgnoCo,#0001: FLAT TIRE,https://www.youtube.com/watch?v=J0nA4VgnoCo,9:44
3,Changing Vehicle Tire,gzDbaEs1Rlg,ŠKODA Tips How to Repair Your Tyre,https://www.youtube.com/watch?v=gzDbaEs1Rlg,4:48
4,Changing Vehicle Tire,XzYM3PfTM4w,When to Replace Your Tires GMC,https://www.youtube.com/watch?v=XzYM3PfTM4w,1:51


In [6]:
df_info['category'].value_counts()

Changing Vehicle Tire      5
Getting Vehicle Unstuck    5
Grooming an Animal         5
Making Sandwich            5
Parkour                    5
Parade                     5
Flash Mob Gathering        5
Bee Keeping                5
Attempting Bike Tricks     5
Dog Show                   5
Name: category, dtype: int64

In [7]:
df_info[df_info['category'] == 'Grooming an Animal']

Unnamed: 0,category,video_id,title,url,length
10,Grooming an Animal,i3wAGJaaktw,"Pet Joy Spa Grooming Services | Brentwood, CA ...",https://www.youtube.com/watch?v=i3wAGJaaktw,2:36
11,Grooming an Animal,Bhxk-O1Y7Ho,Vlog #509 I'M A PUPPY DOG GROOMER! September 1...,https://www.youtube.com/watch?v=Bhxk-O1Y7Ho,7:30
12,Grooming an Animal,0tmA_C6XwfM,Nail clipper Gloria Pets professional grooming,https://www.youtube.com/watch?v=0tmA_C6XwfM,2:21
13,Grooming an Animal,3eYKfiOEJNs,Dog Grooming in Buenos Aires,https://www.youtube.com/watch?v=3eYKfiOEJNs,3:14
14,Grooming an Animal,xxdtq8mxegs,How to Clean Your Dog's Ears - Vetoquinol USA,https://www.youtube.com/watch?v=xxdtq8mxegs,2:24


In [8]:
df_anno.head()

Unnamed: 0,video_id,category,annotation
0,AwmHb44_ouw,Changing Vehicle Tire,"[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, ..."
1,AwmHb44_ouw,Changing Vehicle Tire,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ..."
2,AwmHb44_ouw,Changing Vehicle Tire,"[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, ..."
3,AwmHb44_ouw,Changing Vehicle Tire,"[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, ..."
4,AwmHb44_ouw,Changing Vehicle Tire,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ..."


In [9]:
df_anno.shape

(1000, 3)

In [10]:
df_anno['video_id'].value_counts()

AwmHb44_ouw    20
EE-bNr36nyA    20
z_6gVvQb2d0    20
fWutDQy1nnY    20
4wU_LUjG5Ic    20
VuWGsYPqAX8    20
JKpqYvAdIsw    20
xmEERLqJ2kU    20
byxOvuiIJV0    20
_xMr-HKMfVA    20
WxtbjNsCQ8A    20
uGu_10sucQo    20
Se3oxnaPsz0    20
98MoyGZKHXc    20
oDXZc0tZe04    20
qqR6AEXwxoQ    20
EYqVtI9YWJA    20
eQu1rNs0an0    20
JgHubY5Vw3Y    20
iVt07TCkFM0    20
E11zDS9XGzg    20
NyBmCxDoHJU    20
kLxoNp-UchI    20
jcoYJXDG9sw    20
RBCABdttQmI    20
91IHQYk1IQM    20
PJrm840pAUI    20
GsAD1KT1xo8    20
J0nA4VgnoCo    20
gzDbaEs1Rlg    20
XzYM3PfTM4w    20
HT5vyqe0Xaw    20
sTEELN-vY30    20
vdmoEJ5YbrQ    20
xwqBXPGE9pQ    20
akI8YFjEmUw    20
i3wAGJaaktw    20
Bhxk-O1Y7Ho    20
0tmA_C6XwfM    20
3eYKfiOEJNs    20
xxdtq8mxegs    20
WG0MBPpPC6I    20
Hl-__g2gn_A    20
Yi4Ij2NM7U4    20
37rzWOQsNIw    20
LRw_obCPUt0    20
cjibtmSLxQ4    20
b626MiF1ew4    20
XkqCExn6_Us    20
-esJrBWj2d8    20
Name: video_id, dtype: int64

In [11]:
df_anno['annotation'].str.len()

0      10597
1      10597
2      10597
3      10597
4      10597
       ...  
995     6912
996     6912
997     6912
998     6912
999     6912
Name: annotation, Length: 1000, dtype: int64

In [12]:
df_anno.explode('annotation').groupby('video_id')['annotation'].mean()

video_id
-esJrBWj2d8    1.973155
0tmA_C6XwfM    2.041280
37rzWOQsNIw    2.012722
3eYKfiOEJNs    1.965794
4wU_LUjG5Ic    1.958489
91IHQYk1IQM    2.018735
98MoyGZKHXc    2.030173
AwmHb44_ouw    1.978079
Bhxk-O1Y7Ho    1.900111
E11zDS9XGzg    1.965424
EE-bNr36nyA    1.960354
EYqVtI9YWJA    1.991707
GsAD1KT1xo8    2.001768
HT5vyqe0Xaw    1.951613
Hl-__g2gn_A    1.943739
J0nA4VgnoCo    1.928668
JKpqYvAdIsw    2.019161
JgHubY5Vw3Y    1.946399
LRw_obCPUt0    1.960880
NyBmCxDoHJU    2.091498
PJrm840pAUI    1.920562
RBCABdttQmI    1.973986
Se3oxnaPsz0    1.968139
VuWGsYPqAX8    1.947136
WG0MBPpPC6I    1.986868
WxtbjNsCQ8A    1.929834
XkqCExn6_Us    1.993740
XzYM3PfTM4w    1.950406
Yi4Ij2NM7U4    1.910451
_xMr-HKMfVA    2.029431
akI8YFjEmUw    2.008636
b626MiF1ew4    2.011791
byxOvuiIJV0    2.129460
cjibtmSLxQ4    1.905498
eQu1rNs0an0    1.958193
fWutDQy1nnY    1.991656
gzDbaEs1Rlg    1.971623
i3wAGJaaktw    1.983777
iVt07TCkFM0    2.001560
jcoYJXDG9sw    1.956850
kLxoNp-UchI    2.042300
oDXZc0t

In [13]:
df_anno['annotation'] = df_anno['annotation'].map(np.array)
label_df = df_anno.groupby('video_id')['annotation'].mean()
label_df = label_df / 5
label_df = label_df.reset_index()
label_df

Unnamed: 0,video_id,annotation
0,-esJrBWj2d8,"[0.39, 0.39, 0.39, 0.39, 0.39, 0.39, 0.39, 0.3..."
1,0tmA_C6XwfM,"[0.22999999999999998, 0.22999999999999998, 0.2..."
2,37rzWOQsNIw,"[0.65, 0.65, 0.65, 0.65, 0.65, 0.65, 0.65, 0.6..."
3,3eYKfiOEJNs,"[0.38, 0.38, 0.38, 0.38, 0.38, 0.38, 0.38, 0.3..."
4,4wU_LUjG5Ic,"[0.39, 0.39, 0.39, 0.39, 0.39, 0.39, 0.39, 0.3..."
5,91IHQYk1IQM,"[0.33999999999999997, 0.33999999999999997, 0.3..."
6,98MoyGZKHXc,"[0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.3..."
7,AwmHb44_ouw,"[0.58, 0.58, 0.58, 0.58, 0.58, 0.58, 0.58, 0.5..."
8,Bhxk-O1Y7Ho,"[0.5900000000000001, 0.5900000000000001, 0.590..."
9,E11zDS9XGzg,"[0.32999999999999996, 0.32999999999999996, 0.3..."


## SumMe

In [14]:
summe_path = os.path.join('summe')
summe_video_path = 'summe/video'
summe_ann_path ='summe/GT'

In [15]:
import scipy.io


res = {'video_id': [], 'annotation': []}

for mat_file in os.listdir(summe_ann_path):
    res['video_id'].append(mat_file.replace(' ', '_').split('.')[0].lower())
    
    mat = scipy.io.loadmat(summe_ann_path + '/' + mat_file)
    res['annotation'].append(mat['gt_score'].flatten())
    
    
label_summe_df = pd.DataFrame(res)
label_summe_df

Unnamed: 0,video_id,annotation
0,air_force_one,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
1,base_jumping,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
2,bearpark_climbing,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
3,bike_polo,"[0.2, 0.2, 0.2, 0.2, 0.06666666666666667, 0.06..."
4,bus_in_rock_tunnel,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
5,car_over_camera,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
6,car_railcrossing,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
7,cockpit_landing,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
8,cooking,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
9,eiffel_tower,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."


## DataLoader

In [21]:
import h5py

In [18]:
import torch
from torch.utils.data import Dataset, DataLoader
import h5py
import numpy as np
import json


class VideoData(Dataset):
    DATASETS = {
        'summe': 'PGL-SUM/data/SumMe/eccv16_dataset_summe_google_pool5.h5',
        'tvsum': 'PGL-SUM/data/TVSum/eccv16_dataset_tvsum_google_pool5.h5',
    }
    
    
    def __init__(self, mode, video_type, split_index):
        """ Custom Dataset class wrapper for loading the frame features and ground truth importance scores.
        :param str mode: The mode of the model, train or test.
        :param str video_type: The Dataset being used, SumMe or TVSum.
        :param int split_index: The index of the Dataset split being used.
        """
        self.mode = mode
#         self.name = video_type.lower()
#         self.datasets = ['PGL-SUM/data/SumMe/eccv16_dataset_summe_google_pool5.h5',
#                          'PGL-SUM/data/TVSum/eccv16_dataset_tvsum_google_pool5.h5']
#         self.splits_filename = ['PGL-SUM/data/splits/' + video_type.lower() + '_splits.json']

        self.split_filename = 'PGL-SUM/data/splits/' + video_type.lower() + '_splits.json'
        self.split_index = split_index  # it represents the current split (varies from 0 to 4)
        
        self.filename = self.DATASETS[self.splits_filename]

#         if 'summe' in self.splits_filename[0]:
#             self.filename = self.datasets[0]
#             self.filename = self.DATASETS[0]
#         elif 'tvsum' in self.splits_filename[0]:
#             self.filename = self.datasets[1]
            
        hdf = h5py.File(self.filename, 'r')
        self.list_frame_features = []
        self.list_gtscores = []

        with open(self.splits_filename[0]) as f:
            data = json.loads(f.read())
            for i, split in enumerate(data):
                if i == self.split_index:
                    self.split = split
                    break

        for video_name in self.split[self.mode + '_keys']:
            frame_features = torch.Tensor(np.array(hdf[video_name + '/features']))
            gtscore = torch.Tensor(np.array(hdf[video_name + '/gtscore']))

            self.list_frame_features.append(frame_features)
            self.list_gtscores.append(gtscore)

        hdf.close()

    def __len__(self):
        """ Function to be called for the `len` operator of `VideoData` Dataset. """
        self.len = len(self.split[self.mode+'_keys'])
        return self.len

    def __getitem__(self, index):
        """ Function to be called for the index operator of `VideoData` Dataset.
        train mode returns: frame_features and gtscores
        test mode returns: frame_features and video name
        :param int index: The above-mentioned id of the data.
        """
        video_name = self.split[self.mode + '_keys'][index]
        frame_features = self.list_frame_features[index]
        gtscore = self.list_gtscores[index]

        if self.mode == 'test':
            return frame_features, video_name
        else:
            return frame_features, gtscore


def get_loader(mode, video_type, split_index):
    """ Loads the `data.Dataset` of the `split_index` split for the `video_type` Dataset.
    Wrapped by a Dataloader, shuffled and `batch_size` = 1 in train `mode`.
    :param str mode: The mode of the model, train or test.
    :param str video_type: The Dataset being used, SumMe or TVSum.
    :param int split_index: The index of the Dataset split being used.
    :return: The Dataset used in each mode.
    """
    if mode.lower() == 'train':
        vd = VideoData(mode, video_type, split_index)
        return DataLoader(vd, batch_size=1, shuffle=True)
    else:
        return VideoData(mode, video_type, split_index)


## Model

In [17]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class SelfAttention(nn.Module):
    def __init__(self, input_size=1024, output_size=1024, freq=10000, heads=1, pos_enc=None):
        """ The basic (multi-head) Attention 'cell' containing the learnable parameters of Q, K and V
        :param int input_size: Feature input size of Q, K, V.
        :param int output_size: Feature -hidden- size of Q, K, V.
        :param int freq: The frequency of the sinusoidal positional encoding.
        :param int heads: Number of heads for the attention module.
        :param str | None pos_enc: The type of the positional encoding [supported: Absolute, Relative].
        """
        super(SelfAttention, self).__init__()

        self.permitted_encodings = ["absolute", "relative"]
        if pos_enc is not None:
            pos_enc = pos_enc.lower()
            assert pos_enc in self.permitted_encodings, f"Supported encodings: {*self.permitted_encodings,}"

        self.input_size = input_size
        self.output_size = output_size
        self.heads = heads
        self.pos_enc = pos_enc
        self.freq = freq
        self.Wk, self.Wq, self.Wv = nn.ModuleList(), nn.ModuleList(), nn.ModuleList()
        for _ in range(self.heads):
            self.Wk.append(nn.Linear(in_features=input_size, out_features=output_size//heads, bias=False))
            self.Wq.append(nn.Linear(in_features=input_size, out_features=output_size//heads, bias=False))
            self.Wv.append(nn.Linear(in_features=input_size, out_features=output_size//heads, bias=False))
        self.out = nn.Linear(in_features=output_size, out_features=input_size, bias=False)

        self.softmax = nn.Softmax(dim=-1)
        self.drop = nn.Dropout(p=0.5)

    def getAbsolutePosition(self, T):
        """Calculate the sinusoidal positional encoding based on the absolute position of each considered frame.
        Based on 'Attention is all you need' paper (https://arxiv.org/abs/1706.03762)
        :param int T: Number of frames contained in Q, K and V
        :return: Tensor with shape [T, T]
        """
        freq = self.freq
        d = self.input_size

        pos = torch.tensor([k for k in range(T)], device=self.out.weight.device)
        i = torch.tensor([k for k in range(T//2)], device=self.out.weight.device)

        # Reshape tensors each pos_k for each i indices
        pos = pos.reshape(pos.shape[0], 1)
        pos = pos.repeat_interleave(i.shape[0], dim=1)
        i = i.repeat(pos.shape[0], 1)

        AP = torch.zeros(T, T, device=self.out.weight.device)
        AP[pos, 2*i] = torch.sin(pos / freq ** ((2 * i) / d))
        AP[pos, 2*i+1] = torch.cos(pos / freq ** ((2 * i) / d))
        return AP

    def getRelativePosition(self, T):
        """Calculate the sinusoidal positional encoding based on the relative position of each considered frame.
        r_pos calculations as here: https://theaisummer.com/positional-embeddings/
        :param int T: Number of frames contained in Q, K and V
        :return: Tensor with shape [T, T]
        """
        freq = self.freq
        d = 2 * T
        min_rpos = -(T - 1)

        i = torch.tensor([k for k in range(T)], device=self.out.weight.device)
        j = torch.tensor([k for k in range(T)], device=self.out.weight.device)

        # Reshape tensors each i for each j indices
        i = i.reshape(i.shape[0], 1)
        i = i.repeat_interleave(i.shape[0], dim=1)
        j = j.repeat(i.shape[0], 1)

        # Calculate the relative positions
        r_pos = j - i - min_rpos

        RP = torch.zeros(T, T, device=self.out.weight.device)
        idx = torch.tensor([k for k in range(T//2)], device=self.out.weight.device)
        RP[:, 2*idx] = torch.sin(r_pos[:, 2*idx] / freq ** ((i[:, 2*idx] + j[:, 2*idx]) / d))
        RP[:, 2*idx+1] = torch.cos(r_pos[:, 2*idx+1] / freq ** ((i[:, 2*idx+1] + j[:, 2*idx+1]) / d))
        return RP

    def forward(self, x):
        """ Compute the weighted frame features, based on either the global or local (multi-head) attention mechanism.
        :param torch.tensor x: Frame features with shape [T, input_size]
        :return: A tuple of:
                    y: Weighted features based on the attention weights, with shape [T, input_size]
                    att_weights : The attention weights (before dropout), with shape [T, T]
        """
        outputs = []
        for head in range(self.heads):
            K = self.Wk[head](x)
            Q = self.Wq[head](x)
            V = self.Wv[head](x)

            # Q *= 0.06                       # scale factor VASNet
            # Q /= np.sqrt(self.output_size)  # scale factor (i.e 1 / sqrt(d_k) )
            energies = torch.matmul(Q, K.transpose(1, 0))
            if self.pos_enc is not None:
                if self.pos_enc == "absolute":
                    AP = self.getAbsolutePosition(T=energies.shape[0])
                    energies = energies + AP
                elif self.pos_enc == "relative":
                    RP = self.getRelativePosition(T=energies.shape[0])
                    energies = energies + RP

            att_weights = self.softmax(energies)
            _att_weights = self.drop(att_weights)
            y = torch.matmul(_att_weights, V)

            # Save the current head output
            outputs.append(y)
        y = self.out(torch.cat(outputs, dim=1))
        return y, att_weights.clone()  # for now we don't deal with the weights (probably max or avg pooling)


class MultiAttention(nn.Module):
    def __init__(self, input_size=1024, output_size=1024, freq=10000, pos_enc=None,
                 num_segments=None, heads=1, fusion=None):
        """ Class wrapping the MultiAttention part of PGL-SUM; its key modules and parameters.
        :param int input_size: The expected input feature size.
        :param int output_size: The hidden feature size of the attention mechanisms.
        :param int freq: The frequency of the sinusoidal positional encoding.
        :param None | str pos_enc: The selected positional encoding [absolute, relative].
        :param None | int num_segments: The selected number of segments to split the videos.
        :param int heads: The selected number of global heads.
        :param None | str fusion: The selected type of feature fusion.
        """
        super(MultiAttention, self).__init__()

        # Global Attention, considering differences among all frames
        self.attention = SelfAttention(input_size=input_size, output_size=output_size,
                                       freq=freq, pos_enc=pos_enc, heads=heads)

        self.num_segments = num_segments
        if self.num_segments is not None:
            assert self.num_segments >= 2, "num_segments must be None or 2+"
            self.local_attention = nn.ModuleList()
            for _ in range(self.num_segments):
                # Local Attention, considering differences among the same segment with reduce hidden size
                self.local_attention.append(SelfAttention(input_size=input_size, output_size=output_size//num_segments,
                                                          freq=freq, pos_enc=pos_enc, heads=4))
        self.permitted_fusions = ["add", "mult", "avg", "max"]
        self.fusion = fusion
        if self.fusion is not None:
            self.fusion = self.fusion.lower()
            assert self.fusion in self.permitted_fusions, f"Fusion method must be: {*self.permitted_fusions,}"

    def forward(self, x):
        """ Compute the weighted frame features, based on the global and locals (multi-head) attention mechanisms.
        :param torch.Tensor x: Tensor with shape [T, input_size] containing the frame features.
        :return: A tuple of:
            weighted_value: Tensor with shape [T, input_size] containing the weighted frame features.
            attn_weights: Tensor with shape [T, T] containing the attention weights.
        """
        weighted_value, attn_weights = self.attention(x)  # global attention

        if self.num_segments is not None and self.fusion is not None:
            segment_size = math.ceil(x.shape[0] / self.num_segments)
            for segment in range(self.num_segments):
                left_pos = segment * segment_size
                right_pos = (segment + 1) * segment_size
                local_x = x[left_pos:right_pos]
                weighted_local_value, attn_local_weights = self.local_attention[segment](local_x)  # local attentions

                # Normalize the features vectors
                weighted_value[left_pos:right_pos] = F.normalize(weighted_value[left_pos:right_pos].clone(), p=2, dim=1)
                weighted_local_value = F.normalize(weighted_local_value, p=2, dim=1)
                if self.fusion == "add":
                    weighted_value[left_pos:right_pos] += weighted_local_value
                elif self.fusion == "mult":
                    weighted_value[left_pos:right_pos] *= weighted_local_value
                elif self.fusion == "avg":
                    weighted_value[left_pos:right_pos] += weighted_local_value
                    weighted_value[left_pos:right_pos] /= 2
                elif self.fusion == "max":
                    weighted_value[left_pos:right_pos] = torch.max(weighted_value[left_pos:right_pos].clone(),
                                                                   weighted_local_value)

        return weighted_value, attn_weights


class PGL_SUM(nn.Module):
    def __init__(self, input_size=1024, output_size=1024, freq=10000, pos_enc=None,
                 num_segments=None, heads=1, fusion=None):
        """ Class wrapping the PGL-SUM model; its key modules and parameters.
        :param int input_size: The expected input feature size.
        :param int output_size: The hidden feature size of the attention mechanisms.
        :param int freq: The frequency of the sinusoidal positional encoding.
        :param None | str pos_enc: The selected positional encoding [absolute, relative].
        :param None | int num_segments: The selected number of segments to split the videos.
        :param int heads: The selected number of global heads.
        :param None | str fusion: The selected type of feature fusion.
        """
        super(PGL_SUM, self).__init__()

        self.attention = MultiAttention(input_size=input_size, output_size=output_size, freq=freq,
                                        pos_enc=pos_enc, num_segments=num_segments, heads=heads, fusion=fusion)
        self.linear_1 = nn.Linear(in_features=input_size, out_features=input_size)
        self.linear_2 = nn.Linear(in_features=self.linear_1.out_features, out_features=1)

        self.drop = nn.Dropout(p=0.5)
        self.norm_y = nn.LayerNorm(normalized_shape=input_size, eps=1e-6)
        self.norm_linear = nn.LayerNorm(normalized_shape=self.linear_1.out_features, eps=1e-6)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, frame_features):
        """ Produce frames importance scores from the frame features, using the PGL-SUM model.
        :param torch.Tensor frame_features: Tensor of shape [T, input_size] containing the frame features produced by
        using the pool5 layer of GoogleNet.
        :return: A tuple of:
            y: Tensor with shape [1, T] containing the frames importance scores in [0, 1].
            attn_weights: Tensor with shape [T, T] containing the attention weights.
        """
        residual = frame_features
        weighted_value, attn_weights = self.attention(frame_features)
        y = weighted_value + residual
        y = self.drop(y)
        y = self.norm_y(y)

        # 2-layer NN (Regressor Network)
        y = self.linear_1(y)
        y = self.relu(y)
        y = self.drop(y)
        y = self.norm_linear(y)

        y = self.linear_2(y)
        y = self.sigmoid(y)
        y = y.view(1, -1)

        return y, attn_weights