In [1]:
import torch
import torch.nn as nn

from audio_mae import AudioMaskedAutoencoderViT
from functools import partial



# Encoder 조사

### 1. patch embedding - 생략

### 2. position embedding 과정 추적

In [2]:
pos_embedding = nn.Parameter(torch.zeros(1, 512 + 1, 768), requires_grad=False)[:, 1:, :]
print(pos_embedding.shape)
print(pos_embedding[:, :1, :].shape)

torch.Size([1, 512, 768])
torch.Size([1, 1, 768])


In [3]:
cls_token = nn.Parameter(torch.zeros(1, 1, 768)) 
print(cls_token.shape)

torch.Size([1, 1, 768])


In [4]:
cls_token = cls_token + pos_embedding[:, :1, :]
print(cls_token.shape)
cls_tokens = cls_token.expand(2, -1, -1)
print(cls_tokens.shape)

torch.Size([1, 1, 768])
torch.Size([2, 1, 768])


In [5]:
x = torch.ones([2, 102, 768])
x = torch.cat((cls_tokens, x), dim=1)
print(x.shape)

torch.Size([2, 103, 768])


### 3. random masking 분석

In [6]:
mask_ratio = 0.8
x = torch.rand([1, 512, 768]) # batch, patch 개수, embedding dimmension

In [7]:
N, L, D = x.shape  # batch, length, dim [2, 512, 768]
len_keep = int(L * (1 - mask_ratio)) # 512 * (1 - 0.8 ) = 102

noise = torch.rand(N, L, device=x.device)  # noise in [0, 1], noise shape : [N, L], [batch, patch 개수], [2, 512]

# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove # torch.Size([2, 512])
ids_restore = torch.argsort(ids_shuffle, dim=1) # HHJ : ids_restore가 tensor element의 순서를 나타냄

# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # ids_keep.unsqueeze(-1) : torch.Size([2, 102, 1]) ids_keep.unsqueeze(-1).repeat(1, 1, D) : torch.Size([2, 102, 768])
# x_masked : torch.Size([2, 102, 768])

# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore) # [N, L], [2, 512]

# return x_masked, mask, ids_restore

In [8]:
mask_ratio = 0.25
x = torch.rand([1, 4, 3]) # batch, patch 개수, embedding dimmension
x = torch.tensor([[[0.4548, 0.2263, 0.0600],
         [0.0222, 0.2445, 0.2514],
         [0.0398, 0.6840, 0.2381],
         [0.0882, 0.4236, 0.3248]]])
x

tensor([[[0.4548, 0.2263, 0.0600],
         [0.0222, 0.2445, 0.2514],
         [0.0398, 0.6840, 0.2381],
         [0.0882, 0.4236, 0.3248]]])

In [9]:
# print(x)
N, L, D = x.shape  # torch.Size([2, 4, 3])
len_keep = int(L * (1 - mask_ratio)) # 3

noise = torch.tensor([[0.2119, 0.0595, 0.3355, 0.0211]]) 
print(f"Noise : {noise}")

ids_shuffle = torch.argsort(noise, dim=1) 
ids_restore = torch.argsort(ids_shuffle, dim=1) 
print(f"ids_shuffle : {ids_shuffle}")
print(f"ids_restore : {ids_restore}")

# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
print(f"ids_keep : {ids_keep}")
print(f"index = {ids_keep.unsqueeze(-1).repeat(1, 1, D)}")
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 
print(f"x_masked : {x_masked}")

mask = torch.ones([N, L], device=x.device)
print(mask)
mask[:, :len_keep] = 0
print(mask)
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore) # [N, L], [2, 512]
print(f"mask : {mask}")

# return x_masked, mask, ids_restore

Noise : tensor([[0.2119, 0.0595, 0.3355, 0.0211]])
ids_shuffle : tensor([[3, 1, 0, 2]])
ids_restore : tensor([[2, 1, 3, 0]])
ids_keep : tensor([[3, 1, 0]])
index = tensor([[[3, 3, 3],
         [1, 1, 1],
         [0, 0, 0]]])
