In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import types
from functools import partial

In [3]:
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold
from sklearn.metrics import roc_auc_score
from torch.utils.data import DataLoader, random_split, Subset
from IPython.display import clear_output
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

In [4]:
from lared_laughter.fusion.dataset import FeatureDataset, FatherDataset
from lared_laughter.accel.dataset import AccelDictDataset
from lared_laughter.constants import cloud_data_path
from lared_laughter.utils import get_metrics
from lared_laughter.fusion.model import FusionModel
from lared_laughter.accel.system import System as AccelSystem

In [5]:
class System(pl.LightningModule):
    def __init__(self, modalities, accel_feature_extractor, loss='classification'):
        super().__init__()
        feature_sizes = {
            'audio': 64,
            'video': 9216,
            'accel': 128
        }
        active_feature_sizes = {k: feature_sizes[k] for k in modalities}
        self.model = FusionModel(active_feature_sizes, accel_feature_extractor)
        self.loss_fn = {
            'classification':F.binary_cross_entropy_with_logits,
            'regression': F.l1_loss
        }[loss]

        self.performance_metric = {
            'classification': lambda input, target: roc_auc_score(target, input),
            'regression': F.l1_loss
        }[loss]

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        output = self.model(batch).squeeze()
        loss = self.loss_fn(output, batch['label'].float())

        # Logging to TensorBoard by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=.001)
        return optimizer

    def validation_step(self, batch, batch_idx):
        output = self.model(batch).squeeze()
        val_loss = self.loss_fn(output, batch['label'].float())
        self.log('val_loss', val_loss)

        return (output, batch['label'])

    def validation_epoch_end(self, validation_step_outputs):
        all_outputs = torch.cat([o[0] for o in validation_step_outputs]).cpu()
        all_labels = torch.cat([o[1] for o in validation_step_outputs]).cpu()

        val_metric = self.performance_metric(all_outputs, all_labels)
        self.log('val_metric', val_metric)

    def test_step(self, batch, batch_idx):
        output = self.model(batch).squeeze()

        return (output, batch['label'])

    def test_epoch_end(self, test_step_outputs):
        all_outputs = torch.cat([o[0] for o in test_step_outputs]).cpu()
        all_labels = torch.cat([o[1] for o in test_step_outputs]).cpu()

        test_metric = self.performance_metric(all_outputs, all_labels)
        self.test_results = {'metric': test_metric, 'proba': all_outputs}
        self.log('test_metric', test_metric)

In [6]:
def train_accel_model(train_dl, test_dl,log_name=None):
    system = AccelSystem('resnet')
    trainer = pl.Trainer(
        callbacks=[EarlyStopping(monitor="val_loss", mode="min")],
        accelerator='gpu',
        log_every_n_steps=1,
        max_epochs=-1,
        logger= pl.loggers.TensorBoardLogger(save_dir='logs/', version=log_name),
        enable_checkpointing=False)
    trainer.fit(system, train_dl, test_dl)

    # freeze params
    for param in system.parameters():
        param.requires_grad = False

    # replace the forward fn in the model to return the features directly
    def feature_extract(self, x):
        x = self.layers(x)
        return x.mean(dim=-1)

    system.model.forward = types.MethodType(feature_extract, system.model)

    # system.model.forward = feature_extract

    return system

In [7]:
def do_fold(train_ds, test_ds, modalities, trainer_params={}, log_prefix=None, loss='classification'):
    # data loaders
    data_loader_train = torch.utils.data.DataLoader(
        train_ds, batch_size=100, shuffle=True, num_workers=10,
        collate_fn=None)
    data_loader_val = torch.utils.data.DataLoader(
        test_ds, batch_size=100, shuffle=False, num_workers=10,
        collate_fn=None)

    accel_model = None
    if 'accel' in modalities:
        # train or load the accel method
        print('Training accel model..')
        accel_model = train_accel_model(data_loader_train, data_loader_val, 
            log_name=log_prefix+'_accel' if log_prefix else None)

    system = System(modalities, accel_model, loss=loss)
    trainer_fn = partial(pl.Trainer, **trainer_params)
    trainer = trainer_fn(
        callbacks=[EarlyStopping(monitor="val_loss", mode="min")] + trainer_params.get('callbacks', []),
        accelerator='gpu',
        log_every_n_steps=1,
        max_epochs=-1,
        logger= pl.loggers.TensorBoardLogger(save_dir='logs/', version=log_prefix+'_fusion' if log_prefix else None),
        enable_checkpointing=False)
    trainer.fit(system, data_loader_train, data_loader_val)

    trainer.test(system, data_loader_val)
    return system.test_results

