In [1]:
import torch
from IPython.display import Audio, display, clear_output
import matplotlib.pyplot as plt
import random
import asyncio
import threading
from librosa.beat import beat_track
from pythonosc import dispatcher, osc_server
import numpy as np
from datasets import load_dataset
from diffusers import DiffusionPipeline, DDIMScheduler, AudioDiffusionPipeline
from diffusers_local import UNet2DModel as UNet2DModel_local
import soundfile as sf


from NetworkBending import NetworkBending
import time  # for safely terminating the loop

# Create a global lock for synchronizing OSC parameter updates
osc_lock = threading.Lock()

# Instantiate NetworkBending object
NB = NetworkBending()

device = "cuda" if torch.cuda.is_available() else "cpu"
audio_diffusion = DiffusionPipeline.from_pretrained("teticio/audio-diffusion-ddim-256")
# apply the custom unet model so we can Networkbend the audio
audio_diffusion.unet = UNet2DModel_local.from_pretrained("teticio/audio-diffusion-ddim-256", subfolder="unet", network_bending=NB)

#set the scheduler to DDIM
audio_diffusion.scheduler = DDIMScheduler.from_pretrained("teticio/audio-diffusion-ddim-256", subfolder="scheduler")

audio_diffusion.to(device)

ds = load_dataset('teticio/audio-diffusion-256')
generator=torch.Generator(device="cpu").manual_seed(42)

alpha = 0
latent1 = 0
latent2 = 1

  from .autonotebook import tqdm as notebook_tqdm
  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
unet\diffusion_pytorch_model.safetensors not found
Loading pipeline components...:   0%|          | 0/3 [00:00<?, ?it/s]An error occurred while trying to fetch C:\Users\danhearn\.cache\huggingface\hub\models--teticio--audio-diffusion-ddim-256\snapshots\f5606c5138496ecdcbd096a4446eb6d03ae690cb\unet: Error no file named diffusion_pytorch_model.safetensors found in directory C:\Users\danhearn\.cache\huggingface\hub\models--teticio--audio-diffusion-ddim-256\snapshots\f5606c5138496ecdcbd096a4446eb6d03ae690cb\unet.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
Loading pipeline components...: 100%|██████████| 3/3 [00:00<00:00, 13.04it/s]
An error occurred while trying to fetch teticio/audio-diffusion-ddim-256: teticio/audio-diffusion-ddim-256 does not appear to have a file named diffusion_pytorch_model.safetensors.
Defaulting to unsafe serializat

In [2]:
def loop_it(audio: np.ndarray,
        sample_rate: int,
        loops: int = 4) -> np.ndarray:
    """Loop audio

    Args:
        audio (np.ndarray): audio as numpy array
        sample_rate (int): sample rate of audio
        loops (int): number of times to loop

    Returns:
        (float, np.ndarray): sample rate and raw audio or None
    """
    _, beats = beat_track(y=audio, sr=sample_rate, units='samples')
    for beats_in_bar in [16, 12, 8, 4]:
        if len(beats) > beats_in_bar:
            return np.tile(audio[beats[0]:beats[beats_in_bar]], loops)
    return None

In [3]:
#load 20 images from the dataset
images = []

for i in range(10):
    images.append(random.choice(ds['train'])['image'])

In [7]:
#encode images in latent space
encoded_images = []

