In [1]:
"""Simple simulation for time sampling"""
import sys
import os

# import cv2
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

sys.path.insert(0, os.path.abspath('..'))

from src.factories import ModelFactory
from src.factories import DatasetFactory
from src.utils.load_cfg import ConfigLoader
from src.utils.misc import MiscUtils
from src.models.pytorch_ssim.ssim import SSIM

from tools.complexity import (get_model_complexity_info,
                              is_supported_instance,
                              flops_to_string,
                              get_model_parameters_number)

# Collect and analyze FLOPS

In [2]:
def collect_flops(model, units='GMac', precision=3):
    """Wrapper to collect flops and number of parameters at each layer"""
    total_flops = model.compute_average_flops_cost()

    def accumulate_flops(self):
        if is_supported_instance(self):
            return self.__flops__ / model.__batch_counter__
        else:
            sum = 0
            for m in self.children():
                sum += m.accumulate_flops()
            return sum

    def flops_repr(self):
        accumulated_flops_cost = self.accumulate_flops()
        return ', '.join([flops_to_string(accumulated_flops_cost, units=units, precision=precision),
                          '{:.3%} MACs'.format(accumulated_flops_cost / total_flops),
                          self.original_extra_repr()])

    def add_extra_repr(m):
        m.accumulate_flops = accumulate_flops.__get__(m)
        flops_extra_repr = flops_repr.__get__(m)
        if m.extra_repr != flops_extra_repr:
            m.original_extra_repr = m.extra_repr
            m.extra_repr = flops_extra_repr
            assert m.extra_repr != m.original_extra_repr

    def del_extra_repr(m):
        if hasattr(m, 'original_extra_repr'):
            m.extra_repr = m.original_extra_repr
            del m.original_extra_repr
        if hasattr(m, 'accumulate_flops'):
            del m.accumulate_flops

    model.apply(add_extra_repr)

    # Retrieve flops and param at each layer and sub layer (2 levels)
    flops_dict, param_dict = {}, {}
    for i in model._modules.keys():
        item = model._modules[i]
        if isinstance(model._modules[i], torch.nn.modules.container.Sequential):
            for j in model._modules[i]._modules.keys():
                key = '{}-{}'.format(i, j)
                flops_dict[key] = item._modules[j].accumulate_flops()
                param_dict[key] = get_model_parameters_number(item._modules[j])
        else:
            flops_dict[i] = item.accumulate_flops()
            param_dict[i] = get_model_parameters_number(item)

    model.apply(del_extra_repr)
    return flops_dict, param_dict


