In [None]:
''' 
This script does conditional image generation on MNIST, using a diffusion model

This code is modified from,
https://github.com/cloneofsimo/minDiffusion

Diffusion model is based on DDPM,
https://arxiv.org/abs/2006.11239

The conditioning idea is taken from 'Classifier-Free Diffusion Guidance',
https://arxiv.org/abs/2207.12598

This technique also features in ImageGen 'Photorealistic Text-to-Image Diffusion Modelswith Deep Language Understanding',
https://arxiv.org/abs/2205.11487

'''

from typing import Dict, Tuple
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np
162/45:
class ResidualConvBlock(nn.Module):
    def __init__(
        self, in_channels: int, out_channels: int, is_res: bool = False
    ) -> None:
        super().__init__()
        '''
        standard ResNet style convolutional block
        '''
        self.same_channels = in_channels==out_channels
        self.is_res = is_res
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.GELU(),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.GELU(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.is_res:
            x1 = self.conv1(x)
            x2 = self.conv2(x1)
            # this adds on correct residual in case channels have increased
            if self.same_channels:
                out = x + x2
            else:
                out = x1 + x2 
            return out / 1.414
        else:
            x1 = self.conv1(x)
            x2 = self.conv2(x1)
            return x2


class UnetDown(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UnetDown, self).__init__()
        '''
        process and downscale the image feature maps
        '''
        layers = [ResidualConvBlock(in_channels, out_channels), nn.MaxPool2d(2)]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


class UnetUp(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UnetUp, self).__init__()
        '''
        process and upscale the image feature maps
        '''
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
            ResidualConvBlock(out_channels, out_channels),
            ResidualConvBlock(out_channels, out_channels),
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x, skip):
        x = torch.cat((x, skip), 1)
        x = self.model(x)
        return x


class EmbedFC(nn.Module):
    def __init__(self, input_dim, emb_dim):
        super(EmbedFC, self).__init__()
        '''
        generic one layer FC NN for embedding things  
        '''
        self.input_dim = input_dim
        layers = [
            nn.Linear(input_dim, emb_dim),
            nn.GELU(),
            nn.Linear(emb_dim, emb_dim),
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        return self.model(x)


class ContextUnet(nn.Module):
    # def __init__(self, in_channels, n_feat = 256, n_classes=10):
    def __init__(self, in_channels, n_feat = 256):
        super(ContextUnet, self).__init__()

        self.in_channels = in_channels
        self.n_feat = n_feat
        # self.n_classes = n_classes

        self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)

        self.down1 = UnetDown(n_feat, n_feat)
        self.down2 = UnetDown(n_feat, 2 * n_feat)
        
        # self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())
        self.to_vec = nn.Sequential(nn.AvgPool2d(4), nn.GELU())

        self.timeembed1 = EmbedFC(1, 2*n_feat)
        self.timeembed2 = EmbedFC(1, 1*n_feat)
        self.timeembed1 = EmbedFC(1, 2*n_feat)
        self.timeembed2 = EmbedFC(1, 1*n_feat)
        # self.contextembed1 = EmbedFC(n_classes, 2*n_feat)
        # self.contextembed2 = EmbedFC(n_classes, 1*n_feat)
        self.contextembed1 = EmbedFC(12, 2*n_feat)
        self.contextembed2 = EmbedFC(12, 1*n_feat)

        self.up0 = nn.Sequential(
            # nn.ConvTranspose2d(6 * n_feat, 2 * n_feat, 7, 7), # when concat temb and cemb end up w 6*n_feat
            nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, 4, 4), # otherwise just have 2*n_feat
            nn.GroupNorm(8, 2 * n_feat),
            nn.ReLU(),
        )

        self.up1 = UnetUp(4 * n_feat, n_feat)
        self.up2 = UnetUp(2 * n_feat, n_feat)
        self.out = nn.Sequential(
            nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1),
            nn.GroupNorm(8, n_feat),
            nn.ReLU(),
            nn.Conv2d(n_feat, self.in_channels, 3, 1, 1),
        )

    def forward(self, x, c, t, context_mask):

        # x is (noisy) image, c is context label, t is timestep, 
        # context_mask says which samples to block the context on

        # print('x', x.shape)
        # print('t', t.shape)

        x = self.init_conv(x)
        down1 = self.down1(x)
        down2 = self.down2(down1)
        # print('333333', down1.shape)
        # print('444444', down2.shape)
        hiddenvec = self.to_vec(down2)

        # convert context to one hot embedding
        # c = nn.functional.one_hot(c, num_classes=self.n_classes).type(torch.float)
        # print(88888, c.shape)
        # c = torch.tensor(c.reshape((c.shape[0], 12)))
        # print(9999, c.shape)
        c = c.reshape((c.shape[0], 12))
        
        # mask out context if context_mask == 1
        # context_mask = context_mask[:, None]
        # print('66666', context_mask.shape)
        # context_mask = context_mask.repeat(1,12)
        context_mask = context_mask.reshape((x.shape[0], 12))
        context_mask = (-1*(1-context_mask)) # need to flip 0 <-> 1
        
        c = c * context_mask

        # print('context_mask', context_mask.shape)
        # print('c', c.shape)
        # print('------')
        
        
        # embed context, time step
        cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1)
        temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
        cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)
        temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)

        # print(999999, temb1.shape)
        # print(88888, t.shape)
        # print(t)
        # print(1010101, self.timeembed1(t).shape)

        # could concatenate the context embedding here instead of adaGN
        # hiddenvec = torch.cat((hiddenvec, temb1, cemb1), 1)

        up1 = self.up0(hiddenvec)
        # up2 = self.up1(up1, down2) # if want to avoid add and multiply embeddings
        # print(000, cemb1.shape)
        # print(111, up1.shape)
        # print(222, temb1.shape)
        # print(333, down2.shape)
        up2 = self.up1(cemb1*up1+ temb1, down2)  # add and multiply embeddings
        up3 = self.up2(cemb2*up2+ temb2, down1)
        out = self.out(torch.cat((up3, x), 1))
        return out


