In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import logging
import types
from functools import partial

In [3]:
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold
from sklearn.metrics import roc_auc_score
from IPython.display import clear_output
import pytorch_lightning as pl
logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)

In [4]:
from lared_laughter.fusion.dataset import FatherDataset, FatherDatasetSubset
from lared_laughter.accel.dataset import AccelExtractor
from lared_laughter.audio.dataset import AudioLaughterExtractor
from lared_laughter.video.dataset import VideoExtractor
from lared_laughter.video.dataset.transforms import get_kinetics_val_transform
from lared_laughter.constants import annot_exp_path, datasets_path
from lared_laughter.utils import load_examples
from lared_laughter.fusion.model import FusionModel
from lared_laughter.accel.system import System as AccelSystem

In [5]:
class System(pl.LightningModule):
    def __init__(self, modalities, task='classification'):
        super().__init__()
       
        self.model = FusionModel(modalities)
        self.loss_fn = {
            'classification':F.binary_cross_entropy_with_logits,
            'regression': F.l1_loss
        }[task]

        self.performance_metric = {
            'classification': lambda input, target: roc_auc_score(target, input),
            'regression': F.l1_loss
        }[task]

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        output = self.model(batch).squeeze()
        loss = self.loss_fn(output, batch['label'].float())

        # Logging to TensorBoard by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=.001)
        return optimizer

    def validation_step(self, batch, batch_idx):
        output = self.model(batch).squeeze()
        val_loss = self.loss_fn(output, batch['label'].float())
        self.log('val_loss', val_loss)

        return (output, batch['label'])

    def validation_epoch_end(self, validation_step_outputs):
        all_outputs = torch.cat([o[0] for o in validation_step_outputs]).cpu()
        all_labels = torch.cat([o[1] for o in validation_step_outputs]).cpu()

        val_metric = self.performance_metric(all_outputs, all_labels)
        self.log('val_metric', val_metric)

    def test_step(self, batch, batch_idx):
        output = self.model(batch).squeeze()

        return (output, batch['index'], batch['label'])

    def test_epoch_end(self, test_step_outputs):
        all_outputs = torch.cat([o[0] for o in test_step_outputs]).cpu()
        all_indices = torch.cat([o[1] for o in test_step_outputs]).cpu()
        all_labels = torch.cat([o[2] for o in test_step_outputs]).cpu()

        test_metric = self.performance_metric(all_outputs, all_labels)
        self.test_results = {
            'metric': test_metric,
            'index': all_indices,
            'proba': all_outputs
        }
        self.log('test_metric', test_metric)

In [6]:
def train_accel_model(train_dl, task, log_name=None):
    system = AccelSystem('resnet', task)
    trainer = pl.Trainer(
        # callbacks=[EarlyStopping(monitor="val_loss", mode="min")],
        accelerator='gpu',
        log_every_n_steps=1,
        max_epochs=-1,
        enable_model_summary=False,
        enable_progress_bar=False,
        logger= pl.loggers.TensorBoardLogger(save_dir='logs/', version=log_name),
        enable_checkpointing=False)
    trainer.fit(system, train_dl)

    # freeze params
    for param in system.parameters():
        param.requires_grad = False

    return system.model

In [7]:
def do_fold(train_ds, test_ds, modalities, trainer_params={}, log_prefix=None, task='classification', deterministic=False):

    num_epochs = {
        ('audio',): 12,
        ('accel',): 3,#40,
        ('video',): 17
    }

    # data loaders
    batch_size = 16 if 'video' in modalities else 64
    data_loader_train = torch.utils.data.DataLoader(
        train_ds, batch_size=batch_size, shuffle=True, num_workers=10,
        collate_fn=None)
    data_loader_val = torch.utils.data.DataLoader(
        test_ds, batch_size=batch_size, shuffle=False, num_workers=10,
        collate_fn=None)

    system = System(modalities, task=task)
    trainer_fn = partial(pl.Trainer, **trainer_params)
    trainer = trainer_fn(
        accelerator='gpu',
        log_every_n_steps=1,
        max_epochs=num_epochs[modalities],
        logger= pl.loggers.TensorBoardLogger(
            save_dir='logs/', name='', 
            version=log_prefix+'_fusion' if log_prefix else None),
        deterministic=deterministic,
        enable_checkpointing=False)
        
    trainer.fit(system, data_loader_train)

    trainer.test(system, data_loader_val)
    return system.test_results

