In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0,1"

import sys
print(sys.path)

['/home/nesl/Documents/IROS24/CED_Methods_Eval', '/home/nesl/anaconda3/envs/iros24/lib/python312.zip', '/home/nesl/anaconda3/envs/iros24/lib/python3.12', '/home/nesl/anaconda3/envs/iros24/lib/python3.12/lib-dynload', '', '/home/nesl/anaconda3/envs/iros24/lib/python3.12/site-packages']


In [2]:
import glob
import shutil
import random
import numpy as np
import pandas as pd
from os.path import join
from torch.utils.data import Dataset
import torch
from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary
import configparser
import matplotlib.pyplot as plt
from tqdm import tqdm
import datetime
import json

In [3]:
from torch.utils.data import DataLoader
from torch import nn
from torch import optim
from multimodal import *

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [5]:
time_window = 1
audio_rate = 16000
audio_input_length = int(audio_rate * time_window)
n_train_multi_data_per_class = 800
n_test_multi_data_per_class = 200

audio_train_set = ESC70Select(
    time_window=time_window,
    folds=[1, 2, 3, 4],
    transforms=lambda x: nn.functional.pad(x,((audio_input_length-x.shape[1])//2, (audio_input_length-x.shape[1])//2))\
        if (x.shape[1] % 2) == 0 \
        else nn.functional.pad(x,((audio_input_length-x.shape[1])//2, (audio_input_length-x.shape[1])//2 + 1)), 
    overwrite=False,
    use_bc_learning=False,
    audio_rate=audio_rate)

imu_train_set = WISDMSelect(
    folds=[1, 2, 3, 4],
    time_window=time_window,
    overwrite=False,
    normalize_acc=True)

audio_train_loader = torch.utils.data.DataLoader(audio_train_set, 
                                            batch_size=32, 
                                            shuffle=True)

imu_train_loader = DataLoader(imu_train_set, batch_size=128, 
                            shuffle=True, num_workers=4)

multimodal_train_set = MultimodalDataset(audio_train_set, 
                                         imu_train_set, 
                                         num_data_per_class=n_train_multi_data_per_class, 
                                         time_window=time_window,
                                         overwrite=False
                                         )

print('Audio data: ({}, {}, {})'.format(len(audio_train_set.sounds), 
                          audio_train_set.sounds[0].shape[0], 
                          audio_train_set.sounds[0].shape[1]))
print('IMU data: ({}, {}, {})'.format(len(imu_train_set.imus), 
                          imu_train_set.imus[0].shape[0], 
                          imu_train_set.imus[0].shape[1]))
print('Audio data: ({}, {}, {})'.format(len(multimodal_train_set.sounds), 
                          multimodal_train_set.sounds[0].shape[0], 
                          multimodal_train_set.sounds[0].shape[1]))
print('IMU data: ({}, {}, {})'.format(len(multimodal_train_set.imus), 
                          multimodal_train_set.imus[0].shape[0], 
                          multimodal_train_set.imus[0].shape[1]))

print(multimodal_train_set.get_label_mapping())

multimodal_train_loader = DataLoader(multimodal_train_set, batch_size=8, 
                            shuffle=True, num_workers=4)
 

./Audio/ESC50/meta/esc50.csv
./Audio/kitchen20/kitchen20.csv
./Audio/silent_sound/silent_sound.csv
loading  fold1
loading  fold2
loading  fold3
loading  fold4
loading  fold1
loading  fold2
loading  fold3
loading  fold4
Loading...
Audio data: (480, 1, 16000)
IMU data: (60613, 6, 20)
Audio data: (7200, 1, 16000)
IMU data: (7200, 6, 20)
{'brush_teeth': 0, 'click_mouse': 1, 'drink': 2, 'eat': 3, 'flush_toilet': 4, 'sit': 5, 'type': 6, 'walk': 7, 'wash': 8}


In [6]:
audio_test_set = ESC70Select(
    time_window=time_window,
    folds=[5],
    transforms=lambda x: nn.functional.pad(x,((audio_input_length-x.shape[1])//2, (audio_input_length-x.shape[1])//2))\
        if (x.shape[1] % 2) == 0 \
        else nn.functional.pad(x,((audio_input_length-x.shape[1])//2, (audio_input_length-x.shape[1])//2 + 1)), 
    overwrite=False,
    use_bc_learning=False,
    audio_rate=audio_rate)

audio_test_loader = DataLoader(audio_test_set, batch_size=32, 
                            shuffle=False, num_workers=2)

imu_test_set = WISDMSelect(
    folds=[5],
    time_window=time_window,
    overwrite=False,
    normalize_acc=True)

imu_test_loader = DataLoader(imu_test_set, batch_size=128, 
                            shuffle=False, num_workers=2)


multimodal_test_set = MultimodalDataset(audio_test_set, 
                                        imu_test_set, 
                                        num_data_per_class=n_test_multi_data_per_class,
                                        time_window=time_window,
                                        overwrite=False
                                        )


print('Audio data: ({}, {}, {})'.format(len(multimodal_test_set.sounds), 
                          multimodal_test_set.sounds[0].shape[0], 
                          multimodal_test_set.sounds[0].shape[1]))
print('IMU data: ({}, {}, {})'.format(len(multimodal_test_set.imus), 
                          multimodal_test_set.imus[0].shape[0], 
                          multimodal_test_set.imus[0].shape[1]))

multimodal_test_loader = DataLoader(multimodal_test_set, batch_size=8, 
                            shuffle=True, num_workers=4)

./Audio/ESC50/meta/esc50.csv
./Audio/kitchen20/kitchen20.csv
./Audio/silent_sound/silent_sound.csv
loading  fold5
loading  fold5
Saving multimodal dataset...
Loading...
Audio data: (1800, 1, 16000)
IMU data: (1800, 6, 20)


# Audio Module: BEATs
Use a pre-trained model to extract sound features.

In [7]:
sys.path.append('/home/liying/Documents/MS thesis/master-thesis/BEATs')
print(sys.path)

['/home/nesl/Documents/IROS24/CED_Methods_Eval', '/home/nesl/anaconda3/envs/iros24/lib/python312.zip', '/home/nesl/anaconda3/envs/iros24/lib/python3.12', '/home/nesl/anaconda3/envs/iros24/lib/python3.12/lib-dynload', '', '/home/nesl/anaconda3/envs/iros24/lib/python3.12/site-packages', '/home/liying/Documents/MS thesis/master-thesis/BEATs']


In [8]:
audio_train_set.get_label_mapping()

{'blender': 0,
 'no_sound': 1,
 'stove-burner': 2,
 'water-flowing': 3,
 'drawer': 4,
 'clean-dishes': 5,
 'chopping': 6,
 'eat': 7,
 'peel': 8,
 'toilet_flush': 9,
 'footsteps': 10,
 'brushing_teeth': 11,
 'drinking_sipping': 12,
 'mouse_click': 13,
 'keyboard_typing': 14}

In [9]:
from BEATs import BEATs, BEATsConfig

# load the pre-trained checkpoints
checkpoint = torch.load('./BEATs/BEATs_iter3_plus_AS2M.pt')

cfg = BEATsConfig(checkpoint['cfg'])
BEATs_model = BEATs(cfg)
BEATs_model.load_state_dict(checkpoint['model'])
BEATs_model.eval()

# extract the the audio representation
audio_input_16khz = audio_train_set.sounds[0]
padding_mask = torch.zeros_like(audio_train_set.sounds[0]).bool()

representation = BEATs_model.extract_features(audio_input_16khz, padding_mask=padding_mask)[0]

ModuleNotFoundError: No module named 'BEATs'

In [10]:
print(summary(BEATs_model))

Layer (type:depth-idx)                                  Param #
BEATs                                                   --
├─Linear: 1-1                                           393,984
├─Conv2d: 1-2                                           131,072
├─Dropout: 1-3                                          --
├─TransformerEncoder: 1-4                               --
│    └─Sequential: 2-1                                  --
│    │    └─Conv1d: 3-1                                 4,719,488
│    │    └─SamePad: 3-2                                --
│    │    └─GELU: 3-3                                   --
│    └─ModuleList: 2-2                                  --
│    │    └─TransformerSentenceEncoderLayer: 3-4        7,092,244
│    │    └─TransformerSentenceEncoderLayer: 3-5        7,092,244
│    │    └─TransformerSentenceEncoderLayer: 3-6        7,092,244
│    │    └─TransformerSentenceEncoderLayer: 3-7        7,092,244
│    │    └─TransformerSentenceEncoderLayer: 3-8        7,092,244

## Experiment: finetune model on ESC70Select dataset

In [11]:
class BEATsFinetuned(nn.Module):
    def __init__(self, BEATs_pretrained_model, n_class, predictor_dropout=0.0):
        super().__init__()
        self.BEATs = BEATs_pretrained_model
        self.predictor_dropout = nn.Dropout(predictor_dropout)
        self.predictor = nn.Linear(768, n_class)
 
    def forward(self, x):
        x = self.BEATs.extract_features(x, padding_mask=torch.zeros_like(x).bool())[0]
        x = self.predictor_dropout(x)
        logits = self.predictor(x).mean(dim=1)
        lprobs = torch.sigmoid(logits)
        return lprobs

    def extract_features(self, x):
        x = self.BEATs.extract_features(x, padding_mask=torch.zeros_like(x).bool())[0]
        x = self.predictor_dropout(x)
        logits = self.predictor(x).mean(dim=1)
        return logits

In [12]:
# Freeze the pretrained model
for param in BEATs_model.parameters():
    param.requires_grad = False

BEATs_finetuned_model = BEATsFinetuned(BEATs_model, audio_train_set.nClasses)
output = BEATs_finetuned_model(audio_train_set.sounds[0])
output.shape

torch.Size([1, 15])

### Training

In [13]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(BEATs_finetuned_model.parameters(), lr=0.001, momentum=0.9)
BEATs_finetuned_model.to(device)

print(audio_train_set.get_label_mapping())
# Training loop
n_epochs = 100
summary = {'loss': [[] for _ in range(n_epochs)], 'acc': [[] for _ in range(n_epochs)]}
for e in range(n_epochs):
    for i, (sounds, labels) in enumerate(audio_train_loader):
        optimizer.zero_grad()
        sounds = sounds.squeeze(dim=1).to(device)
        labels = labels.to(device)
        # Run the Net
        x = BEATs_finetuned_model(sounds)
        # print(x.shape)
        # print(labels.shape)
        # x = x.view(x.size()[:-1])

        # Optimize net
        loss = criterion(x, labels.long())
        loss.backward()
        optimizer.step()
        summary['loss'][e].append(loss.item())

            # Calculat accuracy
        _, pred = x.data.topk(1, dim=1)
        pred = pred.view(pred.shape[:-1])
        acc = torch.sum(pred == labels)/x.shape[0]
        summary['acc'][e].append(acc.item())
        
    print('Loss: {}, Accuracy: {}'.format(np.mean(summary['loss'][e]), np.mean(summary['acc'][e])))

{'no_sound': 0, 'blender': 1, 'stove-burner': 2, 'water-flowing': 3, 'drawer': 4, 'clean-dishes': 5, 'chopping': 6, 'eat': 7, 'peel': 8, 'toilet_flush': 9, 'footsteps': 10, 'brushing_teeth': 11, 'drinking_sipping': 12, 'mouse_click': 13, 'keyboard_typing': 14}
Loss: 2.703262436389923, Accuracy: 0.08125
Loss: 2.6863094727198282, Accuracy: 0.18125
Loss: 2.668513762950897, Accuracy: 0.3
Loss: 2.6522530714670816, Accuracy: 0.3645833333333333
Loss: 2.635932270685832, Accuracy: 0.4666666666666667
Loss: 2.620411479473114, Accuracy: 0.5125
Loss: 2.606025139490763, Accuracy: 0.5645833333333333
Loss: 2.5914759993553163, Accuracy: 0.6208333333333333
Loss: 2.577638653914134, Accuracy: 0.6666666666666666
Loss: 2.564376974105835, Accuracy: 0.7125
Loss: 2.551648751894633, Accuracy: 0.71875
Loss: 2.5392409880956013, Accuracy: 0.7375
Loss: 2.5272345622380574, Accuracy: 0.7625
Loss: 2.5158973693847657, Accuracy: 0.7604166666666666
Loss: 2.5044994155565896, Accuracy: 0.7666666666666667
Loss: 2.4939009388

### Testing

In [14]:
test_accuracy = []
for i, (sounds, labels) in enumerate(audio_test_loader):
        # Run the Net
        sounds = sounds.squeeze(dim=1).to(device)
        labels = labels.to(device)
        x = BEATs_finetuned_model(sounds)

        # loss = criterion(x, labels.long())
        # summary['loss'][e].append(loss.item())
        # Calculat accuracy
        _, pred = x.data.topk(1, dim=1)
        pred = pred.view(pred.shape[:-1])
        acc = torch.sum(pred == labels)/x.shape[0]
        summary['acc'][e].append(acc.item())
print(np.mean(summary['acc'][e]))

0.89


# IMU Module: LIMU-BERT
A pre-trained autoencoder model using LIMU-BERT architecture to extract acceleration features.

In [11]:
sys.path.append('/home/liying/Documents/MS thesis/master-thesis/LIMUBert')
print(sys.path)

['/home/liying/Documents/MS thesis/master-thesis', '/home/liying/miniconda3/envs/pytorch-gpu/lib/python39.zip', '/home/liying/miniconda3/envs/pytorch-gpu/lib/python3.9', '/home/liying/miniconda3/envs/pytorch-gpu/lib/python3.9/lib-dynload', '', '/home/liying/miniconda3/envs/pytorch-gpu/lib/python3.9/site-packages', '/home/liying/Documents/MS thesis/master-thesis/BEATs', '/home/liying/Documents/MS thesis/master-thesis/LIMUBert']


In [12]:
from LIMUBert.utils import load_model_config, Preprocess4Normalization, IMUDataset
from LIMUBert.models import LIMUBertModel4Pretrain

In [13]:
# Load LIMU-BERT model

model_cfg = load_model_config('pretrain_base', 'base', 'v1', path_bert='LIMUBert/config/limu_bert.json')
if model_cfg is None:
    print("Unable to find corresponding model config!")

pipeline = [Preprocess4Normalization(model_cfg.feature_num)]
LIMUBert_model = LIMUBertModel4Pretrain(model_cfg, output_embed=True)

In [14]:
# load the pre-trained checkpoints
checkpoint = torch.load('./LIMUBert/saved/pretrain_base_wisdm_20_100/wisdm.pt')

# cfg = BEATsConfig(checkpoint['cfg'])
# BEATs_model = BEATs(cfg)
LIMUBert_model.load_state_dict(checkpoint)
# LIMUBert_model.eval()
print(summary(LIMUBert_model))

Layer (type:depth-idx)                   Param #
LIMUBertModel4Pretrain                   --
├─Transformer: 1-1                       --
│    └─Embeddings: 2-1                   --
│    │    └─Linear: 3-1                  504
│    │    └─Embedding: 3-2               8,640
│    │    └─LayerNorm: 3-3               144
│    └─MultiHeadedSelfAttention: 2-2     --
│    │    └─Linear: 3-4                  5,256
│    │    └─Linear: 3-5                  5,256
│    │    └─Linear: 3-6                  5,256
│    └─Linear: 2-3                       5,256
│    └─LayerNorm: 2-4                    144
│    └─PositionWiseFeedForward: 2-5      --
│    │    └─Linear: 3-7                  10,512
│    │    └─Linear: 3-8                  10,440
│    └─LayerNorm: 2-6                    144
├─Linear: 1-2                            5,256
├─Linear: 1-3                            5,256
├─LayerNorm: 1-4                         144
├─Linear: 1-5                            438
Total params: 62,646
Trainable param

In [15]:
LIMUBert_model.to('cpu')
print(len(imu_test_loader.dataset.imus))
# for i, (imus, labels) in enumerate(imu_test_loader):
#     # Run the Net
# #     print(imus.shape)
#     imus = imus.transpose(-1, 1).to('cpu')
#     labels = labels.to('cpu')
#     x = LIMUBert_model(imus)
#     if i % 50 == 0:
#         print(x.shape)
# print(np.mean(summary['acc'][e]))

# test_output=LIMUBert_model(torch.rand(1, 100, 6))
# test_output.shape

14398


## Experiment: finetune model on WISDMSelect dataset

In [12]:
class LIMUBertFinetuned(nn.Module):
    def __init__(self, LIMUBert_pretrained_model, n_class, predictor_dropout=0.0):
        super().__init__()
        self.LIMUBert = LIMUBert_pretrained_model
        self.predictor_dropout = nn.Dropout(predictor_dropout)
#         self.predictor = nn.Linear(72, n_class)
        self.predictor = GRU(72, n_class)
 
    def forward(self, x):
        x = self.LIMUBert(x.transpose(-1,1)) # Input to LIMUBert model is N * L * C
        x = self.predictor_dropout(x)
#         logits = self.predictor(x).mean(dim=1)
        logits = self.predictor(x)
        lprobs = torch.sigmoid(logits)
        return lprobs

    def extract_features(self, x):
        x = self.LIMUBert(x.transpose(-1,1))
        x = self.predictor_dropout(x)
#         logits = self.predictor(x).mean(dim=1)
        logits = self.predictor(x)
        return logits
    
class GRU(nn.Module):
    def __init__(self, input_feature_dim, output_feature_dim, training=False):
        super().__init__()
        self.dropout = True
        self.num_rnn = 2
        self.num_linear = 1
        self.rnn_io = [[input_feature_dim, 20], [20, output_feature_dim]]
        self.num_layers = [2, 1]
        for i in range(self.num_rnn):
            self.__setattr__('gru' + str(i), nn.GRU(self.rnn_io[i][0], self.rnn_io[i][1], num_layers=self.num_layers[i],
                                         batch_first=True))
        

    def forward(self, input_seqs, training=False):
        h = input_seqs
        for i in range(self.num_rnn):
            rnn = self.__getattr__('gru' + str(i))
            h, _ = rnn(h)
            h = nn.functional.relu(h)
#         print(h.shape)  
        h = h[:, -1, :]
#         print(h.shape)
        if self.dropout:
            h = nn.functional.dropout(h, training=training)
        return h

In [13]:
# Freeze the pretrained model
for param in LIMUBert_model.parameters():
    param.requires_grad = False
    
LIMUBert_finetuned_model = LIMUBertFinetuned(LIMUBert_model, imu_train_set.nClasses)
# output = LIMUBert_finetuned_model(imu_train_set.imus[0].unsqueeze(0))
# output.shape
print(summary(LIMUBert_finetuned_model))

Layer (type:depth-idx)                        Param #
LIMUBertFinetuned                             --
├─LIMUBertModel4Pretrain: 1-1                 --
│    └─Transformer: 2-1                       --
│    │    └─Embeddings: 3-1                   (9,288)
│    │    └─MultiHeadedSelfAttention: 3-2     (15,768)
│    │    └─Linear: 3-3                       (5,256)
│    │    └─LayerNorm: 3-4                    (144)
│    │    └─PositionWiseFeedForward: 3-5      (20,952)
│    │    └─LayerNorm: 3-6                    (144)
│    └─Linear: 2-2                            (5,256)
│    └─Linear: 2-3                            (5,256)
│    └─LayerNorm: 2-4                         (144)
│    └─Linear: 2-5                            (438)
├─Dropout: 1-2                                --
├─GRU: 1-3                                    --
│    └─GRU: 2-6                               8,160
│    └─GRU: 2-7                               720
Total params: 71,526
Trainable params: 8,880
Non-trainable params

### Training

In [16]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(LIMUBert_finetuned_model.parameters(), lr=0.001, momentum=0.9)
# optimizer = optim.Adam(params=LIMUBert_finetuned_model.parameters(), lr=1e-3)
LIMUBert_finetuned_model.to(device)

imu_train_loader = DataLoader(imu_train_set, batch_size=128, 
                            shuffle=True, num_workers=4)
print(imu_train_set.get_label_mapping())
# Training loop
n_epochs = 50
summary = {'loss': [[] for _ in range(n_epochs)], 'acc': [[] for _ in range(n_epochs)]}
for e in range(n_epochs):
    for i, (imus, labels) in enumerate(imu_train_loader):
        optimizer.zero_grad()
        imus = imus.to(device)
        labels = labels.to(device)
        # Run the Net
        x = LIMUBert_finetuned_model(imus)
        # print(x.shape)
        # print(labels.shape)
        # x = x.view(x.size()[:-1])

        # Optimize net
        loss = criterion(x, labels.long())
        loss.backward()
        optimizer.step()
        summary['loss'][e].append(loss.item())

            # Calculat accuracy
        _, pred = x.data.topk(1, dim=1)
        pred = pred.view(pred.shape[:-1])
        acc = torch.sum(pred == labels)/x.shape[0]
        summary['acc'][e].append(acc.item())
        
    print('Loss: {}, Accuracy: {}'.format(np.mean(summary['loss'][e]), np.mean(summary['acc'][e])))

{'walking': 0, 'jogging': 1, 'sitting': 2, 'standing': 3, 'typing': 4, 'teeth': 5, 'pasta': 6, 'drinking': 7}
Loss: 2.078253234060187, Accuracy: 0.1540648496464679
Loss: 2.0764355885355097, Accuracy: 0.1673331767320633
Loss: 2.0747011887399776, Accuracy: 0.18660479329134289
Loss: 2.073084316755596, Accuracy: 0.20871240606433467
Loss: 2.0715565054040206, Accuracy: 0.22311795118607972
Loss: 2.070141440943668, Accuracy: 0.23756109036897358
Loss: 2.068834124113384, Accuracy: 0.24810385343275573
Loss: 2.06751742613943, Accuracy: 0.2566118422307466
Loss: 2.0662729765239516, Accuracy: 0.2641635339511068
Loss: 2.065074012154027, Accuracy: 0.2707706768261759
Loss: 2.063909068860506, Accuracy: 0.2780592105890575
Loss: 2.062872856541684, Accuracy: 0.2832307330871883
Loss: 2.0618502039658395, Accuracy: 0.28916353395110683
Loss: 2.060886337882594, Accuracy: 0.2919736843360098
Loss: 2.059885213249608, Accuracy: 0.2940836467240986
Loss: 2.058993831433748, Accuracy: 0.2954816730398881
Loss: 2.05809270

In [67]:
import copy
criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(LIMUBert_finetuned_model.parameters(), lr=0.001, momentum=0.9)
optimizer = optim.Adam(params=LIMUBert_finetuned_model.parameters(), lr=1e-3)
LIMUBert_finetuned_model.to(device)

imu_tmp_train_set = copy.deepcopy(imu_train_set)
# imu_tmp_train_set.imus[0][0:3] /= 9.8
imu_tmp_train_set.imus = [imu_tmp_train_set.imus[0]]*len(imu_tmp_train_set)
imu_tmp_train_set.labels = [imu_tmp_train_set.labels[0]]*len(imu_tmp_train_set)

imu_tmp_train_loader = DataLoader(imu_tmp_train_set, batch_size=128, 
                            shuffle=True, num_workers=4)



# Training loop
n_epochs = 100
summary = {'loss': [[] for _ in range(n_epochs)], 'acc': [[] for _ in range(n_epochs)]}
for e in range(n_epochs):
    for i, (imus, labels) in enumerate(imu_tmp_train_loader):
        optimizer.zero_grad()
        imus = imus.to(device)
        labels = labels.to(device)
        # Run the Net
        x = LIMUBert_finetuned_model(imus)
        # print(x.shape)
        # print(labels.shape)
        # x = x.view(x.size()[:-1])

        # Optimize net
        loss = criterion(x, labels.long())
        loss.backward()
        optimizer.step()
        summary['loss'][e].append(loss.item())

            # Calculat accuracy
        _, pred = x.data.topk(1, dim=1)
        pred = pred.view(pred.shape[:-1])
        acc = torch.sum(pred == labels)/x.shape[0]
        summary['acc'][e].append(acc.item())
        
    print('Loss: {}, Accuracy: {}'.format(np.mean(summary['loss'][e]), np.mean(summary['acc'][e])))

Loss: 1.8995697397934763, Accuracy: 1.0
Loss: 1.8805682696794208, Accuracy: 1.0
Loss: 1.880459771658245, Accuracy: 1.0
Loss: 1.8804107791499087, Accuracy: 1.0
Loss: 1.8804002197165237, Accuracy: 1.0
Loss: 1.8803832317653455, Accuracy: 1.0
Loss: 1.8803760930111533, Accuracy: 1.0
Loss: 1.8803772851040488, Accuracy: 1.0


KeyboardInterrupt: 

### Testing

In [18]:
test_summary = {'loss': [], 'acc': []}
for i, (imus, labels) in enumerate(imu_test_loader):
        # Run the Net
        imus = imus.to(device)
        labels = labels.to(device)
        x = LIMUBert_finetuned_model(imus)

        # loss = criterion(x, labels.long())
        # summary['loss'][e].append(loss.item())
        # Calculat accuracy
        _, pred = x.data.topk(1, dim=1)
        pred = pred.view(pred.shape[:-1])
        acc = torch.sum(pred == labels)/x.shape[0]
        test_summary['acc'].append(acc.item())
print(np.mean(test_summary['acc']))

0.3212503270679341


# Multimodal Model

## Generate Audio and IMU Embeddings

In [38]:
class AudioModule(nn.Module):
    def __init__(self, BEATs_pretrained_model, dropout_p=0.0):
        super().__init__()
        self.BEATs = BEATs_pretrained_model # need to freeze params
        self.dropout = nn.Dropout(dropout_p)
 
    def forward(self, x):
        x = self.BEATs.extract_features(x, padding_mask=torch.zeros_like(x).bool())[0]
        embeddings = self.dropout(x)
#         embeddings = embeddings.mean(dim=1)
        return embeddings


class IMUModule(nn.Module):
    def __init__(self, LIMUBert_pretrained_model, dropout_p=0.0):
        super().__init__()
        self.imu_model = LIMUBert_pretrained_model # need to freeze params
        self.dropout = nn.Dropout(dropout_p)
 
    def forward(self, x):
        x = self.imu_model(x)
        embeddings = self.dropout(x)
        return embeddings
    

class GRU(nn.Module):
    def __init__(self, input_feature_dim, output_feature_dim, training=False):
        super().__init__()
        self.dropout = True
        self.num_rnn = 2
        self.num_linear = 1
        self.rnn_io = [[input_feature_dim, 256], [256, output_feature_dim]]
        self.num_layers = [2, 1]
        for i in range(self.num_rnn):
            self.__setattr__('gru' + str(i), nn.GRU(self.rnn_io[i][0], self.rnn_io[i][1], num_layers=self.num_layers[i],
                                         batch_first=True))
        

    def forward(self, input_seqs, training=False):
        h = input_seqs
        for i in range(self.num_rnn):
            rnn = self.__getattr__('gru' + str(i))
            h, _ = rnn(h)
            h = nn.functional.relu(h)
#         print(h.shape)  
        h = h[:, -1, :]
#         print(h.shape)
        if self.dropout:
            h = nn.functional.dropout(h, training=training)
        return h

In [39]:
def get_multimodal_embed_dataset(multimodal_loader, audio_module, imu_module, device, overwrite=False):
    
    save_path = multimodal_loader.dataset.db_path.split('.npz')[0] + '_embeddings.npz'
    config_file = './Multimodal/dataset_config.json'
    
    if not os.path.isfile(save_path) or overwrite:
        dataset = {}
        dataset['audio_embeddings'] = []
        dataset['imu_embeddings'] = []
        dataset['labels'] = []
        
        audio_module.eval()
        imu_module.eval()

        for i, (sounds, imus, labels) in enumerate(multimodal_loader):
            sounds = sounds.squeeze(dim=1).to(device)
            imus = imus.permute(0, 2, 1).to(device)
            
            with torch.no_grad():
                audio_embeddings = audio_module(sounds).cpu().numpy()
                imu_embeddings = imu_module(imus).cpu().numpy()
            
            dataset['audio_embeddings'].append(audio_embeddings)
            dataset['imu_embeddings'].append(imu_embeddings)
            dataset['labels'].append(labels.numpy())
            
        dataset['audio_embeddings'] = np.concatenate(dataset['audio_embeddings'], axis=0)
        dataset['imu_embeddings'] = np.concatenate(dataset['imu_embeddings'], axis=0)
        dataset['labels'] = np.concatenate(dataset['labels'], axis=0)
        
        np.savez(save_path, **dataset)
        
    else:
        dataset = np.load(save_path, allow_pickle=True)
    
    if not os.path.isfile(config_file) or overwrite:
        dataset_config = {}
        dataset_config['db_path'] = save_path
        dataset_config['classes'] = multimodal_loader.dataset.classes
        dataset_config['nClasses'] = multimodal_loader.dataset.nClasses
        dataset_config['time_window'] = multimodal_loader.dataset.time_window
        dataset_config['num_data_per_class'] = multimodal_loader.dataset.num_data_per_class
        dataset_config['label_mapping'] = multimodal_loader.dataset.get_label_mapping()
        
        with open(config_file, 'w') as f:
            json.dump(dataset_config, f)
        
    else:
        with open(config_file, 'r') as f:
            dataset_config = json.load(f)
        
    return dataset, dataset_config

In [31]:
# get Multimodal Embedding Dataset for training and testing
audio_module = AudioModule(BEATs_model).to(device)
imu_module = IMUModule(LIMUBert_model).to(device)

embed_dataset, embed_dataset_config = get_multimodal_embed_dataset(DataLoader(multimodal_train_set, batch_size=32, 
                            shuffle=False, num_workers=4), audio_module, imu_module, device=device, overwrite=False)
multimodal_embed_train_set = MultimodalEmbed(embed_dataset, embed_dataset_config)

embed_dataset, embed_dataset_config = get_multimodal_embed_dataset(DataLoader(multimodal_test_set, batch_size=32, 
                            shuffle=False, num_workers=4), audio_module, imu_module, device=device, overwrite=False)
multimodal_embed_test_set = MultimodalEmbed(embed_dataset, embed_dataset_config)

In [32]:
multimodal_embed_test_set.imu_embeddings.shape

(1800, 20, 72)

In [33]:
# get DataLoader for training and testing
multimodal_embed_train_loader = DataLoader(multimodal_embed_train_set, batch_size=128, 
                            shuffle=True, num_workers=4)
multimodal_embed_test_loader = DataLoader(multimodal_embed_test_set, batch_size=128, 
                            shuffle=False, num_workers=4)

In [34]:
class AudioAndIMUFusion(nn.Module):
    def __init__(
        self,
        audio_feature_dim1,
        imu_feature_dim1,
        audio_feature_dim2,
        imu_feature_dim2,
        fusion_output_dim,
        output_dim,
        dropout_p=0.0,
        ):
        """
        
        """
        super().__init__()
        self.audio_gru = GRU(audio_feature_dim1, audio_feature_dim2)
        self.imu_gru = GRU(imu_feature_dim1, imu_feature_dim2)
        self.fusion = nn.Linear(
            in_features=(audio_feature_dim2 + imu_feature_dim2), 
            out_features=fusion_output_dim
            )
        self.fc = nn.Linear(
            in_features=fusion_output_dim, 
            out_features=output_dim
            )
        self.dropout = nn.Dropout(dropout_p)
    
    def forward(self, audio_embeddings, imu_embeddings):
        audio_features = self.audio_gru(audio_embeddings)
#         audio_features = audio_embeddings
#         print(audio_features.shape)
        imu_features = self.imu_gru(imu_embeddings)
        combined = torch.cat([audio_features, imu_features], dim=1)
        fused = self.dropout(
            nn.functional.relu(
                self.fusion(combined)
                )
            )
        logits = self.fc(fused)
        pred = torch.sigmoid(logits)
        return pred
    
    def extract_features(self, audio_embeddings, imu_embeddings):
        audio_features = self.audio_gru(audio_embeddings)
        imu_features = self.imu_gru(imu_embeddings)
        combined = torch.cat([audio_features, imu_features], dim=1)
        fused = self.dropout(
            nn.functional.relu(
                self.fusion(combined)
                )
            )
        return fused

In [40]:
# Freeze the two feature models
# for param in BEATs_model.parameters():
#     param.requires_grad = False
# for param in LIMUBert_model.parameters():
#     param.requires_grad = False

audio_feature_dim1 = 768
imu_feature_dim1 = 72
audio_feature_dim2 = 128
imu_feature_dim2 = 128
fusion_output_dim = 128
n_class = multimodal_embed_train_set.nClasses

multimodal_model = AudioAndIMUFusion(audio_feature_dim1, 
                                     imu_feature_dim1, 
                                     audio_feature_dim2, 
                                     imu_feature_dim2, 
                                     fusion_output_dim,
                                     n_class
                                    )

In [41]:
print(summary(multimodal_model))

TypeError: 'dict' object is not callable

## Training

In [43]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(multimodal_model.parameters(), lr=0.001, momentum=0.9)

# Place on GPU
multimodal_model.to(device)

# Training loop
n_epochs = 200
summary = {'loss': [[] for _ in range(n_epochs)], 'acc': [[] for _ in range(n_epochs)]}
for e in range(n_epochs):
    for i, (audio_embeds, imu_embeds, labels) in enumerate(tqdm(multimodal_embed_train_loader)):
        # Zero the grads
        optimizer.zero_grad()
        audio_embeds = audio_embeds.to(device)
        imu_embeds = imu_embeds.to(device)
        labels = labels.to(device)
        
        # Run the Net
        x = multimodal_model(audio_embeds, imu_embeds)
#         print(x.shape)
        # print(labels.shape)
        # x = x.view(x.size()[:-1])

        # Optimize net
        loss = criterion(x, labels.long())
        loss.backward()
        optimizer.step()
        summary['loss'][e].append(loss.item())

            # Calculat accuracy
        _, pred = x.data.topk(1, dim=1)
        pred = pred.view(pred.shape[:-1])
        acc = torch.sum(pred == labels)/x.shape[0]
        summary['acc'][e].append(acc.item())
 
    print('Epoch: {}, Loss: {}, Accuracy: {}'.format(e, np.mean(summary['loss'][e]), np.mean(summary['acc'][e])))

100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.62it/s]


Epoch: 0, Loss: 2.194183115373578, Accuracy: 0.22436951754385964


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.34it/s]


Epoch: 1, Loss: 2.193968421534488, Accuracy: 0.23081140350877194


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.13it/s]


Epoch: 2, Loss: 2.1938008509184184, Accuracy: 0.23848684210526316


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 28.82it/s]


Epoch: 3, Loss: 2.1936462009162234, Accuracy: 0.24300986842105263


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 27.76it/s]


Epoch: 4, Loss: 2.193410216716298, Accuracy: 0.24917763157894737


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.08it/s]


Epoch: 5, Loss: 2.1932818805962278, Accuracy: 0.25219298245614036


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 28.25it/s]


Epoch: 6, Loss: 2.1930450598398843, Accuracy: 0.2583607456140351


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 28.64it/s]