def ddpm_schedules(beta1, beta2, T):
    """
    Returns pre-computed schedules for DDPM sampling, training process.
    """
    assert beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)"

    beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1
    sqrt_beta_t = torch.sqrt(beta_t)
    alpha_t = 1 - beta_t
    log_alpha_t = torch.log(alpha_t)
    alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp()

    sqrtab = torch.sqrt(alphabar_t)
    oneover_sqrta = 1 / torch.sqrt(alpha_t)

    sqrtmab = torch.sqrt(1 - alphabar_t)
    mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab

    return {
        "alpha_t": alpha_t,  # \alpha_t
        "oneover_sqrta": oneover_sqrta,  # 1/\sqrt{\alpha_t}
        "sqrt_beta_t": sqrt_beta_t,  # \sqrt{\beta_t}
        "alphabar_t": alphabar_t,  # \bar{\alpha_t}
        "sqrtab": sqrtab,  # \sqrt{\bar{\alpha_t}}
        "sqrtmab": sqrtmab,  # \sqrt{1-\bar{\alpha_t}}
        "mab_over_sqrtmab": mab_over_sqrtmab_inv,  # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
    }


class DDPM(nn.Module):
    def __init__(self, nn_model, betas, n_T, device, drop_prob=0.1):
        super(DDPM, self).__init__()
        self.nn_model = nn_model.to(device)

        # register_buffer allows accessing dictionary produced by ddpm_schedules
        # e.g. can access self.sqrtab later
        for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
            self.register_buffer(k, v)

        self.n_T = n_T
        self.device = device
        self.drop_prob = drop_prob
        self.loss_mse = nn.MSELoss()

    def forward(self, x, c):
        """
        this method is used in training, so samples t and noise randomly
        """
        _ts = torch.randint(1, self.n_T+1, (x.shape[0],)).to(self.device)  # t ~ Uniform(0, n_T)
        noise = torch.randn_like(x)  # eps ~ N(0, 1)

        x_t = (
            self.sqrtab[_ts, None].reshape((x.shape[0], 1, 1, 1)) * x
            + self.sqrtmab[_ts, None].reshape((x.shape[0], 1, 1, 1)) * noise
        )  # This is the x_t, which is sqrt(alphabar) x_0 + sqrt(1-alphabar) * eps
        # We should predict the "error term" from this x_t. Loss is what we return.

        # dropout context with some probability
        context_mask = torch.bernoulli(torch.zeros_like(c)+self.drop_prob).to(self.device)

        # print(555, x_t.shape)
        # print('c6666', c.shape)
        # print(777, (_ts / self.n_T).shape)
        # print(888, context_mask.shape)
        
        # return MSE between added noise, and our predicted noise
        return self.loss_mse(noise, self.nn_model(x_t, c, _ts / self.n_T, context_mask))

    def sample(self, n_sample, size, device, guide_w = 0.0):
        # we follow the guidance sampling scheme described in 'Classifier-Free Diffusion Guidance'
        # to make the fwd passes efficient, we concat two versions of the dataset,
        # one with context_mask=0 and the other context_mask=1
        # we then mix the outputs with the guidance scale, w
        # where w>0 means more guidance

        x_i = torch.randn(n_sample, *size).to(device)  # x_T ~ N(0, 1), sample initial noise

        # c_i = torch.arange(0,10).to(device) # context for us just cycles throught the mnist labels
        c_i = torch.rand((n_sample, 1, 12)).to(device) # context for us just cycles throught the mnist labels

        # don't drop context at test time
        context_mask = torch.zeros_like(c_i).to(device)

        # double the batch
        c_i = c_i.repeat(2, 1, 1)
        context_mask = context_mask.repeat(2, 1, 1)
        context_mask[n_sample:] = 1. # makes second half of batch context free

        x_i_store = [] # keep track of generated steps in case want to plot something 
        print()
        for i in range(self.n_T, 0, -1):
            print(f'sampling timestep {i}',end='\r')
            t_is = torch.tensor([i / self.n_T]).to(device)
            t_is = t_is.repeat(n_sample,1,1,1)

            # print(x_i.shape)

            # double batch
            x_i = x_i.repeat(2,1,1,1)
            # t_is = t_is.repeat(2,1,1,1)
            t_is = t_is.repeat(2,1,1,1)

            z = torch.randn(n_sample, *size).to(device) if i > 1 else 0

            # print('want the next 4 to be 128, 1, 12')
            # print('x_i', x_i.shape)
            # print('c_i', c_i.shape)
            # print('t_is', t_is.shape)
            # print('context_mask', context_mask.shape)

            # split predictions and compute weighting
            eps = self.nn_model(x_i, c_i, t_is, context_mask)
            eps1 = eps[:n_sample]
            eps2 = eps[n_sample:]
            
            eps = (1+guide_w)*eps1 - guide_w*eps2
            x_i = x_i[:n_sample]
            x_i = (
                self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i])
                + self.sqrt_beta_t[i] * z
            )
            if i%20==0 or i==self.n_T or i<8:
                x_i_store.append(x_i.detach().cpu().numpy())
        
        x_i_store = np.array(x_i_store)
        return x_i, x_i_store