def analyze_flops_all_models(model, device):
    # Compute FLOPS for different models
    model_factory = ModelFactory()

    # ----------------------------------------------------------------------------
    # RGB model
    print('RGB model')
    rgb_model = model.light_model.rgb

    macs, params = get_model_complexity_info(
        rgb_model,
        (3, 224, 244),
        as_strings=True,
        print_per_layer_stat=False,
    )

    print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
    print('{:<30}  {:<8}'.format('Number of parameters: ', params))
    print('\n{:<15} {:>12} {:>12} {:>12} {:>12}\n'.format(
        'Layer', 'Flops (GMac)', 'Param (M)', 'AccFlops', 'AccParam') + '-'*67)

    flops_dict, param_dict = collect_flops(rgb_model)
    total_flops, total_param = 0, 0
    flops_rgb_part1 = 0
    for k in flops_dict:
        total_flops += flops_dict[k]*1e-9
        total_param += param_dict[k]*1e-6
        if k == 'layer3-0':
            flops_rgb_part1 = total_flops * 1e9
        print('{:<15} {:>12.5f} {:>12.5f} {:>12.2f} {:>12.2f}'.format(
            k, flops_dict[k]*1e-9, param_dict[k]*1e-6, total_flops, total_param))
    del rgb_model
    flops_rgb_part2 = total_flops*1e9 - flops_rgb_part1

    # ----------------------------------------------------------------------------
    # Spec model
    print('\n\nSpec model')
    spec_model = model.light_model.spec

    macs, params = get_model_complexity_info(
        spec_model,
        (1, 224, 244),
        as_strings=True,
        print_per_layer_stat=False,
    )

    print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
    print('{:<30}  {:<8}'.format('Number of parameters: ', params))
    print('\n{:<15} {:>12} {:>12} {:>12} {:>12}\n'.format(
        'Layer', 'Flops (GMac)', 'Param (M)', 'AccFlops', 'AccParam') + '-'*67)

    flops_dict, param_dict = collect_flops(spec_model)
    total_flops, total_param = 0, 0
    flops_spec_part1 = 0
    for k in flops_dict:
        total_flops += flops_dict[k]*1e-9
        total_param += param_dict[k]*1e-6
        if k == 'layer3-0':
            flops_spec_part1 = total_flops * 1e9
        print('{:<15} {:>12.5f} {:>12.5f} {:>12.2f} {:>12.2f}'.format(
            k, flops_dict[k]*1e-9, param_dict[k]*1e-6, total_flops, total_param))
    del spec_model
    flops_spec_part2 = total_flops*1e9 - flops_spec_part1

    # ----------------------------------------------------------------------------
    # Hallucination model
    # create dummy model because the one in real model checks for num_segment
    # here we only look at complexity per frame
    print('\n\nHallu model')
    hallu_model = model_factory.generate(
        model_name='HalluConvLSTM',
        device=device,
        num_segments=1,  # Test per frame
        attention_dim=[32, 14, 14],
        rnn_input_dim=32,
        rnn_hidden_dim=32,
        rnn_num_layers=1,
        has_encoder_decoder=True,
    ).to(device)
    macs, params = get_model_complexity_info(
        hallu_model,
        (1, 32, 14, 14),  # Test per frame
        as_strings=True,
        print_per_layer_stat=False,
    )

    print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
    print('{:<30}  {:<8}'.format('Number of parameters: ', params))
    print('\n{:<15} {:>12} {:>12} {:>12} {:>12}\n'.format(
        'Layer', 'Flops (GMac)', 'Param (M)', 'AccFlops', 'AccParam') + '-'*67)

    flops_dict, param_dict = collect_flops(hallu_model)
    total_flops, total_param = 0, 0
    for k in flops_dict:
        total_flops += flops_dict[k]*1e-9
        total_param += param_dict[k]*1e-6
        print('{:<15} {:>12.5f} {:>12.5f} {:>12.2f} {:>12.2f}'.format(
            k, flops_dict[k]*1e-9, param_dict[k]*1e-6, total_flops, total_param))
    del hallu_model
    flops_hallu = total_flops*1e9

    # ----------------------------------------------------------------------------
    # Action recognition model
    # create dummy model because the one in real model checks for num_segment
    # here we only look at complexity per frame
    print('\n\nActreg model')
    actreg_model = model_factory.generate(
        model_name='ActregGRU',
        device=device,
        modality=['RGB', 'Spec'],
        num_segments=1,  # Test per frame
        num_class=[125, 352],
        dropout=0.5,
        feature_dim=2048,
        rnn_input_size=512,
        rnn_hidden_size=512,
        rnn_num_layers=1,
    ).to(device)

    macs, params = get_model_complexity_info(
        actreg_model,
        (1, 2*2048),  # Test 2 modalities
        as_strings=True,
        print_per_layer_stat=False,
    )

    print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
    print('{:<30}  {:<8}'.format('Number of parameters: ', params))
    print('\n{:<15} {:>12} {:>12} {:>12} {:>12}\n'.format(
        'Layer', 'Flops', 'Param', 'AccFlops', 'AccParam') + '-'*67)

    flops_dict, param_dict = collect_flops(actreg_model)
    total_flops, total_param = 0, 0
    for k in flops_dict:
        total_flops += flops_dict[k]
        total_param += param_dict[k]
        print('{:<15} {:>12.0f} {:>12.0f} {:>12.0f} {:>12.0f}'.format(
            k, flops_dict[k], param_dict[k], total_flops, total_param))
    del actreg_model
    flops_actreg = total_flops

    return {'rgb_part1': flops_rgb_part1,
            'rgb_part2': flops_rgb_part2,
            'spec_part1': flops_spec_part1,
            'spec_part2': flops_spec_part2,
            'hallu': flops_hallu,
            'actreg': flops_actreg,
            }