Epoch: 7, Loss: 2.192833938096699, Accuracy: 0.2649396929824561


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.38it/s]


Epoch: 8, Loss: 2.1925896636226723, Accuracy: 0.2759046052631579


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.01it/s]


Epoch: 9, Loss: 2.192352102513899, Accuracy: 0.28495065789473684


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 30.41it/s]


Epoch: 10, Loss: 2.192117745416206, Accuracy: 0.296875


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 30.09it/s]


Epoch: 11, Loss: 2.19190948469597, Accuracy: 0.2975603070175439


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 30.67it/s]


Epoch: 12, Loss: 2.1915958680604635, Accuracy: 0.30633223684210525


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.52it/s]


Epoch: 13, Loss: 2.1913199006465445, Accuracy: 0.3159265350877193


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.90it/s]


Epoch: 14, Loss: 2.1910416996269895, Accuracy: 0.3204495614035088


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 30.86it/s]


Epoch: 15, Loss: 2.190722754127101, Accuracy: 0.32497258771929827


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.68it/s]


Epoch: 16, Loss: 2.1904305115080716, Accuracy: 0.33292214912280704


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 31.38it/s]


Epoch: 17, Loss: 2.190071373655085, Accuracy: 0.3397752192982456


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 28.60it/s]


Epoch: 18, Loss: 2.189749600594504, Accuracy: 0.3540296052631579


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 30.75it/s]