162/46:
# @title TESS Dataset
import os
from PIL import Image
import torch
import pickle
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms

class TESSDataset(Dataset):
    def __init__(self):

        # self.data = []
        # self.labels = []

        # get data
        angle_folder = "/pdo/users/jlupoiii/TESS/data/angles/"
        ccd_folder = "/pdo/users/jlupoiii/TESS/data/ccds_background_subtracted/"

        # data matrices
        X = []
        Y = []
        ffi_nums = []

        self.angles_dic = pickle.load(open(angle_folder+'angles_O13_data_dic.pkl', "rb"))
        # angles_dic = pickle.load(open(angle_folder+'angles_O13_data.pkl', "rb"))


        # self.ffi_num_to_orbit_dic = pickle.load(open(angle_folder+'ffi_num_to_orbit_dic.pkl', "rb"))

        for filename in os.listdir(ccd_folder):
            if len(filename) < 40 or filename[27] != '3': continue

            image_arr = pickle.load(open(ccd_folder+filename, "rb"))
            ffi_num = filename[18:18+8]
            try:
                angles = self.angles_dic[ffi_num]
                # print('Got ffi number', ffi_num)
            except:
                # print('Could not find ffi with number:', ffi_num)
                continue
            # X.append(np.array([angles[10], angles[11], angles[18], angles[19], angles[22], angles[23], angles[24], angles[25]]))
            X.append(np.array([angles['1/ED'], angles['1/MD'], angles['1/ED^2'], angles['1/MD^2'], angles['Eel'], angles['Eaz'], angles['Mel'], angles['Maz'], angles['E3el'], angles['E3az'], angles['M3el'], angles['M3az']]))

            Y.append(image_arr.flatten())
            ffi_nums.append(ffi_num)

        # for x in X:
        #     print(x)
        
        # X = np.array(X)
        # Y = np.array(Y)
        # ffis = np.array(ffis)
        # we are calculating Y GIVEN X
        self.data = [Image.fromarray(x) for x in X]
        self.labels = [Image.fromarray(y) for y in Y]
        self.ffi_nums = ffi_nums

        # for s in self.labels:
        #     print(s.size)
        #     # print(np.array(s).reshape((16,16))
            
        #     plt.imshow(np.array(s).reshape((16,16)))
        #     plt.show()
        #     plt.close()



    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        angles_image = self.data[idx]
        ffi_image = self.labels[idx]
        ffi_num = self.ffi_nums[idx]
        # orbit = self.ffi_num_to_orbit_dic[ffi_num]
        orbit = self.angles_dic[ffi_num]["orbit"]

        transform = transforms.Compose([
            # transforms.Resize((1, 8)),
            transforms.ToTensor(),
            lambda s: s.reshape(1, 12),
            # transforms.Normalize(mean=[0.456], std=[0.225])
        ])
        target_transform = transforms.Compose([
            lambda s: np.array(s),
            lambda s: s.reshape((16,16)),
            # transforms.Resize((16, 16)),
            transforms.ToTensor()
        ])

        angles_image = transform(angles_image)
        ffi_image = target_transform(ffi_image)
        
        # we are calculating X GIVEN Y
        return {"x":angles_image, "y":ffi_image, "ffi_num": ffi_num, "orbit": orbit}
