In [1]:
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import numpy as np
from mpl_toolkits.axes_grid1 import make_axes_locatable
import librosa
import io
import pickle5 as pickle
import imageio

In [2]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    """
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    """
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()


def plot_gan_losses(d_loss, g_loss):
    fig, ax = plt.subplots(figsize=(10, 8))
    ax.plot(d_loss.detach(), label="discriminator loss")
    ax.plot(g_loss.detach(), label="generator loss")
    ax.set_xlabel("n_epochs")
    ax.legend()
    ax.grid(True)
    plt.show()

In [3]:
def plot_spectrogram_mag(STFT_amp,title,fig=None,ax=None):
    def add_colorbar(fig,ax,im):
        divider = make_axes_locatable(ax)
        cax = divider.append_axes('right', size='5%', pad=0.05)
        fig.colorbar(im, cax=cax, orientation='vertical')
    
    if fig is None:
        fig,ax=plt.subplots(1,1,figsize=(12,6))
    
    times = np.linspace(0,2*128/121,STFT_amp.shape[1])
    freqs = librosa.mel_frequencies(n_mels=STFT_amp.shape[0], fmin=0.0, fmax=16000/2,htk=True)
    
    X,Y = np.meshgrid(times,freqs)
    
    im0 = ax.pcolor(X,Y,STFT_amp,shading='auto',cmap='magma')
    
    add_colorbar(fig,ax,im0)
    labelRow='Temps'
    labelCol='Fréquences'
    ax.set_xlabel(labelRow)
    ax.set_ylabel(labelCol)
    ax.set_title(title)

In [32]:
def generate_spectro_gif(pkl, title, nb_img, duration, loop):
    with open(pkl, 'rb') as f:
        x = pickle.load(f)
    img = x['img_list']
    imgs = []
    div = np.round(len(img)/nb_img)+1
    for i in range(len(img)):
        if i%div==0 :
            imgs.append(img[i])
    for i in range(len(imgs)):
        plot_spectrogram_mag(imgs[i][0],title='Spectro '+str(i),fig=None,ax=None)
        plt.savefig('./img'+str(i)+'.png')
        plt.clf()
        
    fp_in = "img*.png"
    fp_out = title

    img, *imgs = [Image.open(f) for f in sorted(glob.glob(fp_in))]
    img.save(fp=fp_out, format='GIF', append_images=imgs,
             save_all=True, duration=duration, loop=loop)

In [33]:
generate_spectro_gif(pkl='results_loss_techno__wgan__z_256__lr_0.0001__k_5__e_20.pkl', title='test.gif',nb_img=10, duration=200, loop=0)

<Figure size 864x432 with 0 Axes>

<Figure size 864x432 with 0 Axes>

<Figure size 864x432 with 0 Axes>

<Figure size 864x432 with 0 Axes>

<Figure size 864x432 with 0 Axes>

<Figure size 864x432 with 0 Axes>

<Figure size 864x432 with 0 Axes>

<Figure size 864x432 with 0 Axes>

<Figure size 864x432 with 0 Axes>

<Figure size 864x432 with 0 Axes>