Epoch: 19, Loss: 2.1893854852308308, Accuracy: 0.36828399122807015


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.13it/s]


Epoch: 20, Loss: 2.188965550640173, Accuracy: 0.37924890350877194


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.15it/s]


Epoch: 21, Loss: 2.188574657105563, Accuracy: 0.3959703947368421


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 30.64it/s]


Epoch: 22, Loss: 2.188086141619766, Accuracy: 0.40995065789473684


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 31.72it/s]


Epoch: 23, Loss: 2.187676245706123, Accuracy: 0.4095394736842105


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 28.17it/s]


Epoch: 24, Loss: 2.1871997389877054, Accuracy: 0.4243421052631579


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 28.47it/s]


Epoch: 25, Loss: 2.18669719445078, Accuracy: 0.42598684210526316


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 26.70it/s]


Epoch: 26, Loss: 2.186139658877724, Accuracy: 0.44353070175438597


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 30.05it/s]


Epoch: 27, Loss: 2.1855369325269733, Accuracy: 0.4557291666666667


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.25it/s]


Epoch: 28, Loss: 2.1849404259731897, Accuracy: 0.46710526315789475


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.59it/s]


Epoch: 29, Loss: 2.184230457272446, Accuracy: 0.4739583333333333


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.84it/s]


Epoch: 30, Loss: 2.1835963600560238, Accuracy: 0.49300986842105265


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 28.84it/s]


