In [1]:
import librosa
import torch

from model import get_fband_model, get_swin_transformer_encoder, get_st_classifier, get_wav2img_model
from utils import parse_config

config = parse_config()

### init model

In [2]:
fband_transfer = get_fband_model(config).eval()
wav2img_transfer = get_wav2img_model(config).eval()
st_model = get_swin_transformer(config).eval()
st_classifier = get_st_classifier(config).eval()

### read checkpoint

In [3]:
pre_ckpt_path = r'ckpt\infer\hts_at\model_stopby_train_loss_v0.pth'
pre_ckpt = torch.load(pre_ckpt_path, map_location=torch.device('cpu'))

In [4]:
with torch.no_grad():
    for name, para in fband_transfer.named_parameters():
        for _k, _v in pre_ckpt.items():
            if name in _k:
                print(f'match {name} in {_k}')
                para.set_(_v)
                break

torch.save(fband_transfer.state_dict(), r'workspace\fband_transfer.pth')

match spectrogram_extractor.stft.conv_real.weight in pre.spectrogram_extractor.stft.conv_real.weight
match spectrogram_extractor.stft.conv_imag.weight in pre.spectrogram_extractor.stft.conv_imag.weight
match logmel_extractor.melW in pre.logmel_extractor.melW


In [5]:
key_map = {
    'bn0.weight' : 'pre.bn0.weight',
    'bn0.bias' : 'pre.bn0.bias',
    'down_sample_layer.weight' : 'pre.long_audio_down_sample_layer.weight',
    'down_sample_layer.bias' : 'pre.long_audio_down_sample_layer.bias',
}
with torch.no_grad():
    for name, para in wav2img_transfer.named_parameters():
        para.set_(pre_ckpt[key_map[name]])
    wav2img_transfer.bn0.running_mean.set_(pre_ckpt["pre.bn0.running_mean"])
    wav2img_transfer.bn0.running_var.set_(pre_ckpt["pre.bn0.running_var"])
    wav2img_transfer.bn0.num_batches_tracked.set_(pre_ckpt["pre.bn0.num_batches_tracked"])

torch.save(wav2img_transfer.state_dict(), r'workspace\wav2img_transfer.pth')

In [6]:
with torch.no_grad():
    for name, para in st_model.named_parameters():
        para.set_(pre_ckpt[name])

torch.save(st_model.state_dict(), r'workspace\st_model.pth')

In [7]:
with torch.no_grad():
    for name, para in st_classifier.named_parameters():
        para.set_(pre_ckpt[name])

torch.save(st_classifier.state_dict(), r'workspace\st_classifier.pth')

### test data

In [11]:
test_wav = r'datas\test.wav'
signal, _ = librosa.load(test_wav, sr=32000, mono=True)
test_input = torch.Tensor(signal).unsqueeze(0)
print(test_input.shape)

fband = fband_transfer(test_input)
print(fband.shape)
img = wav2img_transfer(fband)
print(img.shape)
st_output = st_model(img)
print(st_output.shape)
classifier_output = st_classifier(st_output)
print(classifier_output[0].shape, classifier_output[1].shape, classifier_output[2].shape)

torch.Size([1, 984064])
torch.Size([1, 1, 3072, 64])
torch.Size([1, 1, 256, 256])
torch.Size([1, 768, 8, 8])


KeyboardInterrupt: 

In [9]:
result = classifier_output[0].detach()
result

tensor([[ 0.0818, -0.6886,  1.0419,  0.1545,  1.2445,  1.6563,  0.8026,  0.5598,
         -1.2866,  3.6938, -0.0896, -1.0914, -0.7614, -1.0926, -2.0282, -0.7174,
         -2.8176, -0.1787,  1.0420, -0.1627, -0.5629,  0.6500, -1.2423,  1.3095,
         -2.3625,  0.5163]])

In [10]:
result.argmax()

tensor(9)