# All frame action recognition with and without sampling

In [3]:
def actreg_nosample(model, sample):
    # Feed to the feature extraction
    x = model.light_model(sample)

    # Feed to the actreg model with appropriate number of frames
    model.actreg_model.rnn.flatten_parameters()
    x = model.actreg_model.relu(model.actreg_model.fc1(x))
    x = x.view(-1, model.num_segments, model.actreg_model.rnn_input_size)
    x, _ = model.actreg_model.rnn(x, None)
    x = model.actreg_model.relu(x)

    output_list_nosampled = []
    for t in range(model.num_segments):
        output_list_nosampled.append(model.actreg_model.classify(x[:, t, :]))

    score_verb_nosampled = np.stack([F.softmax(item[0], dim=1)[0].cpu()
                                     for item in output_list_nosampled], axis=0)
    score_noun_nosampled = np.stack([F.softmax(item[1], dim=1)[0].cpu()
                                     for item in output_list_nosampled], axis=0)
    return score_verb_nosampled, score_noun_nosampled


def actreg_withsample(model, sample, ssim_list, theta):
    # Feed to the feature extraction
    x = model.light_model(sample)

    # Sample to ignore frames with low SSIM
    x = x[ssim_list > theta]
    num_sampled = x.shape[0]

    # Feed to the actreg model with appropriate number of frames
    model.actreg_model.rnn.flatten_parameters()
    x = model.actreg_model.relu(model.actreg_model.fc1(x))
    x = x.view(-1, num_sampled, model.actreg_model.rnn_input_size)
    x, _ = model.actreg_model.rnn(x, None)
    x = model.actreg_model.relu(x)

    output_list_sampled = []
    t1, t2 = 0, 0
    while t1 < model.num_segments:
        if ssim_list[t1] > theta:  # not removed frame -> compute new result
            output_list_sampled.append(model.actreg_model.classify(x[:, t2, :]))
            t2 += 1
        else:  # removed frame -> reuse prev result
            output_list_sampled.append(output_list_sampled[-1])
        t1 += 1
    score_verb_sampled = np.stack([F.softmax(item[0], dim=1)[0].cpu()
                                   for item in output_list_sampled], axis=0)
    score_noun_sampled = np.stack([F.softmax(item[1], dim=1)[0].cpu()
                                   for item in output_list_sampled], axis=0)
    return score_verb_sampled, score_noun_sampled

# Experiment on val set

In [4]:
# dataset_cfg = '../configs/dataset_cfgs/epickitchens_short.yaml'
dataset_cfg = '../configs/dataset_cfgs/epickitchens.yaml'
train_cfg = '../configs/train_cfgs/train_san_freeze_short.yaml'
model_cfg = '../configs/model_cfgs/pipeline2_rgbspec_san19pairfreeze_actreggru_halluconvlstm.yaml'
theta = -0.25

# weight = 'saved_models/san19freeze_halluconvlstm_actreggru/dim32_layer1_nsegment3/epoch_00049.model'
# model_cfg_mod = None
weight = '../saved_models/san19freeze_halluconvlstm_actreggru/dim32_layer1_nsegment10/epoch_00049.model'
model_cfg_mod = {'num_segments': 10, 'hallu_model_cfg': 'exp_cfgs/haluconvlstm_32_1.yaml'}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Build model and data loader
# Load configurations
model_name, model_params = ConfigLoader.load_model_cfg(model_cfg)
dataset_name, dataset_params = ConfigLoader.load_dataset_cfg(dataset_cfg)
train_params = ConfigLoader.load_train_cfg(train_cfg)

