In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import warnings
warnings.filterwarnings('ignore')
import sys
sys.path.append("/root/clip")

import matplotlib.pyplot as plt

import torch
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
import torchaudio.transforms as atrans
from clipmbt.data.data_loading import load_data, Collate_Constrastive
from clipmbt.datasets import tess
from clipmbt.constants import *
from clipmbt.tokenizer import make_audio_input


ds_to_use = tess

In [3]:

def image_grid(batch1=None, batch2=None):
    assert (batch1 is not None) or (batch2 is not None), "rgb_batch or spec_batch must be present, cannot both be empty!"
    batch1_img_grid = None
    batch2_img_grid = None
    if batch1 is not None:
        batch_sz, height, width = batch1.shape
        spec_copy = batch1.clone().permute(0, 1, 2)
        batch1_img_grid = spec_copy.reshape(height*batch_sz, width)

    if batch2 is not None:
        batch_sz, height, width = batch2.shape
        spec_copy = batch2.clone().permute(0, 1, 2)
        batch2_img_grid = spec_copy.reshape(height*batch_sz, width)
    
    if (batch1_img_grid is not None) and (batch2_img_grid is not None):
        img_grid = torch.cat((batch1_img_grid, batch2_img_grid), dim=1)
    else:
        img_grid = batch1_img_grid if batch1_img_grid is not None else batch2_img_grid
    return img_grid


def show_batch(batch_1=None, batch_2=None, title="Image batch", size=5):
    result = image_grid(batch_1, batch_2)
    fig = plt.figure(figsize=(size, size))
    plt.suptitle(f"{title}")
    plt.imshow(result, interpolation="nearest")

In [4]:
hop_len = int(ds_to_use.SAMPLING_RATE * ds_to_use.SPEC_HOP_LEN_MS / 1000)
win_len = int(ds_to_use.SAMPLING_RATE * ds_to_use.SPEC_WINDOW_SZ_MS/ 1000)
audio_transform = nn.Sequential(
    atrans.MelSpectrogram(
        sample_rate=ds_to_use.SAMPLING_RATE,
        n_fft=1024,
        n_mels=NUM_MELS,
        win_length=win_len,
        hop_length=hop_len,
        normalized=True,
        pad_mode="constant"
    ),
    atrans.AmplitudeToDB()
)

In [14]:
from clipmbt.data.utils import invert_melspec
import torchaudio
from pathlib import Path
import librosa
import scipy

emotion = "ps"
words = ["fail", "judge", "nag", "search", "youth", "late", "laud", "ring", "red", "shall", "gaze"]
of_files = [f"{ds_to_use.DATA_DIR}/OAF_{emotion}/OAF_{word}_{emotion.lower()}.wav" for word in words]
yf_files = [f"{ds_to_use.DATA_DIR}/YAF_{emotion}/YAF_{word}_{emotion.lower()}.wav" for word in words]

of_signals = [make_audio_input(of_file) for of_file in of_files]
yf_signals = [make_audio_input(yf_file) for yf_file in yf_files]

of_specs = pad_sequence([audio_transform(of_signal).permute(1, 0) for of_signal in of_signals]).permute(1, 2, 0)
yf_specs = pad_sequence([audio_transform(yf_signal).permute(1, 0) for yf_signal in yf_signals]).permute(1, 2, 0)

# of_tmp = of_specs + abs(of_specs.min())
# of_norm = of_tmp / of_tmp.max()
# yf_tmp = yf_specs + abs(yf_specs.min())
# yf_norm = yf_tmp / yf_tmp.max()



test_inv_file = Path(of_files[0])
print(test_inv_file)
n_fft = 1024
hop_length = int(ds_to_use.SAMPLING_RATE * ds_to_use.SPEC_HOP_LEN_MS / 1000)
mspec = of_specs[0]

audio_signal = invert_melspec(mspec, n_fft=n_fft, sr=ds_to_use.SAMPLING_RATE, hop_len=hop_len)

torchaudio.save(f"./{test_inv_file.stem}_reconstructed.wav", audio_signal, ds_to_use.SAMPLING_RATE)




# print(test_inv_file)
# test_inv_spec = of_specs[0]
# print(test_inv_spec.shape)
# test_inv_signal = invert_melspec(test_inv_spec, n_fft=513, sr=ds_to_use.SAMPLING_RATE)
# torchaudio.save(f"./{test_inv_file.with_suffix('')}_reconstructed.wav")


# print(of_specs_ampl[0])
# show_batch(of_specs_ampl, None, size=20, title=emotion)
# plt.imshow(of_spec)
# plt.imshow(yf_spec)


/root/intelpa-1/datasets/tess_dataset/OAF_ps/OAF_fail_ps.wav


In [6]:
# import plotly.io as pio
# pio.renderers.default = "notebook"
# from plotly.offline import init_notebook_mode
# init_notebook_mode(connected=True)

# from utils import *
# inference_b = get_inference_batch()
# load_and_display_attn("utt_v.jepa_labeled.contrastive_0.0_layers.12", inference_b, graph_title="unsupervised jepa 1.0 attn_weights")
# load_and_display_attn("utt_video.offset.contrastive.labeled0.0_backbone.out.sif", inference_b, graph_title="unsupervised 0.0 attn_weights")
# load_and_display_attn("utt_video.offset.contrastive.labeled0.75", graph_title="unsupervised 1.0 attn_weights")
# load_and_display_attn("utt_video.offset.contrastive.labeled0.1", graph_title="unsupervised 1.0 attn_weights")
# load_and_display_attn("utt_video.offset.contrastive.labeled1.0", inference_b, graph_title="unsupervised 1.0 attn_weights")
# load_and_display_attn("utt_video.offset.contrastive.labeled0.0_redux", inference_b, graph_title="unsupervised 0.0 attn_weights")
# load_and_display_attn("utt_video.offset.contrastive.labeled0.5_redux", inference_b, graph_title="unsupervised 0.5 attn_weights")
# load_and_display_attn("utt_video.offset.contrastive.labeled1.0_redux", inference_b, graph_title="unsupervised 1.0 attn_weights")
# load_and_display_attn("utt_resnet.backbone_contrastive_1.0", inference_b, graph_title="unsupervised 1.0 attn_weights")
# load_and_display_attn("utt_resnet.backbone_contrastive_0.0", inference_b, graph_title="unsupervised 0.0 attn_weights")

# load_and_display_attn("utt_ast.audio_vit.video_contrastive.loss_labeled0.0", inference_b, graph_title="unsupervised 0.0 attn_weights")
# load_and_display_attn("utt_ast_audio.only_contrastive.loss_labeled1.0", inference_b, graph_title="unsupervised 1.0 attn_weights")