Epoch: 31, Loss: 2.182827167343675, Accuracy: 0.5145285087719298


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 28.87it/s]


Epoch: 32, Loss: 2.1820050749862405, Accuracy: 0.5272752192982456


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 28.55it/s]


Epoch: 33, Loss: 2.181138833363851, Accuracy: 0.5327576754385965


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 28.00it/s]


Epoch: 34, Loss: 2.1801978496083043, Accuracy: 0.5461896929824561


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.01it/s]


Epoch: 35, Loss: 2.179154027972305, Accuracy: 0.5648300438596491


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 28.45it/s]


Epoch: 36, Loss: 2.1780270191661097, Accuracy: 0.5689418859649122


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 30.17it/s]


Epoch: 37, Loss: 2.1768298358247993, Accuracy: 0.5681195175438597


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 28.73it/s]


Epoch: 38, Loss: 2.1755009534066185, Accuracy: 0.5753837719298246


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.02it/s]


Epoch: 39, Loss: 2.1740432580312095, Accuracy: 0.5971765350877193


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 27.91it/s]


Epoch: 40, Loss: 2.172477253696375, Accuracy: 0.6066337719298246


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 28.35it/s]


Epoch: 41, Loss: 2.17074587888885, Accuracy: 0.6160910087719298


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 30.27it/s]