if model_cfg_mod is not None:
    model_params.update(model_cfg_mod)

dataset_params.update({
    'modality': model_params['modality'],
    'num_segments': model_params['num_segments'],
    'new_length': model_params['new_length'],
})

# Build model
model_factory = ModelFactory()
model = model_factory.generate(model_name, device=device,
                               model_factory=model_factory, **model_params)
model.load_model(weight)
model = model.to(device)
model.eval()

# Get training augmentation and transforms
train_augmentation = MiscUtils.get_train_augmentation(model.modality, model.crop_size)
train_transform, val_transform = MiscUtils.get_train_val_transforms(
    modality=model.modality,
    input_mean=model.input_mean,
    input_std=model.input_std,
    scale_size=model.scale_size,
    crop_size=model.crop_size,
    train_augmentation=train_augmentation,
)

# Data loader
dataset_factory = DatasetFactory()
loader_params = {
    'batch_size': train_params['batch_size'],
    'num_workers': train_params['num_workers'],
    'pin_memory': True,
}

val_dataset = dataset_factory.generate(dataset_name, mode='val',
                                       transform=val_transform,
                                       **dataset_params)
val_loader = DataLoader(val_dataset, shuffle=False, **loader_params)

# Analyze flops
flops_dict = analyze_flops_all_models(model, device)

RGB model
Computational complexity:       4.08 GMac
Number of parameters:           15.55 M 

Layer           Flops (GMac)    Param (M)     AccFlops     AccParam
-------------------------------------------------------------------
conv_in              0.01049      0.00019         0.01         0.00
bn_in                0.00700      0.00013         0.02         0.00
conv0                0.05597      0.00410         0.07         0.00
bn0                  0.00175      0.00013         0.08         0.00
layer0-0             0.05012      0.00287         0.13         0.01
layer0-1             0.05012      0.00287         0.18         0.01
layer0-2             0.05012      0.00287         0.23         0.01
conv1                0.05597      0.01638         0.28         0.03
bn1                  0.00175      0.00051         0.28         0.03
layer1-0             0.24619      0.04245         0.53         0.07
layer1-1             0.24619      0.04245         0.78         0.11
layer1-2             0

In [5]:
# Go through the dataset
ssim_criterion = SSIM(window_size=3, channel=32)
correct_verb_nosampled = []
correct_noun_nosampled = []
correct_verb_sampled = []
correct_noun_sampled = []
ssim_all = []

with torch.no_grad():
    # for i, (sample, target) in enumerate(val_loader):
    for (sample, target) in tqdm(val_loader):
        sample = {k: v.to(device) for k, v in sample.items()}
        target = {k: v.to(device) for k, v in target.items()}
        target_verb, target_noun = target['verb'].item(), target['noun'].item()

        # Get attention and hallucination
        model(sample)
        attn = model._attn[0]
        hallu = model._hallu[0]

        # Compute ssim
        ssim_list = np.zeros(model.num_segments)
        for t in range(1, model.num_segments):
            ssim = -ssim_criterion(attn[t].unsqueeze(dim=0),
                                   hallu[t-1].unsqueeze(dim=0)).item()
            ssim_list[t] = ssim

        # Compute score for each frame
        score_verb_nosampled, score_noun_nosampled = actreg_nosample(model, sample)
        score_verb_sampled, score_noun_sampled = actreg_withsample(model, sample, ssim_list, theta)

        # Collect results
        ssim_all.append(ssim_list)
        correct_verb_nosampled.append(score_verb_nosampled.argmax(axis=1) == target_verb)
        correct_noun_nosampled.append(score_noun_nosampled.argmax(axis=1) == target_noun)
        correct_verb_sampled.append(score_verb_sampled.argmax(axis=1) == target_verb)
        correct_noun_sampled.append(score_noun_sampled.argmax(axis=1) == target_noun)

        # if i % 100 == 0:
        #     print('{}/{}'.format(i, len(val_loader)))

