In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt

In [None]:
import diffusers
scheduler = diffusers.schedulers.DDIMScheduler(num_train_timesteps=2)

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(2, 10),
            nn.ReLU(),
            nn.Linear(10, 10),
            nn.ReLU(),
            nn.Linear(10, 10),
            nn.ReLU(),
            nn.Linear(10, 10),
            nn.ReLU(),
            nn.Linear(10, 10),
            nn.ReLU(),
            nn.Linear(10, 2)
        )

    def forward(self, x):
        x = self.layers(x)
        return F.tanh(x)
    
theta = torch.linspace(0, 2 * 3.14159, 100)
x = torch.stack([torch.cos(theta), torch.sin(theta)], dim=1)
y = torch.stack([torch.cos(theta + 0.1), torch.sin(theta + 0.1)], dim=1)

saved = []
    
net = Net()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
criterion = nn.MSELoss()

In [None]:
for epcoh in range(10000):
    noise = torch.randn_like(x)
    t = torch.randint(0, 1, (x.size(0),))
    
    noisy_x = scheduler.add_noise(x, noise, t)

    noise_pred = net(noisy_x)
    loss = criterion(noise_pred, torch.cat([torch.cos(theta), torch.sin(theta)], dim=0).view(-1, 2))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epcoh % 1000 == 0:
        print(f"Loss: {loss.item()}")
        saved.append(noise_pred.detach().clone())

In [None]:
def generate():
    noise = torch.randn_like(x)
    t = 2

    output = net(noise)
    scheduler.set_timesteps(t)
    print(output.shape, noise.shape)
    img = scheduler.step(output, timestep=torch.full((x.size(0),), 2, dtype=torch.long), sample=noise)
    return img
generate()

In [None]:
new_points = torch.randn(100, 2)
output = net(new_points)
plt.scatter(output[:, 0].detach(), output[:, 1].detach())
plt.gca().set_aspect('equal', adjustable='box')
plt.show()

In [None]:
from PIL import Image
from tqdm import tqdm
path = "blog/12-dotcloud/frames"

for i, frame in enumerate(tqdm(saved_points)):
    plt.close(fig)
    plt.figure(figsize=(6, 6))
    x, y = x.clip(-1, 1), y.clip(-1, 1)
    plt.scatter(x, y)
    plt.scatter(frame[:, 0].detach(), frame[:, 1].detach())
    plt.gca().set_aspect('equal', anchor='SW')
    plt.title(f'Frame {i}')
    plt.savefig(f'{path}/{i:03d}.png', bbox_inches='tight', pad_inches=0.25)
    plt.close(fig)
    

In [None]:
#for img in path create gif
import os
import imageio

images = []
for filename in sorted(os.listdir(path)):
    img = imageio.imread(f'{path}/{filename}')
    img.resize((474, 558, 4))
    images.append(img)
    # if images[-1].shape[0] != 512:
    #     images.pop()
    #     continue
    
imageio.mimsave('animation.gif', images, duration=1)

In [None]:
gif = Image.open('animation.gif').show()