Epoch: 42, Loss: 2.1689223992197135, Accuracy: 0.6278782894736842


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 30.07it/s]


Epoch: 43, Loss: 2.1667879882611727, Accuracy: 0.6380208333333334


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.43it/s]


Epoch: 44, Loss: 2.1645492629001013, Accuracy: 0.6403508771929824


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 26.76it/s]


Epoch: 45, Loss: 2.1619228480155006, Accuracy: 0.6355537280701754


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 27.09it/s]


Epoch: 46, Loss: 2.1589486557140685, Accuracy: 0.6225328947368421


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.11it/s]


Epoch: 47, Loss: 2.155693819648341, Accuracy: 0.6177357456140351


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 30.24it/s]


Epoch: 48, Loss: 2.1519038593559934, Accuracy: 0.620202850877193


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 30.16it/s]


Epoch: 49, Loss: 2.147798358348378, Accuracy: 0.6152686403508771


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 30.54it/s]


Epoch: 50, Loss: 2.143156884009378, Accuracy: 0.5974506578947368


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 28.40it/s]


Epoch: 51, Loss: 2.137959028545179, Accuracy: 0.5726425438596491


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.41it/s]


Epoch: 52, Loss: 2.1318868461408114, Accuracy: 0.5531798245614035


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 30.46it/s]