x_masked : tensor([[[0.0882, 0.4236, 0.3248],
         [0.0222, 0.2445, 0.2514],
         [0.4548, 0.2263, 0.0600]]])
tensor([[1., 1., 1., 1.]])
tensor([[0., 0., 0., 1.]])
mask : tensor([[0., 0., 1., 0.]])


# decoder 분석

In [10]:
from einops import rearrange

In [11]:
audio_mels = torch.ones([2, 1, 1024, 128])

# Paper recommended archs
model  = AudioMaskedAutoencoderViT(
        num_mels=128, mel_len=1024, in_chans=1,
        patch_size=16, embed_dim=768, encoder_depth=12, num_heads=12,
        decoder_embed_dim=512, decoder_depth=16, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6))
        
latent, msk,  ids_restore = model.forward_encoder(audio_mels, mask_ratio = 0.8)
print("latent:", latent.shape ,msk.shape, ids_restore.shape)

x : torch.Size([2, 1, 1024, 128]), before patch_embed
x : torch.Size([2, 512, 768]), after patch_embed
x : torch.Size([2, 512, 768]), before random masking
x : torch.Size([2, 102, 768]), after random masking
cls : torch.Size([2, 1, 768])
x : torch.Size([2, 103, 768]), before encoder
x : torch.Size([2, 103, 768])
latent: torch.Size([2, 103, 768]) torch.Size([2, 512]) torch.Size([2, 512])


In [12]:
x = latent
ids_restore = ids_restore

print(f"x : {x.shape}") # x : torch.Size([2, 103, 768])
x = model.decoder_embed(x[:, 1:, :]) # nn.Linear
print(f"x after decoder_embed : {x.shape}") # torch.Size([2, 102, 512])

# append mask tokens to sequence

mask_tokens = model.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
print(f"mask_tokens : {mask_tokens.shape}") # torch.Size([2, 410, 512])
x_ = torch.cat([x, mask_tokens], dim=1)  # no cls token
print(f"x_cat : {x_.shape}") # torch.Size([2, 512, 512])
x = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
print(f"x : {x.shape}") # torch.Size([2, 512, 512])

b, l, c = x.shape

assert l == model.grid_h * model.grid_w, "input feature has wrong size"

# add pos embed
print(f"x before pos embed: {x.shape}") # torch.Size([2, 512, 512])
x = x + model.decoder_pos_embed
print(f"x after pos embed: {x.shape}") # torch.Size([2, 512, 512])
x = x.view(b, model.grid_h, model.grid_w, c)
print(f"x after reshape: {x.shape}") # torch.Size([2, 64, 8, 512])
# apply Transformer blocks
for blk in model.decoder_blocks:
    x = blk(x)
print(f"x after transformer block: {x.shape}") # torch.Size([2, 64, 8, 512])
x = rearrange(x, 'b h w c -> b (h w) c')
print(f"x after rerrange: {x.shape}") # torch.Size([2, 512, 512])
x = model.decoder_norm(x)
print(f"x after decoder_norm: {x.shape}") # torch.Size([2, 512, 512])
# predictor projection
x = model.decoder_pred(x)
print(f"x predictor projection: {x.shape}") # torch.Size([2, 512, 256])

x : torch.Size([2, 103, 768])
x after decoder_embed : torch.Size([2, 102, 512])
mask_tokens : torch.Size([2, 410, 512])
x_cat : torch.Size([2, 512, 512])
x : torch.Size([2, 512, 512])
x before pos embed: torch.Size([2, 512, 512])
x after pos embed: torch.Size([2, 512, 512])
x after reshape: torch.Size([2, 64, 8, 512])
x after transformer block: torch.Size([2, 64, 8, 512])
x after rerrange: torch.Size([2, 512, 512])
x after decoder_norm: torch.Size([2, 512, 512])
x predictor projection: torch.Size([2, 512, 256])