162/47:
def train_TESS():

    # hardcoding these here
    n_epoch = 20
    batch_size = 128 # 4 # 256
    n_T = 400 # 500
    device = "cuda:0"
    # n_classes = 10
    n_feat = 128 # 128 ok, 256 better (but slower)
    lrate = 1e-4
    save_model = True
    save_dir = './model_O13/'
    ws_test = [0.0, 0.5, 2.0] # strength of generative guidance

    # ddpm = DDPM(nn_model=ContextUnet(in_channels=1, n_feat=n_feat, n_classes=n_classes), betas=(1e-4, 0.02), n_T=n_T, device=device, drop_prob=0.1)
    ddpm = DDPM(nn_model=ContextUnet(in_channels=1, n_feat=n_feat), betas=(1e-4, 0.02), n_T=n_T, device=device, drop_prob=0.1)
    ddpm.to(device)

    # optionally load a model
    # ddpm.load_state_dict(torch.load("./data/diffusion_outputs/ddpm_unet01_mnist_9.pth"))

    # tf = transforms.Compose([transforms.ToTensor()]) # mnist is already normalised 0 to 1
    # dataset = MNIST("./data", train=True, download=True, transform=tf)
    # dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=5)
    dataset = TESSDataset()
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=5)


    
    optim = torch.optim.Adam(ddpm.parameters(), lr=lrate)

    for ep in range(n_epoch):
        print(f'epoch {ep}')
        ddpm.train()

        # linear lrate decay
        optim.param_groups[0]['lr'] = lrate*(1-ep/n_epoch)

        pbar = tqdm(dataloader)
        loss_ema = None

        # for data_dic in pbar:
        #     optim.zero_grad()
        #     x = data_dic['y'].to(device)
        #     c = data_dic['x'].to(device)
        #     print('x', x.shape, 'c', c.shape)
        # return

        # print(type(pbar))
        # for x, c in pbar:
        for data_dic in pbar:
            optim.zero_grad()
            x = data_dic['y'].to(device)
            c = data_dic['x'].to(device)

            # print('x', type(x), x.shape)
            # print('c', type(c), c.shape)
            # print('-----')
            
            loss = ddpm(x, c)
            loss.backward()
            if loss_ema is None:
                loss_ema = loss.item()
            else:
                loss_ema = 0.95 * loss_ema + 0.05 * loss.item()
            pbar.set_description(f"loss: {loss_ema:.4f}")
            optim.step()
        
        # for eval, save an image of currently generated samples (top rows)
        # followed by real images (bottom rows)
        ddpm.eval()
        with torch.no_grad():
            n_sample = 40
            for w_i, w in enumerate(ws_test):
                x_gen, x_gen_store = ddpm.sample(n_sample, (1, 16, 16), device, guide_w=w)

                # append some real images at bottom, order by class also
                x_real = torch.Tensor(x_gen.shape).to(device)
                
                # plt.imshow(np.array(x_gen[0][0].cpu()), vmin = 0, vmax = 1)
                # plt.title("generated")
                # plt.show()
                # plt.close()
                # plt.imshow(np.array(x_real[0][0].cpu()), vmin = 0, vmax = 1)
                # plt.title("real")
                # plt.show()
                # plt.close()
                # print('----------------------------------')
                
                
                
                # for k in range(n_classes):
                #     for j in range(int(n_sample/10)):
                #         try: 
                #             idx = torch.squeeze((c == k).nonzero())[j]
                #         except:
                #             idx = 0
                #         x_real[k+(j*10)] = x[idx]

                # x_all = torch.cat([x_gen, x_real])
                # grid = make_grid(x_all*-1 + 1, nrow=10)

        #         if ep%5==0 or ep == int(n_epoch-1):
        #             # create gif of images evolving over time, based on x_gen_store
        #             fig, axs = plt.subplots(nrows=int(n_sample/n_classes), ncols=n_classes,sharex=True,sharey=True,figsize=(8,3))
        #             def animate_diff(i, x_gen_store):
        #                 print(f'gif animating frame {i} of {x_gen_store.shape[0]}', end='\r')
        #                 plots = []
        #                 for row in range(int(n_sample/n_classes)):
        #                     for col in range(n_classes):
        #                         axs[row, col].clear()
        #                         axs[row, col].set_xticks([])
        #                         axs[row, col].set_yticks([])
        #                         # plots.append(axs[row, col].imshow(x_gen_store[i,(row*n_classes)+col,0],cmap='gray'))
        #                         plots.append(axs[row, col].imshow(-x_gen_store[i,(row*n_classes)+col,0],cmap='gray',vmin=(-x_gen_store[i]).min(), vmax=(-x_gen_store[i]).max()))
        #                 return plots
        #             ani = FuncAnimation(fig, animate_diff, fargs=[x_gen_store],  interval=200, blit=False, repeat=True, frames=x_gen_store.shape[0])    
        #             ani.save(save_dir + f"gif_ep{ep}_w{w}.gif", dpi=100, writer=PillowWriter(fps=5))
        #             print('saved image at ' + save_dir + f"gif_ep{ep}_w{w}.gif")
        # # optionally save model
        # if save_model and ep == int(n_epoch-1):
        #     torch.save(ddpm.state_dict(), save_dir + f"model_{ep}.pth")
        #     print('saved model at ' + save_dir + f"model_{ep}.pth")