In [8]:
def do_cross_validation(ds, modalities, random_state, task='classification', log_prefix=None, first_fold=False, deterministic=False):
    cv_splits = list(KFold(n_splits=10, random_state=random_state, shuffle=True).split(range(len(ds))))
    if first_fold:
        # only do first fold
        cv_splits = [cv_splits[0]]
    else:
        # skip the first fold
        cv_splits = cv_splits[1:]

    all_results = []
    for f, (train_idx, test_idx) in enumerate(cv_splits):
        # create dss    
        train_ds = FatherDatasetSubset(ds, train_idx)
        test_ds = FatherDatasetSubset(ds, test_idx)

        fold_outputs = do_fold(train_ds, test_ds, modalities,
            log_prefix=log_prefix+f'fold{f}' if log_prefix else None,
            task=task,
            deterministic=deterministic)
        all_results.append(fold_outputs)
        clear_output(wait=False)

    outputs = [r['proba'].numpy() for r in all_results]
    indices = [r['index'].numpy() for r in all_results]
    metrics = [r['metric'] for r in all_results]
    return metrics, outputs, indices

In [9]:
def get_table(regression=False):
    examples = load_examples(os.path.join(annot_exp_path, 'processed', 'examples_without_calibration.csv'))
    if regression:
        examples.loc[examples['intensity'].isna(), 'intensity'] = 0
        examples.loc[~examples['pressed_key'], 'intensity'] = 0
        label_column = 'intensity'
    else:
        label_column = 'pressed_key'

    res = {}
    for input_modalities in [('accel',)]:
        input_modality_res = {}

        for label_modality in ['audio', 'video', 'av']:

            filtered_examples = examples[examples['condition'] == label_modality].reset_index()

            print(f'Using {len(filtered_examples)} examples')

            # create the feature datasets
            extractors = {}
            if 'audio' in input_modalities:
                audios_path = os.path.join(datasets_path, "loose", "lared_audios.pkl")
                extractors['audio'] = {
                    'extractor': AudioLaughterExtractor(audios_path, min_len=1.5, max_len=1.5), 
                    'id_column': 'hash'
                }
            if 'video' in input_modalities:
                videos_path = '/home/jose/data/lared_video/video'
                extractors['video'] = {
                    'extractor': VideoExtractor(videos_path,
                        transform=get_kinetics_val_transform(8, 256, False)), 
                    'id_column': 'hash'
                }
            if 'accel' in input_modalities:
                accel_ds_path = os.path.join(datasets_path, 'loose', 'accel_long.pkl')
                extractors['accel'] = {
                    'extractor': AccelExtractor(accel_ds_path, min_len=1.5, max_len=1.5), 
                    'id_column': 'hash'
                }
            ds = FatherDataset(filtered_examples, extractors, label_column=label_column, id_column='hash', )
            assert len(ds) == 1318
            input_modality_res[label_modality] = []
            for i in range(1):
                
                seed = 22+i
                pl.utilities.seed.seed_everything(seed, workers=True)

                metrics, probas, indices = do_cross_validation(ds,
                    first_fold=False,
                    modalities=input_modalities,
                    task='regression' if regression else 'classification',
                    deterministic=True,
                    random_state=seed,
                    log_prefix=f'({"-".join(input_modalities)})L({label_modality})_run{i}')

                input_modality_res[label_modality].append({
                    'metrics': metrics,
                    'probas': probas,
                    'indices': indices,
                    'seed': seed
                })
                
                torch.cuda.empty_cache()

        res['-'.join(input_modalities)] = input_modality_res
    return res

In [10]:
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
res = get_table(False)

In [57]:
def make_table(res):

    input_mod_map = {
        'accel': 'Acceleration',
        'video': 'Video',
        'audio': 'Audio',
    }

    label_mod_map = {
        'video': 'Video',
        'audio': 'Audio',
        'av': 'Audiovisual'
    }

    t = []
    for input_mod, input_res in res.items():
        index = pd.MultiIndex.from_tuples(
            [('', 'Input')] +
            [('Label Modality', l) for l in ['Audio', 'Video', 'AV']], names=["first", "second"])

        row = pd.Series(index=index)

        row[('', 'Input')] = input_mod_map[input_mod]
        for label_mod, label_res in input_res.items():
            metrics = np.concatenate([r['metrics'] for r in label_res])
            row[('Label Modality', label_mod_map[label_mod])] = f'{np.mean(metrics):.3f} ({np.std(metrics):.3f})'

        
        t.append(row)
    return pd.DataFrame(t)

In [58]:
print(make_table(res).to_latex(
    index=False,
))

\begin{tabular}{lllrl}
\toprule
             & \multicolumn{4}{l}{Label Modality} \\
       Input &          Audio &         Video &  AV &   Audiovisual \\
\midrule
Acceleration &  0.720 (0.042) & 0.697 (0.038) & NaN & 0.692 (0.042) \\
\bottomrule
\end{tabular}



  row = pd.Series(index=index)
  print(make_table(res).to_latex(
