In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
from torch import nn
import torch
import torch.nn.functional as F

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

In [None]:
device

In [None]:
from datasets import load_dataset
from utils.ch15util import transforms

In [None]:
dataset = load_dataset('huggan/flowers-102-categories', split='train')
dataset.set_transform(transforms)

In [None]:
from torchvision.utils import make_grid

In [None]:
grid = make_grid(dataset[:16]["input"], 8, 2)

In [None]:
plt.figure(figsize=(8, 2), dpi=300)
plt.imshow(grid.numpy().transpose(1,2,0))
plt.axis('off')

In [None]:
resolution = 64

In [None]:
batch_size = 4

In [None]:
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
clean_images=next(iter(train_dataloader))["input"] * 2 -1
print(clean_images.shape)
nums=clean_images.shape[0]
noise=torch.randn(clean_images.shape)
print(noise.shape)

In [None]:
from utils.ch15util import DDIMScheduler

In [None]:
noise_scheduler = DDIMScheduler(num_train_timesteps=1000)

In [None]:
allimgs = clean_images

In [None]:
for step in range(200, 1001, 200):
    timesteps = torch.tensor([step-1]*4).long()
    noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
    allimgs = torch.cat((allimgs, noisy_images))

In [None]:
imgs = make_grid(allimgs, 4, 6)

In [None]:
imgs=make_grid(allimgs,4,6)
fig = plt.figure(dpi=300)
plt.imshow((imgs.permute(2,1,0)+1)/2)
plt.axis('off')

In [None]:
from utils.unet_util import UNet, Attention

In [None]:
n1 = nn.Conv2d(128, 128 * 3, 1, bias=False)

In [None]:
x = torch.randn(1, 128, 64, 64)

In [None]:
o1 = n1(x)

In [None]:
o1.shape

In [None]:
q, k, v = o1.chunk(3, dim=1)

In [None]:
from einops import rearrange

In [None]:
rearrange(q, 'b (h c) x y -> b h c (x y)', h=4).shape

In [None]:
attn = Attention(128)

In [None]:
resolution = 64
model = UNet(3, hidden_dims=[128, 256, 512, 1024], image_size=resolution).to(device)

In [None]:
model

In [None]:
num=sum(p.numel() for p in model.parameters())
print("number of parameters: %.2fM" % (num/1e6,))

In [None]:
from diffusers.optimization import get_scheduler

In [None]:
num_epochs = 100
optimizer = torch.optim.AdamW(model.parameters(), lr=.0001, betas=(.95, .999), weight_decay=0.00001, eps=1e-8)

In [None]:
lr_schedule = get_scheduler(name='cosine', optimizer=optimizer, num_warmup_steps=300, num_training_steps=len(train_dataloader) * num_epochs)

In [None]:
# model.load_state_dict(torch.load('files/models/difussion.pth'))

In [None]:
from tqdm import tqdm

In [None]:
for epoch in range(num_epochs):
    loop = tqdm(train_dataloader, leave=False)
    tloss = 0
    for step, batch in enumerate(loop):
        clean_images = (batch["input"] * 2 - 1).to(device)
        nums = clean_images.shape[0]
        noise = torch.randn(clean_images.shape).to(device)
        timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (nums, ), device=device).long()
        noisy_images = noise_scheduler.add_noise(
            clean_images, noise, timesteps=timesteps
        )
        noise_pred = model(noisy_images, timesteps=timesteps)["sample"]
        loss = torch.nn.functional.l1_loss(noise_pred, noise)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_schedule.step()
        tloss += loss.detach().item()
        if step % 100 == 0:
            loop.set_postfix(epoch=epoch, step=step, tloss=tloss / (step + 1))
    torch.save(model.state_dict(), 'files/models/difussion-new.pth')

In [None]:
generator = torch.manual_seed(1)

In [None]:
model.eval()

In [None]:
generated_images,imgs = noise_scheduler.generate(
    model,
    device, 
    num_inference_steps=50, 
    generator=generator, 
    eta=1.0, 
    use_clipped_model_output=True, 
    batch_size=10
)

In [None]:
imgnp=generated_images["sample"]

In [None]:
plt.figure(figsize=(10,4),dpi=300)
for i in range(10):
    ax = plt.subplot(2,5, i + 1)
    plt.imshow(imgnp[i])
    plt.xticks([])
    plt.yticks([])
    plt.tight_layout()

In [None]:
steps = imgs[9::10]
imgs20 = []

In [None]:
for j in [1,3,6,9]:
    for i in range(5):
        imgs20.append(steps[i][j])

In [None]:
plt.figure(figsize=(10,8),dpi=300)
for i in range(20):
    k=i%5
    ax = plt.subplot(4,5, i + 1)
    plt.imshow(imgs20[i])
    plt.xticks([])
    plt.yticks([])
    plt.tight_layout()
    plt.title(f't={800-200*k}',fontsize=15,c="r")

In [None]:
from openai import OpenAI

In [None]:
api_key = 'sk-JB0JjViGhxQAeB4777Ce56A3174344E484BdC4BcD2400eAd'

In [None]:
base_url = 'https://api.gptapi.us/v1'

In [None]:
client = OpenAI(api_key=api_key, base_url=base_url)

In [None]:
response = client.images.generate(
    model="dall-e-3",
    prompt="an astronaut in a space suit riding a unicorn",
    # size="512x512",
    quality="standard",
    n=1,
)

In [None]:
img_url = response.data[0].url

In [None]:
from PIL import Image

In [None]:
img_url