if __name__ == "__main__":
    save_model_name = 'model_TESS_O13'
    train_TESS()
161/21:
def train_mnist(save_dir):

    # hardcoding these here
    n_epoch = 20
    batch_size = 1024 # 256
    n_T = 400 # 500
    device = "cuda:0"
    n_classes = 10
    n_feat = 128 # 128 ok, 256 better (but slower)
    lrate = 1e-4
    save_model = False
    # save_dir = './diffusion_outputs/'
    ws_test = [0.0, 0.5, 2.0] # strength of generative guidance

    ddpm = DDPM(nn_model=ContextUnet(in_channels=1, n_feat=n_feat, n_classes=n_classes), betas=(1e-4, 0.02), n_T=n_T, device=device, drop_prob=0.1)
    ddpm.to(device)

    # optionally load a model
    # ddpm.load_state_dict(torch.load("./data/diffusion_outputs/ddpm_unet01_mnist_9.pth"))

    tf = transforms.Compose([transforms.ToTensor()]) # mnist is already normalised 0 to 1

    dataset = MNIST("./data", train=True, download=True, transform=tf)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=5)
    optim = torch.optim.Adam(ddpm.parameters(), lr=lrate)

    for ep in range(n_epoch):
        print(f'epoch {ep}')
        ddpm.train()

        # linear lrate decay
        optim.param_groups[0]['lr'] = lrate*(1-ep/n_epoch)

        pbar = tqdm(dataloader)
        loss_ema = None
        for x, c in pbar:
            # print('x', type(x), x.shape)
            # print('c', type(c), c.shape, c)
            # print('-----')
            
            optim.zero_grad()
            x = x.to(device)
            c = c.to(device)
            loss = ddpm(x, c)
            loss.backward()
            if loss_ema is None:
                loss_ema = loss.item()
            else:
                loss_ema = 0.95 * loss_ema + 0.05 * loss.item()
            pbar.set_description(f"loss: {loss_ema:.4f}")
            optim.step()
        
        # for eval, save an image of currently generated samples (top rows)
        # followed by real images (bottom rows)
        ddpm.eval()
        with torch.no_grad():
            n_sample = 4*n_classes
            for w_i, w in enumerate(ws_test):
                x_gen, x_gen_store = ddpm.sample(n_sample, (1, 28, 28), device, guide_w=w)

                print(x_gen.shape)
                
                # print('111')
                # for image in x_gen:
                #     print(image.shape)
                #     plt.imshow(np.array(image.squeeze().cpu()))
                #     plt.show()
                #     plt.close()

                # append some real images at bottom, order by class also
                x_real = torch.Tensor(x_gen.shape).to(device)
                for k in range(n_classes):
                    for j in range(int(n_sample/n_classes)):
                        try: 
                            idx = torch.squeeze((c == k).nonzero())[j]
                        except:
                            idx = 0
                        x_real[k+(j*n_classes)] = x[idx]

                x_all = torch.cat([x_gen, x_real])
                grid = make_grid(x_all*-1 + 1, nrow=10)
                save_image(grid, save_dir + f"image_ep{ep}_w{w}.png")
                print('saved image at ' + save_dir + f"image_ep{ep}_w{w}.png")

                if ep%5==0 or ep == int(n_epoch-1):
                    # create gif of images evolving over time, based on x_gen_store
                    fig, axs = plt.subplots(nrows=int(n_sample/n_classes), ncols=n_classes,sharex=True,sharey=True,figsize=(8,3))
                    def animate_diff(i, x_gen_store):
                        print(f'gif animating frame {i} of {x_gen_store.shape[0]}', end='\r')
                        plots = []
                        for row in range(int(n_sample/n_classes)):
                            for col in range(n_classes):
                                axs[row, col].clear()
                                axs[row, col].set_xticks([])
                                axs[row, col].set_yticks([])
                                # plots.append(axs[row, col].imshow(x_gen_store[i,(row*n_classes)+col,0],cmap='gray'))
                                plots.append(axs[row, col].imshow(-x_gen_store[i,(row*n_classes)+col,0],cmap='gray',vmin=(-x_gen_store[i]).min(), vmax=(-x_gen_store[i]).max()))
                        return plots
                    ani = FuncAnimation(fig, animate_diff, fargs=[x_gen_store],  interval=200, blit=False, repeat=True, frames=x_gen_store.shape[0])    
                    ani.save(save_dir + f"gif_ep{ep}_w{w}.gif", dpi=100, writer=PillowWriter(fps=5))
                    print('saved image at ' + save_dir + f"gif_ep{ep}_w{w}.gif")
        # optionally save model
        if save_model and ep == int(n_epoch-1):
            torch.save(ddpm.state_dict(), save_dir + f"model_{ep}.pth")
            print('saved model at ' + save_dir + f"model_{ep}.pth")

