## Spectrogram model Inference

In [1]:
from data.load import load_imgs_png, get_loader, get_labels
from data.preprocess import spectrum_transform
from baseline.freqcnn.model import LargeConvSpecModel, LargeConvFEv1, LargeConvFEv3, LargeConvFEv7, \
    LargeConvFEv12, LargeConvFEv11
from baseline.densenet.densenet import DenseNetSpecModel
from audtorch.metrics.functional import pearsonr
from torchvision.models import densenet121
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import cv2
import torch
import os


# model preparation
print("setting up model...")
FC = [2048]
DROPOUT, CLS_BASE = 0.5, -1
# FE_MODEL, EMBEDDING_DIM = LargeConvFEv1, 101376  # embedding dim depends on fe_model
# FE_MODEL, EMBEDDING_DIM = LargeConvFEv3, 26880      # embedding dim depends on fe_model
# FE_MODEL, EMBEDDING_DIM = LargeConvFEv7, 23040     # embedding dim depends on fe_model
# FE_MODEL, EMBEDDING_DIM = LargeConvFEv11, 4096    # embedding dim depends on fe_model
FE_MODEL, EMBEDDING_DIM = LargeConvFEv12, 4096    # embedding dim depends on fe_model
device = "cuda" if torch.cuda.is_available() else "cpu"
val_model = LargeConvSpecModel(FE_MODEL, 3, 1, EMBEDDING_DIM,
                               fcs=FC, dropout=DROPOUT).half().to(device)
val_model_fp = os.path.join(os.getcwd(), "baseline", r"results/valence/models/largeconv_valence0319.pth")
val_model.load_state_dict(torch.load(val_model_fp))
val_model.eval()

aro_model = LargeConvSpecModel(FE_MODEL, 3, 1, EMBEDDING_DIM,
                               fcs=FC, dropout=DROPOUT).half().to(device)
aro_model_fp = os.path.join(os.getcwd(), "baseline", r"results/arousal/models/largeconv_arousal0152.pth")
aro_model.load_state_dict(torch.load(aro_model_fp))
aro_model.eval()

dom_model = LargeConvSpecModel(FE_MODEL, 3, 1, EMBEDDING_DIM,
                               fcs=FC, dropout=DROPOUT).half().to(device)
dom_model_fp = os.path.join(os.getcwd(), "baseline", r"results/dominance/models/largeconv_dominance0167.pth")
dom_model.load_state_dict(torch.load(dom_model_fp))
dom_model.eval()

print('done!')

setting up model...
done!


In [3]:
rick_spec_fp = r"D:\Documents\datasets\AIST4010\muse\4cOdK2wGLETKBW3PvgPWqT.png"
rick_spec = np.array(cv2.cvtColor(cv2.imread(rick_spec_fp), cv2.COLOR_BGR2RGB))
trans = spectrum_transform(resize=128, norm=True, freq_mask=None, time_mask=None)
trans_rick_spec = trans(rick_spec).unsqueeze(0).half().to(device)
val_pred = val_model(trans_rick_spec).item()
aro_pred = aro_model(trans_rick_spec).item()
dom_pred = dom_model(trans_rick_spec).item()

print(f"Spectrogram Model prediction - valence is {val_pred:.4f}    arousal is {aro_pred:.4f}    dominance is {dom_pred:.4f}")

Spectrogram Model prediction - valence is 0.6353    arousal is 0.5195    dominance is 0.6387


## PANN model Inference

In [1]:
from data.load import get_labels, get_wav_fp, WAV_DIR
from data.dataset import LazyWavDataset
from torch.utils.data import Dataset, DataLoader
from audtorch.metrics.functional import pearsonr
from PANN.model import WaveNet
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torch
import os


# model preparation
print("setting up model...")
# model preparation
sr = 22050
wsize, hsize, mel_bins = 520, 320, 128
fmin, fmax = 50, 8000
# fcs, dropout, act = [1024, 1024], 0.5, nn.ReLU
fcs, dropout, act = [2048], 0.5, nn.ReLU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
freeze = 0
val_model = WaveNet(1, 2048, sr=sr, wsize=wsize, hsize=hsize, mel_bins=mel_bins,
                    fmin=fmin, fmax=fmax, fcs=fcs, dropout=dropout, act=act, freeze=freeze, 
                    checkpoint_fp=None).half().to(device)
val_model_fp = os.path.join(os.getcwd(), "PANN", r"results/valence/models/freeze_none_valmse=0304.pth")
val_model.load_state_dict(torch.load(val_model_fp))
val_model.eval()

aro_model = WaveNet(1, 2048, sr=sr, wsize=wsize, hsize=hsize, mel_bins=mel_bins,
                    fmin=fmin, fmax=fmax, fcs=fcs, dropout=dropout, act=act, freeze=freeze, 
                    checkpoint_fp=None).half().to(device)
aro_model_fp = os.path.join(os.getcwd(), "PANN", r"results/arousal/models/freeze_none_valmse=0150.pth")
aro_model.load_state_dict(torch.load(aro_model_fp))
aro_model.eval()

dom_model = WaveNet(1, 2048, sr=sr, wsize=wsize, hsize=hsize, mel_bins=mel_bins,
                    fmin=fmin, fmax=fmax, fcs=fcs, dropout=dropout, act=act, freeze=freeze, 
                    checkpoint_fp=None).half().to(device)
dom_model_fp = os.path.join(os.getcwd(), "PANN", r"results/dominance/models/freeze_none_valmse=0194.pth")
dom_model.load_state_dict(torch.load(dom_model_fp))
dom_model.eval()

setting up model...


WaveNet(
  (wavecnn): Wavegram_Cnn14(
    (pre_conv0): Conv1d(1, 64, kernel_size=(11,), stride=(5,), padding=(5,), bias=False)
    (pre_bn0): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pre_block1): ConvPreWavBlock(
      (conv1): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
      (conv2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,), bias=False)
      (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (pre_block2): ConvPreWavBlock(
      (conv1): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
      (conv2): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,), bias=False)
      (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm1d(128, eps=1e-05, 

In [7]:
rick_wav_fp = r"D:\Documents\datasets\AIST4010\muse\4cOdK2wGLETKBW3PvgPWqT.npy"
rick_wav = torch.from_numpy(np.load(rick_wav_fp)).unsqueeze(0).half().to(device)

val_pred = val_model(rick_wav).item()
aro_pred = aro_model(rick_wav).item()
dom_pred = dom_model(rick_wav).item()

print(f"Waveform model prediction - valence is {val_pred:.4f}    arousal is {aro_pred:.4f}    dominance is {dom_pred:.4f}")

Waveform model prediction - valence is 0.6729    arousal is 0.5103    dominance is 0.5278