In [8]:
def do_cross_validation(ds, modalities, metrics_name='classification', log_prefix=None):
    
    seed = 22
    cv_splits = KFold(n_splits=4, random_state=seed, shuffle=True).split(range(len(ds)))

    outputs = torch.empty((len(ds),))
    for f, (train_idx, test_idx) in enumerate(cv_splits):
        # create dss    
        train_ds = Subset(ds, train_idx)
        test_ds = Subset(ds, test_idx)

        fold_outputs = do_fold(train_ds, test_ds, modalities,
            log_prefix=log_prefix+f'fold{f}' if log_prefix else None,
            loss=metrics_name)
        outputs[test_idx] = fold_outputs['proba'].cpu()
        clear_output(wait=True)

    labels = torch.Tensor(ds.get_all_labels())
    run_metrics = get_metrics(outputs, labels, metrics_name)
    return outputs, run_metrics

In [9]:
examples = pd.read_csv('../dataset/computational_examples.csv')

In [10]:
examples.head()

Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,person,cam,hit_id,condition,calibration,hash,ini_time,end_time,...,gt_offset,gt_laughter,is_laughter,confidence,intensity,attempt,pressed_key,onset,offset,rating_hash
0,0,0,25,1,9c45e4f0c5442e796eb93e73e94dc6c2dfca7b9c4c54ff...,video,False,1170917790b51bc5a8dacacc4d8ed8c410b7ea6bb7ea4b...,7360.29,7361.54,...,4.420238,True,True,7,4,0,True,3.33667,6.639973,7af591213b827db95c12c56e76e0b1fe518f2088d11aad...
1,1947,1947,25,1,bff9b86d833a595e6fe5a54f45093fa168cda45db1143e...,video,False,1170917790b51bc5a8dacacc4d8ed8c410b7ea6bb7ea4b...,7360.29,7361.54,...,4.420238,True,True,1,6,0,True,2.569236,6.639973,25df21dc0f25e11a7c4aba77e502269d42a7bb548044f2...
2,546,546,25,1,4198c11729cea33268040a725998f16478a6564d4af091...,audio,False,1170917790b51bc5a8dacacc4d8ed8c410b7ea6bb7ea4b...,7360.29,7361.54,...,4.420238,True,True,7,5,0,True,2.10322,4.26322,2cb0148d83e939600a9e1d71872ba748334e4d8d0cafa0...
3,2440,2440,25,1,a9760ede24043c59a0151b09a46e866fa43f74bd60b682...,audio,False,1170917790b51bc5a8dacacc4d8ed8c410b7ea6bb7ea4b...,7360.29,7361.54,...,4.420238,True,True,7,6,0,True,2.78322,4.14322,b3cc8b0750211f5b5100f20641793ad043dc7cc823ad9a...
4,1058,1058,25,1,f4c9842cec7be99eeaaea36d0c7d077c4d5d94596dc731...,av,False,1170917790b51bc5a8dacacc4d8ed8c410b7ea6bb7ea4b...,7360.29,7361.54,...,4.420238,True,True,7,7,0,True,2.792656,3.893757,bf6cd2aeaf7c77c2c2ff873e6f603b7d46cd64c74e9ebd...


In [13]:
res = {}
for input_modalities in [['audio']]:
    input_modality_res = {}

    for label_modality in ['audio', 'video', 'av']:

        filtered_examples = examples[examples['condition'] == label_modality].reset_index()

        # create the feature datasets
        datasets = {}
        if 'audio' in input_modalities:
            datasets['audio'] = FeatureDataset('../audio/features/resnet_bigger.pkl', 
                key="resnet_bigger_bn2")
        if 'video' in input_modalities:
            datasets['video'] = FeatureDataset('../video/features/slowfast_test.pkl',
                key="proj_input")
        if 'accel' in input_modalities:
            accel_ds_path = os.path.join(cloud_data_path, 'accel', 'accel_ds.pkl')
            datasets['accel'] = AccelDictDataset(accel_ds_path, example_len=60)
        ds = FatherDataset(filtered_examples, datasets, id_column='hash', label_column='intensity')
        
        input_modality_res[label_modality] = []
        for i in range(1):
            _, metrics = do_cross_validation(ds, 
                modalities=input_modalities,
                metrics_name='regression',
                log_prefix=f'({"-".join(input_modalities)})L({label_modality})_run{i}')
            input_modality_res[label_modality].append(metrics)
            torch.cuda.empty_cache()

    res['-'.join(input_modalities)] = input_modality_res

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type        | Params
--------------------------------------
0 | model | FusionModel | 65    
--------------------------------------
65        Trainable params
0         Non-trainable params
65        Total params
0.000     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       test_metric          1.1928596496582031
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


In [14]:
res

{'audio': {'audio': [{'mse': tensor(2.5993), 'l1': tensor(1.2770)}],
  'video': [{'mse': tensor(2.2184), 'l1': tensor(1.1632)}],
  'av': [{'mse': tensor(2.5274), 'l1': tensor(1.2467)}]}}

: 