if __name__ == "__main__":
    save_dir = './diffusion_outputs/'
    # os.makedirs(save_dir, exist_ok=True)
    train_mnist(save_dir)
162/48:
def train_TESS(save_dir):

    # hardcoding these here
    n_epoch = 20
    batch_size = 128 # 4 # 256
    n_T = 400 # 500
    device = "cuda:0"
    # n_classes = 10
    n_feat = 128 # 128 ok, 256 better (but slower)
    lrate = 1e-4
    save_model = True
    # save_dir = './model_O13/'
    ws_test = [0.0, 0.5, 2.0] # strength of generative guidance

    # ddpm = DDPM(nn_model=ContextUnet(in_channels=1, n_feat=n_feat, n_classes=n_classes), betas=(1e-4, 0.02), n_T=n_T, device=device, drop_prob=0.1)
    ddpm = DDPM(nn_model=ContextUnet(in_channels=1, n_feat=n_feat), betas=(1e-4, 0.02), n_T=n_T, device=device, drop_prob=0.1)
    ddpm.to(device)

    # optionally load a model
    # ddpm.load_state_dict(torch.load("./data/diffusion_outputs/ddpm_unet01_mnist_9.pth"))

    # tf = transforms.Compose([transforms.ToTensor()]) # mnist is already normalised 0 to 1
    # dataset = MNIST("./data", train=True, download=True, transform=tf)
    # dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=5)
    dataset = TESSDataset()
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=5)


    
    optim = torch.optim.Adam(ddpm.parameters(), lr=lrate)

    for ep in range(n_epoch):
        print(f'epoch {ep}')
        ddpm.train()

        # linear lrate decay
        optim.param_groups[0]['lr'] = lrate*(1-ep/n_epoch)

        pbar = tqdm(dataloader)
        loss_ema = None

        # for data_dic in pbar:
        #     optim.zero_grad()
        #     x = data_dic['y'].to(device)
        #     c = data_dic['x'].to(device)
        #     print('x', x.shape, 'c', c.shape)
        # return

        # print(type(pbar))
        # for x, c in pbar:
        for data_dic in pbar:
            optim.zero_grad()
            x = data_dic['y'].to(device)
            c = data_dic['x'].to(device)

            # print('x', type(x), x.shape)
            # print('c', type(c), c.shape)
            # print('-----')
            
            loss = ddpm(x, c)
            loss.backward()
            if loss_ema is None:
                loss_ema = loss.item()
            else:
                loss_ema = 0.95 * loss_ema + 0.05 * loss.item()
            pbar.set_description(f"loss: {loss_ema:.4f}")
            optim.step()
        
        # for eval, save an image of currently generated samples (top rows)
        # followed by real images (bottom rows)
        ddpm.eval()
        with torch.no_grad():
            n_sample = 40
            for w_i, w in enumerate(ws_test):
                x_gen, x_gen_store = ddpm.sample(n_sample, (1, 16, 16), device, guide_w=w)

                # append some real images at bottom, order by class also
                x_real = torch.Tensor(x_gen.shape).to(device)
                
                # plt.imshow(np.array(x_gen[0][0].cpu()), vmin = 0, vmax = 1)
                # plt.title("generated")
                # plt.show()
                # plt.close()
                # plt.imshow(np.array(x_real[0][0].cpu()), vmin = 0, vmax = 1)
                # plt.title("real")
                # plt.show()
                # plt.close()
                # print('----------------------------------')
                
                
                
                # for k in range(n_classes):
                for k in range(10):
                    # for j in range(int(n_sample/num_classes)):
                    for j in range(int(n_sample/10)):
                        try: 
                            idx = torch.squeeze((c == k).nonzero())[j]
                        except:
                            idx = 0
                        x_real[k+(j*10)] = x[idx]

                x_all = torch.cat([x_gen, x_real])
                grid = make_grid(x_all*-1 + 1, nrow=10)
                save_image(grid, save_dir + f"image_ep{ep}_w{w}.png")
                print('saved image at ' + save_dir + f"image_ep{ep}_w{w}.png")

        #         if ep%5==0 or ep == int(n_epoch-1):
        #             # create gif of images evolving over time, based on x_gen_store
        #             fig, axs = plt.subplots(nrows=int(n_sample/n_classes), ncols=n_classes,sharex=True,sharey=True,figsize=(8,3))
        #             def animate_diff(i, x_gen_store):
        #                 print(f'gif animating frame {i} of {x_gen_store.shape[0]}', end='\r')
        #                 plots = []
        #                 for row in range(int(n_sample/n_classes)):
        #                     for col in range(n_classes):
        #                         axs[row, col].clear()
        #                         axs[row, col].set_xticks([])
        #                         axs[row, col].set_yticks([])
        #                         # plots.append(axs[row, col].imshow(x_gen_store[i,(row*n_classes)+col,0],cmap='gray'))
        #                         plots.append(axs[row, col].imshow(-x_gen_store[i,(row*n_classes)+col,0],cmap='gray',vmin=(-x_gen_store[i]).min(), vmax=(-x_gen_store[i]).max()))
        #                 return plots
        #             ani = FuncAnimation(fig, animate_diff, fargs=[x_gen_store],  interval=200, blit=False, repeat=True, frames=x_gen_store.shape[0])    
        #             ani.save(save_dir + f"gif_ep{ep}_w{w}.gif", dpi=100, writer=PillowWriter(fps=5))
        #             print('saved image at ' + save_dir + f"gif_ep{ep}_w{w}.gif")
        # # optionally save model
        # if save_model and ep == int(n_epoch-1):
        #     torch.save(ddpm.state_dict(), save_dir + f"model_{ep}.pth")
        #     print('saved model at ' + save_dir + f"model_{ep}.pth")