for i in range(10):
    noise = audio_diffusion.encode([images[i]], steps=100)
    encoded_images.append(noise)

  hidden_states = F.scaled_dot_product_attention(
100%|██████████| 100/100 [00:02<00:00, 35.19it/s]
100%|██████████| 100/100 [00:02<00:00, 41.35it/s]
100%|██████████| 100/100 [00:02<00:00, 42.13it/s]
100%|██████████| 100/100 [00:02<00:00, 41.24it/s]
100%|██████████| 100/100 [00:02<00:00, 40.83it/s]
100%|██████████| 100/100 [00:02<00:00, 41.28it/s]
100%|██████████| 100/100 [00:02<00:00, 41.29it/s]
100%|██████████| 100/100 [00:02<00:00, 42.42it/s]
100%|██████████| 100/100 [00:02<00:00, 42.87it/s]
100%|██████████| 100/100 [00:02<00:00, 42.71it/s]


In [None]:
for encoded_image in encoded_images:
    plt.imshow(encoded_image.cpu().squeeze(), cmap='gray')
    plt.axis('off')  # Hide axes
    plt.show()
    

In [4]:
def interpolation(address, *args):
    global alpha, latent1, latent2
    if address == "/alpha":
        alpha = args[0]
    if address == "/latent1" and latent2 != args[0]:
        latent1 = int(args[0])
    if address == "/latent2" and latent1 != args[0]:
        latent2 = int(args[0])

In [5]:
#osc listener
from pythonosc import dispatcher, osc_server

ip="127.0.0.1"
port=9999

d = dispatcher.Dispatcher()
# Map incoming OSC messages to the osc_receive method
d.map("/rotate_x_radian",  NB.osc_receive)
d.map("/rotate_y_radian", NB.osc_receive)
d.map("/rotate_z_radian", NB.osc_receive)
d.map("/rotate_x_scaling_factor", NB.osc_receive)
d.map("/rotate_y_scaling_factor", NB.osc_receive)
d.map("/rotate_z_scaling_factor", NB.osc_receive)
d.map("/scale_factor", NB.osc_receive)
d.map("/layer", NB.osc_receive)
d.map("/scale", NB.osc_receive)
d.map("/reflect", NB.osc_receive)
d.map("/erosion", NB.osc_receive)
d.map("/dilation", NB.osc_receive)
d.map("/gradient", NB.osc_receive)
d.map("/sobel", NB.osc_receive)
d.map("/add_rand_rows", NB.osc_receive)
d.map("/normalize", NB.osc_receive)
d.map("/rotate_x", NB.osc_receive)
d.map("/rotate_y", NB.osc_receive)
d.map("/rotate_z", NB.osc_receive)
d.map("/alpha", interpolation)
d.map("/latent1", interpolation)
d.map("/latent2", interpolation)

s = osc_server.ThreadingOSCUDPServer((ip, port), d)

osc_thread = threading.Thread(target=s.serve_forever)
osc_thread.daemon = True # This will allow the main program to exit even if the OSC server is still running
osc_thread.start()
print("OSC server started - listening on port 9999")

OSC server started - listening on port 9999


In [10]:
# Main loop for image generation
try:
    while True:
        with osc_lock:  # Ensure that network bending is thread-safe
            output = audio_diffusion(steps=10,
            noise=AudioDiffusionPipeline.slerp(encoded_images[latent1], encoded_images[latent2], alpha),
            generator=generator, eta=0)
            output.images[0].save("generated_image.png")
            sf.write("output_audio.wav", output.audios[0, 0], audio_diffusion.mel.get_sample_rate())



except KeyboardInterrupt:
    s.shutdown()
    print("Terminating the loop gracefully.")

100%|██████████| 10/10 [00:00<00:00, 28.36it/s]
100%|██████████| 10/10 [00:00<00:00, 42.90it/s]
100%|██████████| 10/10 [00:00<00:00, 43.90it/s]
100%|██████████| 10/10 [00:00<00:00, 46.16it/s]
100%|██████████| 10/10 [00:00<00:00, 39.99it/s]
100%|██████████| 10/10 [00:00<00:00, 42.91it/s]
100%|██████████| 10/10 [00:00<00:00, 40.44it/s]
100%|██████████| 10/10 [00:00<00:00, 44.60it/s]
100%|██████████| 10/10 [00:00<00:00, 42.81it/s]
100%|██████████| 10/10 [00:00<00:00, 43.15it/s]
100%|██████████| 10/10 [00:00<00:00, 42.57it/s]
100%|██████████| 10/10 [00:00<00:00, 39.82it/s]
100%|██████████| 10/10 [00:00<00:00, 42.09it/s]
100%|██████████| 10/10 [00:00<00:00, 43.59it/s]
100%|██████████| 10/10 [00:00<00:00, 41.54it/s]
100%|██████████| 10/10 [00:00<00:00, 43.93it/s]
100%|██████████| 10/10 [00:00<00:00, 35.93it/s]
100%|██████████| 10/10 [00:00<00:00, 42.54it/s]
100%|██████████| 10/10 [00:00<00:00, 42.78it/s]
100%|██████████| 10/10 [00:00<00:00, 42.86it/s]
100%|██████████| 10/10 [00:00<00:00, 42.

Terminating the loop gracefully.
