In [11]:
import torch
import numpy as np
from scipy.io.wavfile import write
from ipywidgets import Audio, widgets
from IPython.display import display

SAMPLE_RATE = 22050  # Define the sample rate for your audio

def to_audio_widget(wav: torch.Tensor, normalize: bool = False):
    assert len(wav.shape) == 2, f"Expected 2D tensor, got shape: {wav.shape}"
    
    # Convert tensor to numpy array and sum across channels to make mono if needed
    audio_data = wav.sum(dim=0, keepdims=True).cpu().numpy()
    
    # Create the audio widget using keyword arguments
    return Audio(value=audio_data, rate=SAMPLE_RATE, normalize=normalize)

def wrap_in_out(*obj):
    out = widgets.Output()
    with out:
        display(*obj)
    return out

def grid_widget(grid_of_objs):
    col_boxes = []
    for row_of_objs in grid_of_objs:
        row_outs = []
        for obj in row_of_objs:
            row_outs += [obj]
        col_boxes += [widgets.HBox(row_outs)]
    return widgets.VBox(col_boxes)

def save_audio(wav: torch.Tensor, filename: str):
    # Convert tensor to numpy array
    audio_data = wav.sum(dim=0, keepdims=True).cpu().numpy()
    
    # Normalize the audio if needed (to the range of [-1, 1])
    if audio_data.max() > 1 or audio_data.min() < -1:
        audio_data = audio_data / np.max(np.abs(audio_data))
    
    # Save the numpy array as a .wav file using scipy
    write(filename, SAMPLE_RATE, audio_data.astype(np.float32))


In [13]:
#读取sample.png文件为一个tensor
from PIL import Image
import torch
from torchvision import transforms

# Load the image
image_path = 'sample.png'
image = Image.open(image_path)

# Define the transformation to convert the image to a tensor
transform = transforms.ToTensor()

# Apply the transformation to the image
image_tensor = transform(image)
# Check the shape of the tensor
audio_tensor = torch.randn(3,512*512)
print(audio_tensor.shape)

torch.Size([3, 262144])


In [29]:
from torch.utils.data import DataLoader
from data import MultiSourceDataset
testset = MultiSourceDataset(
    sr=22050,
    channels=1,
    min_duration=12,
    max_duration=640,
    aug_shift=True,
    sample_length=262144,
    audio_files_dir="/home/jingyi49/multi-source-diffusion-models/data/slakh2100/test",
    stems=['bass', 'drums', 'guitar', 'piano'],)
    
dataloader = DataLoader(testset, batch_size=1, shuffle=True)
    

for data in dataloader:
    x1 = data[:, 0, :].squeeze().reshape(1, 512, 512)
    x2 = data[:, 1, :].squeeze().reshape(1, 512, 512)
    x3 = data[:, 2, :].squeeze().reshape(1, 512, 512)
    x4 = data[:, 3, :].squeeze().reshape(1, 512, 512)
    #将4个乐器的图像加在一起
    y = torch.sum(torch.stack([x1, x2, x3]), dim=0)
    break

Found 152 tracks.
* skipped because sources are not aligned!
[5710592. 5710592. 5710592. 5529856. 5529856. 5529856. 6424320. 6424320.
 6424320. 5435392. 5435392. 8514048. 8514048. 8514048. 6312448. 6312448.
 4978688. 4978688. 4978688. 7500032. 7500032. 7500032. 5928960. 5928192.
 5928960. 9242368. 9242368. 9242368. 6868992. 6868992. 6868992. 6195968.
 6195968. 6195968. 5946880. 5946880. 5946880. 3969280. 3969280. 3969280.
 4957184. 4957184. 4957184. 2765312. 2765312. 2765312. 6243584. 6243584.
 6243584. 6684160. 6684160. 6684160. 6381056. 6381056. 6381056. 3714816.
 3714816. 3714816. 5682944. 5682944. 5682944. 5526016. 5526016. 5526016.
 6045440. 6045440. 6045440. 5372928. 5372928. 5372928. 6481664. 6481664.
 6481664. 5613312. 5613312. 5613312. 5755136. 5755136. 5755136. 5439488.
 5439488. 5439488. 3741184. 3741184. 3741184. 3793920. 3793920. 3793920.
 7833344. 7833344. 7833344. 5648896. 5648896. 5648896. 6258944. 6258944.
 6258944. 4614912. 4614912. 4614912. 5891840. 5891840. 5891840.

In [34]:
#audio_tensor压缩为一个维度
audio_tensor = audio_tensor.squeeze()

torch.Size([1, 262144])

In [38]:
import soundfile
audio_tensor = y.reshape(1, 512*512)
soundfile.write(file="/home/jingyi49/DiT/y.wav", data=audio_tensor.squeeze(), samplerate=SAMPLE_RATE)