if __name__ == "__main__":
    save_dir = 'model_TESS_O13'
    os.makedirs(save_dir, exist_ok=True)
    train_TESS(save_dir)
162/49:
def train_TESS(save_dir):

    # hardcoding these here
    n_epoch = 20
    batch_size = 128 # 4 # 256
    n_T = 400 # 500
    device = "cuda:0"
    # n_classes = 10
    n_feat = 128 # 128 ok, 256 better (but slower)
    lrate = 1e-4
    save_model = True
    # save_dir = './model_O13/'
    ws_test = [0.0, 0.5, 2.0] # strength of generative guidance

    # ddpm = DDPM(nn_model=ContextUnet(in_channels=1, n_feat=n_feat, n_classes=n_classes), betas=(1e-4, 0.02), n_T=n_T, device=device, drop_prob=0.1)
    ddpm = DDPM(nn_model=ContextUnet(in_channels=1, n_feat=n_feat), betas=(1e-4, 0.02), n_T=n_T, device=device, drop_prob=0.1)
    ddpm.to(device)

    # optionally load a model
    # ddpm.load_state_dict(torch.load("./data/diffusion_outputs/ddpm_unet01_mnist_9.pth"))

    # tf = transforms.Compose([transforms.ToTensor()]) # mnist is already normalised 0 to 1
    # dataset = MNIST("./data", train=True, download=True, transform=tf)
    # dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=5)
    dataset = TESSDataset()
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=5)


    
    optim = torch.optim.Adam(ddpm.parameters(), lr=lrate)

    for ep in range(n_epoch):
        print(f'epoch {ep}')
        ddpm.train()

        # linear lrate decay
        optim.param_groups[0]['lr'] = lrate*(1-ep/n_epoch)

        pbar = tqdm(dataloader)
        loss_ema = None

        # for data_dic in pbar:
        #     optim.zero_grad()
        #     x = data_dic['y'].to(device)
        #     c = data_dic['x'].to(device)
        #     print('x', x.shape, 'c', c.shape)
        # return

        # print(type(pbar))
        # for x, c in pbar:
        for data_dic in pbar:
            optim.zero_grad()
            x = data_dic['y'].to(device)
            c = data_dic['x'].to(device)

            # print('x', type(x), x.shape)
            # print('c', type(c), c.shape)
            # print('-----')
            
            loss = ddpm(x, c)
            loss.backward()
            if loss_ema is None:
                loss_ema = loss.item()
            else:
                loss_ema = 0.95 * loss_ema + 0.05 * loss.item()
            pbar.set_description(f"loss: {loss_ema:.4f}")
            optim.step()
        
        # for eval, save an image of currently generated samples (top rows)
        # followed by real images (bottom rows)
        ddpm.eval()
        with torch.no_grad():
            n_sample = 40
            for w_i, w in enumerate(ws_test):
                x_gen, x_gen_store = ddpm.sample(n_sample, (1, 16, 16), device, guide_w=w)

                # append some real images at bottom, order by class also
                x_real = torch.Tensor(x_gen.shape).to(device)
                
                # plt.imshow(np.array(x_gen[0][0].cpu()), vmin = 0, vmax = 1)
                # plt.title("generated")
                # plt.show()
                # plt.close()
                # plt.imshow(np.array(x_real[0][0].cpu()), vmin = 0, vmax = 1)
                # plt.title("real")
                # plt.show()
                # plt.close()
                # print('----------------------------------')
                
                
                
                # for k in range(n_classes):
                for k in range(10):
                    # for j in range(int(n_sample/num_classes)):
                    for j in range(int(n_sample/10)):
                        try: 
                            idx = torch.squeeze((c == k).nonzero())[j]
                        except:
                            idx = 0
                        x_real[k+(j*10)] = x[idx]

                x_all = torch.cat([x_gen, x_real])
                grid = make_grid(x_all*-1 + 1, nrow=10)
                save_image(grid, save_dir + f"image_ep{ep}_w{w}.pdf")
                print('saved image at ' + save_dir + f"image_ep{ep}_w{w}.pdf")

        #         if ep%5==0 or ep == int(n_epoch-1):
        #             # create gif of images evolving over time, based on x_gen_store
        #             fig, axs = plt.subplots(nrows=int(n_sample/n_classes), ncols=n_classes,sharex=True,sharey=True,figsize=(8,3))
        #             def animate_diff(i, x_gen_store):
        #                 print(f'gif animating frame {i} of {x_gen_store.shape[0]}', end='\r')
        #                 plots = []
        #                 for row in range(int(n_sample/n_classes)):
        #                     for col in range(n_classes):
        #                         axs[row, col].clear()
        #                         axs[row, col].set_xticks([])
        #                         axs[row, col].set_yticks([])
        #                         # plots.append(axs[row, col].imshow(x_gen_store[i,(row*n_classes)+col,0],cmap='gray'))
        #                         plots.append(axs[row, col].imshow(-x_gen_store[i,(row*n_classes)+col,0],cmap='gray',vmin=(-x_gen_store[i]).min(), vmax=(-x_gen_store[i]).max()))
        #                 return plots
        #             ani = FuncAnimation(fig, animate_diff, fargs=[x_gen_store],  interval=200, blit=False, repeat=True, frames=x_gen_store.shape[0])    
        #             ani.save(save_dir + f"gif_ep{ep}_w{w}.gif", dpi=100, writer=PillowWriter(fps=5))
        #             print('saved image at ' + save_dir + f"gif_ep{ep}_w{w}.gif")
        # # optionally save model
        # if save_model and ep == int(n_epoch-1):
        #     torch.save(ddpm.state_dict(), save_dir + f"model_{ep}.pth")
        #     print('saved model at ' + save_dir + f"model_{ep}.pth")

if __name__ == "__main__":
    save_dir = 'model_TESS_O13'
    os.makedirs(save_dir, exist_ok=True)
    train_TESS(save_dir)
   1: %history
   2: %history -g -f filename
   3: %history -g -f model_train_TESS