Epoch: 53, Loss: 2.125082626677396, Accuracy: 0.5416666666666666


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 30.81it/s]


Epoch: 54, Loss: 2.117576101370025, Accuracy: 0.5078125


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 28.37it/s]


Epoch: 55, Loss: 2.1091084187490896, Accuracy: 0.4765625


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 28.69it/s]


Epoch: 56, Loss: 2.09979978778906, Accuracy: 0.4605263157894737


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 30.16it/s]


Epoch: 57, Loss: 2.08948584606773, Accuracy: 0.44243421052631576


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.90it/s]


Epoch: 58, Loss: 2.0778987407684326, Accuracy: 0.4231085526315789


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.99it/s]


Epoch: 59, Loss: 2.066295360264025, Accuracy: 0.3998081140350877


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.05it/s]


Epoch: 60, Loss: 2.0536481497580543, Accuracy: 0.38308662280701755


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.43it/s]


Epoch: 61, Loss: 2.040314222636976, Accuracy: 0.3788377192982456


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 26.83it/s]


Epoch: 62, Loss: 2.027423885830662, Accuracy: 0.36499451754385964


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 30.03it/s]


Epoch: 63, Loss: 2.0140277461001745, Accuracy: 0.36499451754385964


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.45it/s]


Epoch: 64, Loss: 2.000983215214913, Accuracy: 0.3578673245614035


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 30.25it/s]


Epoch: 65, Loss: 1.988324941250316, Accuracy: 0.35183662280701755


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 30.21it/s]


Epoch: 66, Loss: 1.9768498801348502, Accuracy: 0.3432017543859649


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.71it/s]


Epoch: 67, Loss: 1.9663994458683751, Accuracy: 0.34196820175438597


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 28.65it/s]


Epoch: 68, Loss: 1.955798862273233, Accuracy: 0.35074013157894735


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 28.95it/s]


Epoch: 69, Loss: 1.9463492339117485, Accuracy: 0.3691063596491228


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 28.33it/s]


Epoch: 70, Loss: 1.9386420981925832, Accuracy: 0.4055646929824561


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 28.82it/s]


Epoch: 71, Loss: 1.9316917218660052, Accuracy: 0.40858004385964913


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 30.24it/s]


Epoch: 72, Loss: 1.9253009704121373, Accuracy: 0.41968201754385964


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 30.60it/s]


Epoch: 73, Loss: 1.9187896795440138, Accuracy: 0.3832236842105263


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.38it/s]


Epoch: 74, Loss: 1.9132465207785891, Accuracy: 0.3658168859649123


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 30.52it/s]


Epoch: 75, Loss: 1.9088310839837057, Accuracy: 0.3522478070175439


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 28.75it/s]


Epoch: 76, Loss: 1.9049855345173885, Accuracy: 0.3471765350877193


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 28.51it/s]


Epoch: 77, Loss: 1.9010453893427264, Accuracy: 0.3541666666666667


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.09it/s]


Epoch: 78, Loss: 1.8984284400939941, Accuracy: 0.3560855263157895


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 27.19it/s]


Epoch: 79, Loss: 1.8948384607047366, Accuracy: 0.35238486842105265


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 27.96it/s]


Epoch: 80, Loss: 1.8921494128411276, Accuracy: 0.3574561403508772


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 30.22it/s]


Epoch: 81, Loss: 1.8887460566403573, Accuracy: 0.37225877192982454


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 28.88it/s]


Epoch: 82, Loss: 1.8868696752347445, Accuracy: 0.37979714912280704


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 28.49it/s]


Epoch: 83, Loss: 1.884484512764111, Accuracy: 0.38308662280701755


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.18it/s]


Epoch: 84, Loss: 1.882176633466754, Accuracy: 0.39555921052631576


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 30.31it/s]


Epoch: 85, Loss: 1.8803682724634807, Accuracy: 0.4088541666666667


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 29.17it/s]


Epoch: 86, Loss: 1.8795105114317776, Accuracy: 0.41488486842105265


100%|███████████████████████████████████████████| 57/57 [00:01<00:00, 30.16it/s]


Epoch: 87, Loss: 1.8768789454510337, Accuracy: 0.4206414473684211


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 25.04it/s]


Epoch: 88, Loss: 1.875058270337289, Accuracy: 0.4457236842105263


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 23.69it/s]


Epoch: 89, Loss: 1.8737963856312267, Accuracy: 0.46066337719298245


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 23.95it/s]


Epoch: 90, Loss: 1.8715836311641492, Accuracy: 0.4809484649122807


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 23.30it/s]


Epoch: 91, Loss: 1.8702415792565596, Accuracy: 0.5072642543859649


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 23.06it/s]


Epoch: 92, Loss: 1.8689846344161452, Accuracy: 0.5182291666666666


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 23.96it/s]


Epoch: 93, Loss: 1.8675867611901802, Accuracy: 0.5313870614035088


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 23.85it/s]


Epoch: 94, Loss: 1.8654062852524875, Accuracy: 0.5323464912280702


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 23.95it/s]


Epoch: 95, Loss: 1.8634892848500035, Accuracy: 0.5271381578947368


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 24.17it/s]


Epoch: 96, Loss: 1.861466315754673, Accuracy: 0.5045230263157895


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 23.55it/s]


Epoch: 97, Loss: 1.8590222597122192, Accuracy: 0.5037006578947368


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 24.02it/s]


Epoch: 98, Loss: 1.856354322349816, Accuracy: 0.4897203947368421


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 24.23it/s]


Epoch: 99, Loss: 1.8537598952912449, Accuracy: 0.48739035087719296


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 23.62it/s]


Epoch: 100, Loss: 1.8509997129440308, Accuracy: 0.4683388157894737


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 23.59it/s]


Epoch: 101, Loss: 1.8479704020316141, Accuracy: 0.46450109649122806


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 23.76it/s]


Epoch: 102, Loss: 1.8452358036710506, Accuracy: 0.4613486842105263


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 23.82it/s]


Epoch: 103, Loss: 1.841891190461945, Accuracy: 0.4557291666666667


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 23.53it/s]


