In [1]:
from modules1 import *
from torch.optim import Adam
from  tqdm.notebook import tqdm
import matplotlib.animation as animation
import imageio
import matplotlib.pyplot as plt
import torchvision
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torch. optim.lr_scheduler import ExponentialLR


In [2]:
batch_size = 275
dataset = torchvision.datasets.FashionMNIST(root="C:/Users/ericy/Downloads", train = True, download = True, transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Lambda(lambda t: (t * 2) - 1)
]))
loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

In [3]:
device = "cuda:1"

model = Unet(time_dim=200, dims=((16, 64),(64,128)), channels=1, device = device)
model = model.to(device)

In [4]:
class schedule():
    def __init__(self, timesteps = 200, device="cuda"):
        self.device = device
        betas = beta_schedule(timesteps=timesteps, device = device)
        alphas = 1. - betas
        alphas_cumprod = torch.cumprod(alphas, axis=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
        self.betas = betas
        self.sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
        self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
        self.posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
    def add_noise(self, x,t):
        noise = torch.randn_like(x, device = self.device)
        A = torch.index_select(self.sqrt_alphas_cumprod, 0, t)
        B = torch.index_select(self.sqrt_one_minus_alphas_cumprod, 0, t)
        x = x.permute(1,2,3,0)
        noise_permuted = noise.permute(1,2,3,0)
        C = (A * x + B * noise_permuted).permute(3,0,1,2)
        return C, noise
    def loss(self, model, x0, t):
        xt, noise = self.add_noise(x0,t)
        pred = model(xt, t)
        loss = nn.MSELoss()(noise, pred)
        return loss
    
    @torch.no_grad()
    def sample(self, model, device, time_steps):
        M = np.zeros((time_steps, 28, 28, 1))
        img = torch.randn((1, 1, 28, 28), device=device)*0.5
        for t in range(0,time_steps)[::-1]:
            model_mean = self.sqrt_recip_alphas[t] * (img - self.betas[t] * model(img, torch.unsqueeze(torch.tensor(t,device = device), dim=0)) / self.sqrt_one_minus_alphas_cumprod[t])
            noise = torch.randn_like(model_mean)
            img = model_mean + torch.sqrt(self.posterior_variance[t]) * noise
            out = (img[0].detach().cpu().permute(1,2,0)*0.5 + 0.5)*255
            M[t] = out.numpy().astype(np.uint8)
        return np.flip(M, axis = 0)


In [5]:
time_steps = 200
S = schedule(time_steps, device)

In [None]:
optimizer = Adam(model.parameters(), lr=0.001)
scheduler = ExponentialLR(optimizer, gamma=0.997)

epochs = 1000
L = []

for epoch in tqdm(range(epochs), desc="Epoch", position=0):
    for step, (batch, _) in enumerate(tqdm(loader, desc="Batch", position=1, leave=False)):
        optimizer.zero_grad()
        batch = batch.to(device)
        t = torch.randint(0, time_steps, (len(batch),), device=device).long()
        loss = S.loss(model, batch, t)
        loss.backward()
        optimizer.step()
        scheduler.step
        L.append(loss.detach().cpu())
        if epoch % 1 == 0 and  step == 0:
            M = np.squeeze(S.sample(model, device, time_steps),axis=3).astype(np.uint8)
            imageio.mimwrite("animations0/output" + str(epoch) + ".gif", M, 'GIF', duration = 0.04)
            print("Epoch:", epoch, "Loss:", loss.item())

Epoch:   0%|          | 0/1000 [00:00<?, ?it/s]

Batch:   0%|          | 0/219 [00:00<?, ?it/s]

Epoch: 0 Loss: 0.42243826389312744


Batch:   0%|          | 0/219 [00:00<?, ?it/s]

Epoch: 1 Loss: 0.36524900794029236


Batch:   0%|          | 0/219 [00:00<?, ?it/s]

Epoch: 2 Loss: 0.36237674951553345


Batch:   0%|          | 0/219 [00:00<?, ?it/s]

Epoch: 3 Loss: 0.36971208453178406


Batch:   0%|          | 0/219 [00:00<?, ?it/s]

Epoch: 4 Loss: 0.3639644384384155


Batch:   0%|          | 0/219 [00:00<?, ?it/s]

Epoch: 5 Loss: 0.3624126613140106


Batch:   0%|          | 0/219 [00:00<?, ?it/s]

Epoch: 6 Loss: 0.37481680512428284


Batch:   0%|          | 0/219 [00:00<?, ?it/s]

Epoch: 7 Loss: 0.3579775094985962


Batch:   0%|          | 0/219 [00:00<?, ?it/s]

Epoch: 8 Loss: 0.367175430059433


In [None]:
M = np.squeeze(S.sample(model, device, time_steps),axis=3).astype(np.uint8)
imageio.mimwrite("output4.gif", M[::10], 'GIF')