In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import torch
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.sampler import BatchSampler, RandomSampler
from sklearn.model_selection import KFold, GroupKFold

In [3]:
from lared_laughter.fusion.dataset import FatherDataset, FatherDatasetSubset, CacheExtractor
from lared_laughter.accel.dataset import AccelExtractor
from lared_laughter.audio.dataset import AudioLaughterExtractor
from lared_laughter.video.dataset import VideoExtractor
from lared_laughter.video.dataset.transforms import get_kinetics_val_transform
from lared_laughter.constants import annot_exp_path, datasets_path
from lared_laughter.utils import get_metrics, load_examples
from lared_laughter.audio.models.resnet import get_pretrained_body as get_audio_feature_extractor
from lared_laughter.video.models.models import make_slow_pretrained_body as get_video_feature_extractor

In [4]:
examples = load_examples(os.path.join(annot_exp_path, 'processed', 'examples_without_calibration.csv'))

In [8]:
examples[(examples['hash'] == '72fe462eccb27f971193e5724c64a7c4537f04bc88593f180ef18121a9dfc779') & (examples['condition'] == 'video')][['_ini_time', '_end_time']]

Unnamed: 0,_ini_time,_end_time
2220,1233.738593,1242.419436
3856,1233.738593,1242.419436


In [14]:
accel_ds_path = os.path.join(datasets_path, 'loose', 'accel_long.pkl')
videos_path = os.path.join(datasets_path, 'loose', 'video')
audios_path = os.path.join(datasets_path, "loose", "lared_audios.pkl")
extractors = {
    'accel': AccelExtractor(accel_ds_path, min_len=1.5, max_len=1.5),
    'video': CacheExtractor(
        model = get_video_feature_extractor().cuda(),
        extractor = VideoExtractor(videos_path, transform=get_kinetics_val_transform(8, 256, False)),
        cache_path='./video_cache.pkl'
    ),
    'audio': CacheExtractor(
        model = get_audio_feature_extractor().cuda(),
        extractor = AudioLaughterExtractor(audios_path, min_len=1.5, max_len=1.5),
        cache_path='./audio_cache.pkl'
    )
}

def collate_fn(batch):
    batch = batch[0]
    return {k: torch.tensor(v) for k,v in batch.items()}

ds = FatherDataset(examples, extractors, label_column='intensity', id_column='hash')
train_ds = FatherDatasetSubset(ds, range(1,100), eval=False)
eval_ds  = FatherDatasetSubset(ds, range(1,100), eval=True)

g = torch.Generator()
g.manual_seed(22)
loader = DataLoader(
    dataset=train_ds,
    # This line below!
    sampler=BatchSampler(
        RandomSampler(train_ds, generator=g), batch_size=5, drop_last=False
    ),
    num_workers=0,
    generator=g,
    collate_fn=collate_fn
)


loaded pre-trained model
missing keys []
unexpected keys ['model.blocks.5.proj.weight', 'model.blocks.5.proj.bias']
loaded pre-trained model
missing keys []
unexpected keys ['bn2.weight', 'bn2.bias', 'bn2.running_mean', 'bn2.running_var', 'bn2.num_batches_tracked', 'bn3.weight', 'bn3.bias', 'bn3.running_mean', 'bn3.running_var', 'bn3.num_batches_tracked', 'linear1.weight', 'linear1.bias', 'linear2.weight', 'linear2.bias']


In [15]:
batch = next(iter(loader))
print((batch['accel'].shape, batch['audio'].shape, batch['video'].shape))

[10, 1, 58, 41, 56]
(torch.Size([5, 3, 30]), torch.Size([5, 16, 9, 16]), torch.Size([5, 2048, 1, 2, 2]))


In [7]:
extractors['video'].store()
extractors['audio'].store()

In [14]:
train_ds[[5,6,7]]['accel'].shape

(3, 3, 30)

In [14]:
extractors['video'].store()

In [13]:
ds[[2]]

{'accel': [array([[-0.04914867, -0.26731655, -0.26731655, ...,  0.        ,
           0.        ,  0.        ],
         [ 0.67281616,  0.67281616,  0.67281616, ...,  0.        ,
           0.        ,  0.        ],
         [ 0.6884482 ,  0.9163148 ,  0.46057492, ...,  0.        ,
           0.        ,  0.        ]], dtype=float32)],
 'label': array([6.]),
 'intval': array([[2.41776103, 3.91776103]]),
 'index': [2]}

In [14]:
extractors['accel'].store()

In [6]:
for i in range(len(ds)):
    s = ds.get_train_item(i)
    intval = s['intval']
    assert intval[0] < intval[1], i

In [31]:
ds[1]

{'accel': array([[ 0.99963385,  0.6424623 ,  0.6424623 , ...,  0.        ,
          0.        ,  0.        ],
        [-0.12390687, -0.12390687, -0.12390576, ...,  0.        ,
          0.        ,  0.        ],
        [ 0.65230256,  0.6523019 ,  0.44224632, ...,  0.        ,
          0.        ,  0.        ]], dtype=float32),
 'label': 4.0,
 'intval': [6.654659663666233, 7.487534029482958],
 'index': 1}

In [33]:
ds[[1,4]]

[('ec84022e240482976063d686d8e9b1d59e0cbf9b46e763f658e8f2d130d15bbe', 1.7983112267356356, 3.2983112267356356), ('08755169b187562233dc5fabd6268a969fa61915c0d14cc70c37a9b4aa3e38e7', 1.6789426242368215, 3.1789426242368215)]


{'accel': [array([[ 1.1782213 ,  1.1782213 ,  1.1782213 , ...,  0.        ,
           0.        ,  0.        ],
         [-0.12390796, -0.4699086 , -0.4699075 , ...,  0.        ,
           0.        ,  0.        ],
         [-0.18791772, -0.39797062,  0.02213785, ...,  0.        ,
           0.        ,  0.        ]], dtype=float32),
  array([[-0.48548397, -0.04999961, -0.04914819, ...,  0.        ,
           0.        ,  0.        ],
         [ 0.6728177 ,  1.0548579 ,  0.29077598, ...,  0.        ,
           0.        ,  0.        ],
         [ 0.00483601,  0.2327093 ,  1.1441808 , ...,  0.        ,
           0.        ,  0.        ]], dtype=float32)],
 'label': array([4., 4.]),
 'intval': array([[1.79831123, 3.29831123],
        [1.67894262, 3.17894262]]),
 'index': [1, 4]}

## Plot data from extractors