Epoch: 104, Loss: 1.8385020954567088, Accuracy: 0.45490679824561403


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 23.82it/s]


Epoch: 105, Loss: 1.8352773189544678, Accuracy: 0.4496984649122807


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 23.59it/s]


Epoch: 106, Loss: 1.831532827594824, Accuracy: 0.4682017543859649


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 23.54it/s]


Epoch: 107, Loss: 1.827790289594416, Accuracy: 0.48519736842105265


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 23.50it/s]


Epoch: 108, Loss: 1.8232437079412895, Accuracy: 0.5


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 23.94it/s]


Epoch: 109, Loss: 1.8193160170002987, Accuracy: 0.5098684210526315


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 23.86it/s]


Epoch: 110, Loss: 1.8144570379926448, Accuracy: 0.5400219298245614


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 23.70it/s]


Epoch: 111, Loss: 1.8093062430097346, Accuracy: 0.5219298245614035


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 23.76it/s]


Epoch: 112, Loss: 1.804750118339271, Accuracy: 0.4765625


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 23.70it/s]


Epoch: 113, Loss: 1.7988313373766447, Accuracy: 0.4575109649122807


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 23.78it/s]


Epoch: 114, Loss: 1.7940554848888464, Accuracy: 0.4570997807017544


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 23.95it/s]


Epoch: 115, Loss: 1.7884508852373089, Accuracy: 0.4466831140350877


100%|███████████████████████████████████████████| 57/57 [00:02<00:00, 22.70it/s]


Epoch: 116, Loss: 1.7830021736914652, Accuracy: 0.4439418859649123


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 16.91it/s]


Epoch: 117, Loss: 1.7772473473297923, Accuracy: 0.43146929824561403


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.08it/s]


Epoch: 118, Loss: 1.7715444857614082, Accuracy: 0.4361293859649123


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 16.94it/s]


Epoch: 119, Loss: 1.7669897267692967, Accuracy: 0.4331140350877193


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.26it/s]


Epoch: 120, Loss: 1.7612137480785972, Accuracy: 0.43516995614035087


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 16.88it/s]


Epoch: 121, Loss: 1.7555984697843854, Accuracy: 0.4465460526315789


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.24it/s]


Epoch: 122, Loss: 1.7509016300502576, Accuracy: 0.4420230263157895


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.17it/s]


Epoch: 123, Loss: 1.7459457736266286, Accuracy: 0.4465460526315789


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.30it/s]


Epoch: 124, Loss: 1.7410001461965996, Accuracy: 0.45285087719298245


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.28it/s]


Epoch: 125, Loss: 1.7365215493921649, Accuracy: 0.4557291666666667


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.09it/s]


Epoch: 126, Loss: 1.731098449021055, Accuracy: 0.45860745614035087


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.19it/s]


Epoch: 127, Loss: 1.7261441255870618, Accuracy: 0.45860745614035087


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.18it/s]


Epoch: 128, Loss: 1.722020057209751, Accuracy: 0.4603892543859649


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.31it/s]


Epoch: 129, Loss: 1.717756047583463, Accuracy: 0.4595668859649123


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.36it/s]


Epoch: 130, Loss: 1.713565232460959, Accuracy: 0.46011513157894735


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.31it/s]


Epoch: 131, Loss: 1.7099061305062813, Accuracy: 0.4616228070175439


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.20it/s]


Epoch: 132, Loss: 1.7055753992314924, Accuracy: 0.4777960526315789


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.24it/s]


Epoch: 133, Loss: 1.701400503777621, Accuracy: 0.4836896929824561


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.20it/s]


Epoch: 134, Loss: 1.697499850340057, Accuracy: 0.48862390350877194


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.11it/s]


Epoch: 135, Loss: 1.6937561181553624, Accuracy: 0.48807565789473684


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.18it/s]


Epoch: 136, Loss: 1.6904932678791516, Accuracy: 0.4876644736842105


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.25it/s]


Epoch: 137, Loss: 1.686907216122276, Accuracy: 0.49890350877192985


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.13it/s]


Epoch: 138, Loss: 1.6832357080359208, Accuracy: 0.4960252192982456


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.33it/s]


Epoch: 139, Loss: 1.6798090913839507, Accuracy: 0.49588815789473684


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.14it/s]


Epoch: 140, Loss: 1.6767173273521556, Accuracy: 0.49136513157894735


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.18it/s]


Epoch: 141, Loss: 1.6729628516916644, Accuracy: 0.49575109649122806


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 16.60it/s]


Epoch: 142, Loss: 1.6697233894415069, Accuracy: 0.4987664473684211


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 16.91it/s]


Epoch: 143, Loss: 1.6668849995261745, Accuracy: 0.5023300438596491


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.14it/s]


Epoch: 144, Loss: 1.6639115831308198, Accuracy: 0.5050712719298246


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 16.39it/s]


Epoch: 145, Loss: 1.6608276032564933, Accuracy: 0.5017817982456141


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.10it/s]


Epoch: 146, Loss: 1.6577759834758021, Accuracy: 0.5015076754385965


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.52it/s]


Epoch: 147, Loss: 1.6547612892953973, Accuracy: 0.5010964912280702


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.38it/s]


Epoch: 148, Loss: 1.6519467391465839, Accuracy: 0.501233552631579


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.54it/s]


Epoch: 149, Loss: 1.6496032183630425, Accuracy: 0.5037006578947368


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.40it/s]


Epoch: 150, Loss: 1.6471934423112033, Accuracy: 0.5052083333333334


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.10it/s]


Epoch: 151, Loss: 1.6447709777898956, Accuracy: 0.5069901315789473


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.37it/s]


Epoch: 152, Loss: 1.641581418221457, Accuracy: 0.5128837719298246


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.35it/s]


Epoch: 153, Loss: 1.6394850768541034, Accuracy: 0.5086348684210527


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.44it/s]


Epoch: 154, Loss: 1.6372258203071461, Accuracy: 0.5142543859649122


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.38it/s]


Epoch: 155, Loss: 1.6351094162255002, Accuracy: 0.5138432017543859


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.48it/s]


Epoch: 156, Loss: 1.6327980443050987, Accuracy: 0.5106907894736842


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.42it/s]


Epoch: 157, Loss: 1.6310204351157473, Accuracy: 0.5100054824561403


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.28it/s]


Epoch: 158, Loss: 1.6292054318545157, Accuracy: 0.5120614035087719


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.30it/s]


Epoch: 159, Loss: 1.6268990374448007, Accuracy: 0.5138432017543859


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.48it/s]


Epoch: 160, Loss: 1.62504429984511, Accuracy: 0.5119243421052632


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.38it/s]


Epoch: 161, Loss: 1.6232090330960458, Accuracy: 0.5121984649122807


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.38it/s]


Epoch: 162, Loss: 1.6213527604153282, Accuracy: 0.5127467105263158


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.30it/s]


Epoch: 163, Loss: 1.620081665223105, Accuracy: 0.510827850877193


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.40it/s]


Epoch: 164, Loss: 1.6187152067820232, Accuracy: 0.5087719298245614


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.46it/s]


Epoch: 165, Loss: 1.6168782376406485, Accuracy: 0.5102796052631579


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.48it/s]


Epoch: 166, Loss: 1.6158853288282429, Accuracy: 0.5104166666666666


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.51it/s]


Epoch: 167, Loss: 1.6142952755877846, Accuracy: 0.5112390350877193


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.63it/s]


Epoch: 168, Loss: 1.6131928008899354, Accuracy: 0.5067160087719298


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.32it/s]


Epoch: 169, Loss: 1.6118136100601732, Accuracy: 0.5053453947368421


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.27it/s]


Epoch: 170, Loss: 1.6102737418392248, Accuracy: 0.5098684210526315


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.25it/s]


Epoch: 171, Loss: 1.6087743357608193, Accuracy: 0.5112390350877193


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.18it/s]