# Print statistics
print('Accuracy without sampling')
print('- Verb: {:.4f}%'.format(np.concatenate(correct_verb_nosampled).mean() * 100))
print('- Noun: {:.4f}%'.format(np.concatenate(correct_noun_nosampled).mean() * 100))
print('Accuracy with sampling')
print('- Verb: {:.4f}%'.format(np.concatenate(correct_verb_sampled).mean() * 100))
print('- Noun: {:.4f}%'.format(np.concatenate(correct_noun_sampled).mean() * 100))

ssim_all = np.concatenate(ssim_all)

100%|██████████| 2398/2398 [22:15<00:00,  1.80it/s]

Accuracy without sampling
- Verb: 46.7348%
- Noun: 30.0083%
Accuracy with sampling
- Verb: 39.3036%
- Noun: 25.2335%





In [6]:
n_removed = (ssim_all <= theta).sum()
total_flops = len(ssim_all) * (flops_dict['rgb_part1'] + flops_dict['rgb_part2'] +
                               flops_dict['spec_part1'] + flops_dict['spec_part2'] +
                               flops_dict['hallu'] + flops_dict['actreg']
                               )
# sampled_flops = n_removed * (flops_dict['rgb_part1'] + flops_dict['hallu'])
saved_flops = n_removed * (flops_dict['rgb_part2'] + 
                           flops_dict['spec_part1'] + flops_dict['spec_part2'] +
                           flops_dict['actreg'])

print('Val set statistics:')
print('- Total FLOPS:               {:.2f}*1e9'.format(total_flops * 1e-9))
print('- Overall saved FLOPS:       {:.2f}*1e9'.format(saved_flops * 1e-9))
print('- Overall saving percentage: {:.2f}%'.format(saved_flops / total_flops * 100))
print('Per frame statistics (on average):')
print('- Original FLOPS per frame:  {:.2f}*1e9'.format(total_flops / len(ssim_all) * 1e-9))
print('- Saved FLOPS per frame:     {:.2f}*1e9'.format(saved_flops / len(ssim_all) * 1e-9))

Val set statistics:
- Total FLOPS:               196224.01*1e9
- Overall saved FLOPS:       99154.03*1e9
- Overall saving percentage: 50.53%
Per frame statistics (on average):
- Original FLOPS per frame:  8.18*1e9
- Saved FLOPS per frame:     4.13*1e9


In [8]:
# Double check the math
print('len frames =', len(ssim_all))
print('saved frames =', (ssim_all <= theta).sum())
print('saved frames percentage =', (ssim_all <= theta).sum() / len(ssim_all) * 100)

print('-'*10)
full_com = flops_dict['rgb_part1']  + flops_dict['rgb_part2'] + \
           flops_dict['spec_part1'] + flops_dict['spec_part2'] + \
           flops_dict['hallu']      + flops_dict['actreg']
saved_com = flops_dict['rgb_part2'] + flops_dict['spec_part1'] + flops_dict['spec_part2'] + flops_dict['actreg']
print('full_com per frame  =', full_com)
print('saved_com per frame =', saved_com)

print('-'*10)
full_com_total  = full_com  * len(ssim_all)
saved_com_total = saved_com * (ssim_all <= theta).sum()
print('full_com total  =', full_com_total)
print('saved_com total =', saved_com_total)
print('saved_com total percent =', saved_com_total / full_com_total * 100)

len frames = 23980
saved frames = 17020
saved frames percentage = 70.97581317764804
----------
full_com per frame  = 8182819269.000002
saved_com per frame = 5825736345.000002
----------
full_com total  = 196224006070620.03
saved_com total = 99154032591900.03
saved_com total percent = 50.53104081272043