# Loss 분석

In [13]:
imgs, pred, mask = audio_mels, x, msk

loss =  model.forward_loss(audio_mels, pred, mask)
print(loss)

tensor(2.4189, grad_fn=<DivBackward0>)


In [14]:
print(f"imgs : {imgs.shape}")

target = model.patchify(imgs)

print(f"after patchfy : {target.shape}")
print(f"pred : {pred.shape}")
if model.norm_pix_loss:
    print("1")
    mean = target.mean(dim=-1, keepdim=True)
    var = target.var(dim=-1, keepdim=True)
    target = (target - mean) / (var + 1.e-6) ** .5

loss = (pred - target) ** 2
loss = loss.mean(dim=-1)  # [N, L], mean loss per patch

loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches

imgs : torch.Size([2, 1, 1024, 128])
after patchfy : torch.Size([2, 512, 256])
pred : torch.Size([2, 512, 256])


In [17]:
unpatch = model.unpatchify(target)

In [19]:
target.shape

torch.Size([2, 512, 256])

In [18]:
unpatch.shape

torch.Size([2, 1, 1024, 128])

# Model 구조 분석

In [None]:
audio_mels = torch.ones([2, 1, 1024, 128])

# Paper recommended archs
model  = AudioMaskedAutoencoderViT(
        num_mels=128, mel_len=1024, in_chans=1,
        patch_size=16, embed_dim=768, encoder_depth=12, num_heads=12,
        decoder_embed_dim=512, decoder_depth=16, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6))
        

- input

In [None]:
loss, pred, mask = model(audio_mels)
print(loss, pred.shape, mask.shape) # mask_ratio = 0.8

- 의문 : torch.Size([2, 512, 256]) 이 사이즈는 어떻게 나왔나?
- forward_encoder에서 mask의 값이긴 함
- 아래 cell에서 확인할 것

### encoder

In [None]:
latent, msk,  ids_restore = model.forward_encoder(audio_mels, mask_ratio = 0.8)
print("latent:", latent.shape ,msk.shape, ids_restore.shape)

### decoder

In [None]:
pred = model.forward_decoder(latent, ids_restore) 
print(pred.shape)

In [None]:
loss =  model.forward_loss(audio_mels, pred, mask)
print(loss)

### input test

In [None]:
import torch
import librosa
from transform import MelSpectrogram_transform

In [None]:
audio, sr = librosa.load('./N6e5C5sXdBI_0.000_10.000.wav', sr = 16000)
print(len(audio), sr, len(audio) / sr)

In [None]:
import torch
import torchaudio.transforms as T


sr = 16000

spectrogram = T.MelSpectrogram(sample_rate= sr,
                            hop_length = int(sr * 0.01),
                            n_fft = 512,
                            n_mels = 128,
                            window_fn = torch.hann_window,f_max=8000
                            ) # 1 : energy, 2 : power

spectrogram_a = T.Spectrogram(hop_length = int(sr * 0.01),
                            n_fft = 512,
                            window_fn = torch.hann_window) # 1 : energy, 2 : power


In [None]:
spectrogram_a(torch.ones(160000)).shape

In [None]:
(sr * 0.001)

In [None]:
win_len = int(sr * 0.025)
hop_len = int(sr * 0.01)
print(win_len, hop_len)

In [None]:
mel_spectrogram = MelSpectrogram_transform(sample_rate = sr, hop_length=hop_len, win_length=win_len, n_fft = 1024)

In [None]:
mel = mel_spectrogram(torch.Tensor(audio))

In [None]:
mel.shape

In [None]:
mel = mel.transpose(0, 1)
print(mel.shape)

In [None]:
mel = mel.unsqueeze(dim = 0)
mel = mel.unsqueeze(dim = 0)

In [None]:
mel.shape

In [None]:
latent, mask, ids_restore = model.forward_encoder(mel, mask_ratio = 0)