Epoch: 172, Loss: 1.607601013099938, Accuracy: 0.5137061403508771


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.15it/s]


Epoch: 173, Loss: 1.6069268992072658, Accuracy: 0.5091831140350878


100%|███████████████████████████████████████████| 57/57 [00:03<00:00, 17.49it/s]


Epoch: 174, Loss: 1.6057585331431605, Accuracy: 0.5115131578947368


 53%|██████████████████████▋                    | 30/57 [00:02<00:01, 14.38it/s]


KeyboardInterrupt: 

In [186]:
torch.save({
    'model_config': {
        'n_class': n_class,
        'audio_feature_dim': audio_feature_dim,
        'imu_feature_dim': imu_feature_dim,
        'fusion_output_size': fusion_output_size,
        },
    'model_state_dict': multimodal_model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    }, 
    './saved_models/multimodal_embed_model_{}-{}.pt'.format(datetime.datetime.now().date().month, datetime.datetime.now().date().day))

## Testing

In [30]:
checkpoint = torch.load('./saved_models/multimodal_embed_model_4-17.pt', map_location=device)

multimodal_model = AudioAndIMUFusion(
                               checkpoint['model_config']['n_class'],
                               checkpoint['model_config']['audio_feature_dim'],
                               checkpoint['model_config']['imu_feature_dim'],
                               checkpoint['model_config']['fusion_output_size']
                               )
                                     
multimodal_model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [31]:
criterion = nn.CrossEntropyLoss()
multimodal_model.to(device)

test_loss = []
test_acc = []

for i, (audio_embeds, imu_embeds, labels) in enumerate(multimodal_embed_test_loader):
    # Zero the grads
    optimizer.zero_grad()
    audio_embeds = audio_embeds.to(device)
    imu_embeds = imu_embeds.to(device)
    labels = labels.to(device)

    # Run the Net
    x = multimodal_model(audio_embeds, imu_embeds)
    # Optimize net
    loss = criterion(x, labels.long())
    test_loss.append(loss.item())

        # Calculat accuracy
    _, pred = x.data.topk(1, dim=1)
    pred = pred.view(pred.shape[:-1])
    acc = torch.sum(pred == labels)/x.shape[0]
    test_acc.append(acc.item())
    
print('Loss: {}, Accuracy: {}'.format(np.mean(test_loss), np.mean(test_acc)))

Loss: 1.495086904229789, Accuracy: 0.9202586206896551


# Generate and save fusion embeddings

In [32]:
embeddings1 = 0
for i, (audio_embeds, imu_embeds, labels) in enumerate(multimodal_embed_train_loader):
    audio_embeds = audio_embeds.to(device)
    imu_embeds = imu_embeds.to(device)
    labels = labels.to(device)

    embeddings1 = multimodal_model.extract_features(audio_embeds, imu_embeds).detach().cpu().numpy()
print(embeddings1.shape)

(8, 128)


In [33]:
def get_fusion_embed_dataset(multimodal_embed_loader, multimodal_model, device, overwrite=False):
    
    save_path = multimodal_embed_loader.dataset.db_path.replace('MultimodalDataset', 'fusion')
    config_file = './Multimodal/dataset_config.json'
    multimodal_model.eval()
    
    if not os.path.isfile(save_path) or overwrite:
        dataset = {}
        dataset['embeddings'] = []
        dataset['labels'] = []

        for i, (audio_embeds, imu_embeds, labels) in enumerate(multimodal_embed_loader):
            audio_embeds = audio_embeds.to(device)
            imu_embeds = imu_embeds.to(device)

            embeddings = multimodal_model.extract_features(audio_embeds, imu_embeds).detach().cpu().numpy()
            
            dataset['embeddings'].append(embeddings)
            dataset['labels'].append(labels.numpy())
            
        dataset['embeddings'] = np.concatenate(dataset['embeddings'], axis=0)
        dataset['labels'] = np.concatenate(dataset['labels'], axis=0)
        np.savez(save_path, **dataset)
        
    else:
        dataset = np.load(save_path, allow_pickle=True)
    
    with open(config_file, 'r') as f:
        dataset_config = json.load(f)
        
    return dataset, dataset_config


In [34]:
fusion_dataset, fusion_dataset_config = get_fusion_embed_dataset(multimodal_embed_train_loader, multimodal_model, device)
fusion_embed_train_set = FusionEmbed(fusion_dataset, fusion_dataset_config)

fusion_dataset, fusion_dataset_config = get_fusion_embed_dataset(multimodal_embed_test_loader, multimodal_model, device)
fusion_embed_test_set = FusionEmbed(fusion_dataset, fusion_dataset_config)

In [35]:
print(len(audio_train_set.sounds))
print(len(imu_train_set.imus))
print(fusion_embed_train_set.embeddings.shape)
print(len(fusion_embed_test_set.label_mapping))


480
12102
(900, 128)
9


In [8]:
from sklearn.model_selection import KFold

In [45]:
k_folds = 5
kf = KFold(n_splits=k_folds, shuffle=True)

In [37]:
kf.split(multimodal_train_set)[0]

TypeError: 'generator' object is not subscriptable

In [46]:
kf

KFold(n_splits=5, random_state=None, shuffle=True)

In [50]:
print(len(multimodal_train_set))
folds = []

for fold, (train_idx, valid_idx) in enumerate(kf.split(multimodal_train_set)):
    print(f"Fold {fold + 1}", (train_idx, valid_idx))
    folds.append((train_idx, valid_idx))
    


7200
Fold 1 (array([   0,    1,    2, ..., 7197, 7198, 7199]), array([   3,   32,   52, ..., 7190, 7194, 7196]))
Fold 2 (array([   1,    2,    3, ..., 7197, 7198, 7199]), array([   0,    4,    6, ..., 7177, 7181, 7192]))
Fold 3 (array([   0,    2,    3, ..., 7196, 7197, 7199]), array([   1,   16,   18, ..., 7179, 7186, 7198]))
Fold 4 (array([   0,    1,    3, ..., 7195, 7196, 7198]), array([   2,    9,   12, ..., 7193, 7197, 7199]))
Fold 5 (array([   0,    1,    2, ..., 7197, 7198, 7199]), array([   5,   17,   19, ..., 7185, 7187, 7195]))


In [62]:
kf = KFold(n_splits=k_folds, shuffle=True)

folds = []

for fold, (train_idx, valid_idx) in enumerate(kf.split(multimodal_train_set)):
    print(f"Fold {fold + 1}", (train_idx, valid_idx))
    folds.append((train_idx, valid_idx))

for train_idx, valid_idx in folds:
    print(train_idx, valid_idx)
    


Fold 1 (array([   0,    2,    3, ..., 7196, 7198, 7199]), array([   1,    5,   11, ..., 7191, 7195, 7197]))
Fold 2 (array([   0,    1,    3, ..., 7197, 7198, 7199]), array([   2,    4,   18, ..., 7189, 7193, 7194]))
Fold 3 (array([   0,    1,    2, ..., 7197, 7198, 7199]), array([   3,    9,   10, ..., 7164, 7188, 7190]))
Fold 4 (array([   0,    1,    2, ..., 7195, 7197, 7198]), array([   6,    8,   12, ..., 7192, 7196, 7199]))
Fold 5 (array([   1,    2,    3, ..., 7196, 7197, 7199]), array([   0,    7,   21, ..., 7179, 7187, 7198]))
<class 'numpy.ndarray'> [   1    5   11 ... 7191 7195 7197]
<class 'numpy.ndarray'> [   2    4   18 ... 7189 7193 7194]
<class 'numpy.ndarray'> [   3    9   10 ... 7164 7188 7190]
<class 'numpy.ndarray'> [   6    8   12 ... 7192 7196 7199]
<class 'numpy.ndarray'> [   0    7   21 ... 7179 7187 7198]
