In [47]:
''' 
This script does conditional image generation on MNIST, using a diffusion model

This code is modified from,
https://github.com/cloneofsimo/minDiffusion

Diffusion model is based on DDPM,
https://arxiv.org/abs/2006.11239

The conditioning idea is taken from 'Classifier-Free Diffusion Guidance',
https://arxiv.org/abs/2207.12598

This technique also features in ImageGen 'Photorealistic Text-to-Image Diffusion Modelswith Deep Language Understanding',
https://arxiv.org/abs/2205.11487

'''

from typing import Dict, Tuple
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np


In [54]:
class ResidualConvBlock(nn.Module):
    def __init__(
        self, in_channels: int, out_channels: int, is_res: bool = False
    ) -> None:
        super().__init__()
        '''
        standard ResNet style convolutional block
        '''
        self.same_channels = in_channels==out_channels
        self.is_res = is_res
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.GELU(),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.GELU(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.is_res:
            x1 = self.conv1(x)
            x2 = self.conv2(x1)
            # this adds on correct residual in case channels have increased
            if self.same_channels:
                out = x + x2
            else:
                out = x1 + x2 
            return out / 1.414
        else:
            x1 = self.conv1(x)
            x2 = self.conv2(x1)
            return x2


class UnetDown(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UnetDown, self).__init__()
        '''
        process and downscale the image feature maps
        '''
        layers = [ResidualConvBlock(in_channels, out_channels), nn.MaxPool2d(2)]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


class UnetUp(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UnetUp, self).__init__()
        '''
        process and upscale the image feature maps
        '''
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
            ResidualConvBlock(out_channels, out_channels),
            ResidualConvBlock(out_channels, out_channels),
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x, skip):
        x = torch.cat((x, skip), 1)
        x = self.model(x)
        return x


class EmbedFC(nn.Module):
    def __init__(self, input_dim, emb_dim):
        super(EmbedFC, self).__init__()
        '''
        generic one layer FC NN for embedding things  
        '''
        self.input_dim = input_dim
        layers = [
            nn.Linear(input_dim, emb_dim),
            nn.GELU(),
            nn.Linear(emb_dim, emb_dim),
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        return self.model(x)


class ContextUnet(nn.Module):
    # def __init__(self, in_channels, n_feat = 256, n_classes=10):
    def __init__(self, in_channels, n_feat = 256):
        super(ContextUnet, self).__init__()

        self.in_channels = in_channels
        self.n_feat = n_feat

        self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)

        self.down1 = UnetDown(n_feat, n_feat)
        self.down2 = UnetDown(n_feat, 2 * n_feat)
        
        self.to_vec = nn.Sequential(nn.AvgPool2d(4), nn.GELU())

        self.timeembed1 = EmbedFC(1, 2*n_feat)
        self.timeembed2 = EmbedFC(1, 1*n_feat)
        self.timeembed1 = EmbedFC(1, 2*n_feat)
        self.timeembed2 = EmbedFC(1, 1*n_feat)
        self.contextembed1 = EmbedFC(12, 2*n_feat)
        self.contextembed2 = EmbedFC(12, 1*n_feat)

        self.up0 = nn.Sequential(
            nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, 4, 4), # otherwise just have 2*n_feat
            nn.GroupNorm(8, 2 * n_feat),
            nn.ReLU(),
        )

        self.up1 = UnetUp(4 * n_feat, n_feat)
        self.up2 = UnetUp(2 * n_feat, n_feat)
        self.out = nn.Sequential(
            nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1),
            nn.GroupNorm(8, n_feat),
            nn.ReLU(),
            nn.Conv2d(n_feat, self.in_channels, 3, 1, 1),
        )

    def forward(self, x, c, t, context_mask):

        # x is (noisy) image, c is context label, t is timestep, 
        # context_mask says which samples to block the context on

        x = self.init_conv(x)
        down1 = self.down1(x)
        down2 = self.down2(down1)
        hiddenvec = self.to_vec(down2)

        c = c.reshape((c.shape[0], 12))
        
        # mask out context if context_mask == 1
        context_mask = context_mask.reshape((x.shape[0], 12))
        context_mask = (-1*(1-context_mask)) # need to flip 0 <-> 1
        c = c * context_mask

        # embed context, time step
        cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1)
        temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
        cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)
        temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)

        # could concatenate the context embedding here instead of adaGN
        # hiddenvec = torch.cat((hiddenvec, temb1, cemb1), 1)

        up1 = self.up0(hiddenvec)
        # up2 = self.up1(up1, down2) # if want to avoid add and multiply embeddings
        up2 = self.up1(cemb1*up1+ temb1, down2)  # add and multiply embeddings
        up3 = self.up2(cemb2*up2+ temb2, down1)
        out = self.out(torch.cat((up3, x), 1))
        return out


def ddpm_schedules(beta1, beta2, T):
    """
    Returns pre-computed schedules for DDPM sampling, training process.
    """
    assert beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)"

    beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1
    sqrt_beta_t = torch.sqrt(beta_t)
    alpha_t = 1 - beta_t
    log_alpha_t = torch.log(alpha_t)
    alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp()

    sqrtab = torch.sqrt(alphabar_t)
    oneover_sqrta = 1 / torch.sqrt(alpha_t)

    sqrtmab = torch.sqrt(1 - alphabar_t)
    mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab

    return {
        "alpha_t": alpha_t,  # \alpha_t
        "oneover_sqrta": oneover_sqrta,  # 1/\sqrt{\alpha_t}
        "sqrt_beta_t": sqrt_beta_t,  # \sqrt{\beta_t}
        "alphabar_t": alphabar_t,  # \bar{\alpha_t}
        "sqrtab": sqrtab,  # \sqrt{\bar{\alpha_t}}
        "sqrtmab": sqrtmab,  # \sqrt{1-\bar{\alpha_t}}
        "mab_over_sqrtmab": mab_over_sqrtmab_inv,  # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
    }


class DDPM(nn.Module):
    def __init__(self, nn_model, betas, n_T, device, drop_prob=0.1):
        super(DDPM, self).__init__()
        self.nn_model = nn_model.to(device)

        # register_buffer allows accessing dictionary produced by ddpm_schedules
        # e.g. can access self.sqrtab later
        for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
            self.register_buffer(k, v)

        self.n_T = n_T
        self.device = device
        self.drop_prob = drop_prob
        self.loss_mse = nn.MSELoss()

    def forward(self, x, c):
        """
        this method is used in training, so samples t and noise randomly
        """
        _ts = torch.randint(1, self.n_T+1, (x.shape[0],)).to(self.device)  # t ~ Uniform(0, n_T)
        noise = torch.randn_like(x)  # eps ~ N(0, 1)

        x_t = (
            self.sqrtab[_ts, None].reshape((x.shape[0], 1, 1, 1)) * x
            + self.sqrtmab[_ts, None].reshape((x.shape[0], 1, 1, 1)) * noise
        )  # This is the x_t, which is sqrt(alphabar) x_0 + sqrt(1-alphabar) * eps
        # We should predict the "error term" from this x_t. Loss is what we return.

        # dropout context with some probability
        context_mask = torch.bernoulli(torch.zeros_like(c)+self.drop_prob).to(self.device)
        
        # return MSE between added noise, and our predicted noise
        return self.loss_mse(noise, self.nn_model(x_t, c, _ts / self.n_T, context_mask))

    def sample(self, n_sample, size, device, guide_w = 0.0):
        '''
        the c_i, context, is a random 1x12 vector. It is not real data. This function will
        not give good preditions. Look to sample_c for better results
        '''
        
        # we follow the guidance sampling scheme described in 'Classifier-Free Diffusion Guidance'
        # to make the fwd passes efficient, we concat two versions of the dataset,
        # one with context_mask=0 and the other context_mask=1
        # we then mix the outputs with the guidance scale, w
        # where w>0 means more guidance

        x_i = torch.randn(n_sample, *size).to(device)  # x_T ~ N(0, 1), sample initial noise
        c_i = torch.rand((n_sample, 1, 12)).to(device) # context for us just cycles throught the mnist labels

        # don't drop context at test time
        context_mask = torch.zeros_like(c_i).to(device)

        # double the batch
        c_i = c_i.repeat(2, 1, 1)
        context_mask = context_mask.repeat(2, 1, 1)
        context_mask[n_sample:] = 1. # makes second half of batch context free

        x_i_store = [] # keep track of generated steps in case want to plot something 
        print()
        for i in range(self.n_T, 0, -1):
            print(f'sampling timestep {i}',end='\r')
            t_is = torch.tensor([i / self.n_T]).to(device)
            t_is = t_is.repeat(n_sample,1,1,1)

            # double batch
            x_i = x_i.repeat(2,1,1,1)
            t_is = t_is.repeat(2,1,1,1)

            z = torch.randn(n_sample, *size).to(device) if i > 1 else 0

            # split predictions and compute weighting
            eps = self.nn_model(x_i, c_i, t_is, context_mask)
            eps1 = eps[:n_sample]
            eps2 = eps[n_sample:]
            
            eps = (1+guide_w)*eps1 - guide_w*eps2
            x_i = x_i[:n_sample]
            x_i = (
                self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i])
                + self.sqrt_beta_t[i] * z
            )
            if i%20==0 or i==self.n_T or i<8:
                x_i_store.append(x_i.detach().cpu().numpy())
        
        x_i_store = np.array(x_i_store)
        return x_i, x_i_store


    def sample_c(self, c_i, n_sample, size, device):
        '''
        this is different than the function sample above
        this always uses classifer guidance for diffusion, so no need to concat 2 versions of 
        dataset or have a guidance scale w. Also context_mask=0 always since no mask used

        taking n_sample samples of EACH datapoint. There are n_datapoint datapoints
        '''
        n_datapoint = c_i.shape[0]

        x_i = torch.randn(n_datapoint*n_sample, *size).to(device)  # x_T ~ N(0, 1), sample initial noise
        
        # repeat c_i n_sample times to make up a row
        c_i = torch.cat([c_i[idx:idx+1].repeat(n_sample, 1, 1) for idx in range(n_datapoint)]).to(device)
        
        # don't drop context at test time. To include context make context_mask all 0's
        context_mask = torch.zeros_like(c_i).to(device)

        x_i_store = [] # keep track of generated steps in case want to plot something 
        print()
        for i in range(self.n_T, 0, -1):
            print(f'sampling timestep {i}',end='\r')
            t_is = torch.tensor([i / self.n_T]).to(device)
            t_is = t_is.repeat(n_datapoint*n_sample,1,1,1)

            z = torch.randn(n_datapoint*n_sample, *size).to(device) if i > 1 else 0

            # compute weighting
            eps = self.nn_model(x_i, c_i, t_is, context_mask)
            
            x_i = (
                self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i])
                + self.sqrt_beta_t[i] * z
            )
            if i%20==0 or i==self.n_T or i<8:
                x_i_store.append(x_i.detach().cpu().numpy())
        
        x_i_store = np.array(x_i_store)
        return x_i, x_i_store
        
        

In [59]:
# @title TESS Dataset
import os
from PIL import Image
import torch
import pickle
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms

class TESSDataset(Dataset):
    def __init__(self, angle_filename):
        
        # get data
        angle_folder = "/pdo/users/jlupoiii/TESS/data/angles/"
        ccd_folder = "/pdo/users/jlupoiii/TESS/data/ccds_background_subtracted/"

        # data matrices
        X = []
        Y = []
        ffi_nums = []

        self.angles_dic = pickle.load(open(angle_folder+angle_filename, "rb"))

        # self.ffi_num_to_orbit_dic = pickle.load(open(angle_folder+'ffi_num_to_orbit_dic.pkl', "rb"))

        for filename in os.listdir(ccd_folder):
            if len(filename) < 40 or filename[27] != '3': continue

            image_arr = pickle.load(open(ccd_folder+filename, "rb"))
            ffi_num = filename[18:18+8]
            try:
                angles = self.angles_dic[ffi_num]
                # print('Got ffi number', ffi_num)
            except:
                # print('Could not find ffi with number:', ffi_num)
                continue
                
            X.append(np.array([angles['1/ED'], angles['1/MD'], angles['1/ED^2'], angles['1/MD^2'], angles['Eel'], angles['Eaz'], angles['Mel'], angles['Maz'], angles['E3el'], angles['E3az'], angles['M3el'], angles['M3az']]))
            Y.append(image_arr.flatten())
            ffi_nums.append(ffi_num)
        
        self.data = [Image.fromarray(x) for x in X]
        self.labels = [Image.fromarray(y) for y in Y]
        self.ffi_nums = ffi_nums
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        angles_image = self.data[idx]
        ffi_image = self.labels[idx]
        ffi_num = self.ffi_nums[idx]
        orbit = self.angles_dic[ffi_num]["orbit"]

        transform = transforms.Compose([
            transforms.ToTensor(),
            lambda s: s.reshape(1, 12)
        ])
        target_transform = transforms.Compose([
            lambda s: np.array(s),
            lambda s: s.reshape((16,16)),
            transforms.ToTensor()
        ])

        angles_image = transform(angles_image)
        ffi_image = target_transform(ffi_image)

        # X: 1x12 vector of angles and distances
        # Y: 16x16 image
        return {"x":angles_image, "y":ffi_image, "ffi_num": ffi_num, "orbit": orbit}


# MAKE DATASET

# we are calculating Y GIVEN X
tess_dataset = TESSDataset('angles_O11-54_data_dic.pkl')
print(len(tess_dataset))
print(tess_dataset[1]['x'].shape)
print(tess_dataset[1]['y'].shape)
print(tess_dataset[1]['ffi_num'])
print(tess_dataset[1]['orbit'])

# # # plt.plot(tess_dataset[1]['x'][0])
# # print('x:', tess_dataset[1]['x'][0])

# # plt.imshow(tess_dataset[1]['y'][0], vmin=0, vmax=1)
# # plt.colorbar()
# # plt.show()
# # plt.close()

# # # displays all datapoints
# # idx = 0
# # for i in range(len(tess_dataset)):
# #     print('------')
# #     print(idx)
# #     idx += 1
# #     print('x:', tess_dataset[i]['x'][0])
# #     plt.imshow(tess_dataset[i]['y'][0], vmin=0, vmax=1)
# #     plt.colorbar()
# #     plt.show()
# #     plt.close()

# print('x:', tess_dataset[9]['x'][0])
# plt.imshow(tess_dataset[9]['y'][0], vmin=0, vmax=1)
# plt.colorbar()
# plt.show()
# plt.close()

25960
torch.Size([1, 12])
torch.Size([1, 16, 16])
00006869
12


In [None]:
def train_TESS(save_dir):
    # hardcoding these here
    n_epoch = 300
    batch_size = 16
    n_T = 600 # 400
    device = "cuda:0"
    n_feat = 256 # 128 ok, 256 better (but slower)
    lrate = 1e-4
    save_model = True
    # save_dir = './model_O13/'
    dataset_filename = "angles_O11-54_data_dic.pkl"
    ws_test = [0.0, 0.5, 2.0] # strength of generative guidance. Not used in sample_c()

    ddpm = DDPM(nn_model=ContextUnet(in_channels=1, n_feat=n_feat), betas=(1e-4, 0.02), n_T=n_T, device=device, drop_prob=0.1)
    ddpm.to(device)

    # optionally load a model
    # ddpm.load_state_dict(torch.load("./data/diffusion_outputs/ddpm_unet01_mnist_9.pth"))

    dataset = TESSDataset(dataset_filename)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=5, drop_last=True)
    optim = torch.optim.Adam(ddpm.parameters(), lr=lrate)

    loss_history = []
    
    for ep in range(n_epoch):
        print(f'epoch {ep}')
        ddpm.train()

        # linear lrate decay
        optim.param_groups[0]['lr'] = lrate*(1-ep/n_epoch)

        pbar = tqdm(dataloader)
        loss_ema = None

        for data_dic in pbar:
            optim.zero_grad()
            x = data_dic['y'].to(device)
            c = data_dic['x'].to(device)
            ffi_nums = data_dic['ffi_num']
            orbits = data_dic['orbit']
            
            # FOR THE GIFS MAKE SURE THEY ARE CAPPED AT 0 TO 1. THATS WHY I CANT SEE THEM

            loss = ddpm(x, c)
            loss.backward()
            if loss_ema is None:
                loss_ema = loss.item()
            else:
                loss_ema = 0.95 * loss_ema + 0.05 * loss.item()
            pbar.set_description(f"loss: {loss_ema:.4f}")
            optim.step()

        loss_history.append(loss_ema)
        
        # for eval, save an image of rows of datapoint predictions. The first column are the real
        # images and the rest are predictions
        ddpm.eval()
        with torch.no_grad():
            # want each row to be for one datapoint, and for there to be n_sample columns
            # want the first column to be the real image
            
            n_datapoint = 10
            n_sample = 5
            if ep%25==0 or ep == int(n_epoch-1):
                # choose the first n_datapoint datapoints to do predictions on
                # The dataloader has shuffle=True so these datapoints are always random
                x_real = x[:n_datapoint]
                c_real = c[:n_datapoint]
                ffi_nums_real = ffi_nums[:n_datapoint]
                orbits_real = orbits[:n_datapoint]
                
                x_gen, x_gen_store = ddpm.sample_c(c_real, n_sample, (1, 16, 16), device)
                x_all = torch.Tensor().to(device)
                for i in range(n_datapoint):
                    x_all = torch.cat([x_all, x_real[i:i+1], x_gen[i*n_sample:(i+1)*n_sample]])

                fig, axes = plt.subplots(n_datapoint, n_sample+1, figsize=(15, 30))
                plt.subplots_adjust(top=1.7)
                for idx in range(x_all.shape[0]):
                    image = x_all[idx, 0, :, :].cpu().detach().numpy()
                    axes[idx//(n_sample+1), idx%(n_sample+1)].imshow(image, cmap='gray', vmin=0, vmax=1)
                    axes[idx//(n_sample+1), idx%(n_sample+1)].axis('off')

                # set labels for sampled columns
                for i in range(n_sample):
                    axes[0, i+1].set_title(f"Sample {i+1} \n ", fontsize=12)

                # set labels for each datapoint
                for j in range(n_datapoint):
                    data_title = f"O{orbits_real[j]} , ffi {ffi_nums_real[j]}"
                    if j==0: data_title = f"Original\n{data_title}"
                    axes[j, 0].set_title(data_title, fontsize=12)

                # Sets title for whole figure
                fig.suptitle(f"Predictions for epoch {ep}", fontsize = 25)

                # save images
                plt.tight_layout()
                fig.savefig(save_dir + f"image_ep{ep}.pdf")
                print('saved image at ' + save_dir + f"image_ep{ep}.pdf")
                plt.close()
                

                # save loss graph
                plt.plot(loss_history)
                plt.xlabel('Epoch')
                plt.ylabel('MSE Loss')
                plt.title('MSE Loss over Epochs')
                plt.savefig(os.path.join(save_dir, 'loss_graph.png'))
                plt.close()

                # # create gif of images evolving over time, based on x_gen_store
                # fig, axs = plt.subplots(nrows=int(n_datapoint), ncols=n_sample,sharex=True,sharey=True, figsize=(3, 8))# ,figsize=(8,3))
                # def animate_diff(i, x_gen_store):
                    
                #     print('max and min', np.max(x_gen_store), np.min(x_gen_store))
                    
                #     # x_gen_store_clipped = np.clip(x_gen_store, 0, 1)
                    
                #     print(f'gif animating frame {i} of {x_gen_store.shape[0]}', end='\r')
                #     plots = []
                #     for row in range(n_datapoint):
                #         for col in range(n_sample):
                #             axs[row, col].clear()
                #             axs[row, col].set_xticks([])
                #             axs[row, col].set_yticks([])
                #             # plots.append(axs[row, col].imshow(x_gen_store[i,(row*10)+col,0],cmap='gray'))
                #             plots.append(axs[row, col].imshow(-x_gen_store[i,(row*n_sample)+col,0],cmap='gray',vmin=0, vmax=1))
                #     return plots
                # # ani = FuncAnimation(fig, animate_diff, fargs=[x_gen_store],  interval=200, blit=False, repeat=True, frames=x_gen_store.shape[0])    
                # ani = FuncAnimation(fig, animate_diff, fargs=[np.clip(x_gen_store, 0, 1)],  interval=200, blit=False, repeat=True, frames=x_gen_store.shape[0]) 
                # ani.save(save_dir + f"gif_ep{ep}.gif", dpi=100, writer=PillowWriter(fps=5))
                # print('saved gif at ' + save_dir + f"gif_ep{ep}.gif")
        
        # optionally save model
        if save_model and ep == int(n_epoch-1):
            torch.save(ddpm.state_dict(), save_dir + f"model_{ep}.pth")
            print('saved model at ' + save_dir + f"model_{ep}.pth")

if __name__ == "__main__":
    save_dir = 'model_TESS_O11-54_new/'
    os.makedirs(save_dir, exist_ok=True)
    train_TESS(save_dir)


epoch 0


loss: 0.0117: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 52.72it/s]



saved image at model_TESS_O11-54_new/image_ep0.pdf
epoch 1


loss: 0.0089: 100%|█████████████████████████| 1622/1622 [00:32<00:00, 49.52it/s]


epoch 2


loss: 0.0059: 100%|█████████████████████████| 1622/1622 [00:34<00:00, 47.67it/s]


epoch 3


loss: 0.0058: 100%|█████████████████████████| 1622/1622 [00:32<00:00, 49.62it/s]


epoch 4


loss: 0.0049: 100%|█████████████████████████| 1622/1622 [00:33<00:00, 48.37it/s]


epoch 5


loss: 0.0031: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 52.54it/s]


epoch 6


loss: 0.0032: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 53.89it/s]


epoch 7


loss: 0.0038: 100%|█████████████████████████| 1622/1622 [00:33<00:00, 49.08it/s]


epoch 8


loss: 0.0030: 100%|█████████████████████████| 1622/1622 [00:31<00:00, 51.54it/s]


epoch 9


loss: 0.0026: 100%|█████████████████████████| 1622/1622 [00:26<00:00, 61.00it/s]


epoch 10


loss: 0.0031: 100%|█████████████████████████| 1622/1622 [00:31<00:00, 51.93it/s]


epoch 11


loss: 0.0038: 100%|█████████████████████████| 1622/1622 [00:28<00:00, 56.33it/s]


epoch 12


loss: 0.0024: 100%|█████████████████████████| 1622/1622 [00:31<00:00, 51.16it/s]


epoch 13


loss: 0.0034: 100%|█████████████████████████| 1622/1622 [00:32<00:00, 50.36it/s]


epoch 14


loss: 0.0023: 100%|█████████████████████████| 1622/1622 [00:29<00:00, 55.26it/s]


epoch 15


loss: 0.0023: 100%|█████████████████████████| 1622/1622 [00:31<00:00, 51.04it/s]


epoch 16


loss: 0.0024: 100%|█████████████████████████| 1622/1622 [00:33<00:00, 48.50it/s]


epoch 17


loss: 0.0024: 100%|█████████████████████████| 1622/1622 [00:29<00:00, 54.33it/s]


epoch 18


loss: 0.0021: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 53.42it/s]


epoch 19


loss: 0.0021: 100%|█████████████████████████| 1622/1622 [00:32<00:00, 50.13it/s]


epoch 20


loss: 0.0028: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 52.89it/s]


epoch 21


loss: 0.0014: 100%|█████████████████████████| 1622/1622 [00:28<00:00, 56.62it/s]


epoch 22


loss: 0.0020: 100%|█████████████████████████| 1622/1622 [00:28<00:00, 56.91it/s]


epoch 23


loss: 0.0016: 100%|█████████████████████████| 1622/1622 [00:28<00:00, 57.09it/s]


epoch 24


loss: 0.0016: 100%|█████████████████████████| 1622/1622 [00:32<00:00, 50.51it/s]


epoch 25


loss: 0.0023: 100%|█████████████████████████| 1622/1622 [00:29<00:00, 54.39it/s]



saved image at model_TESS_O11-54_new/image_ep25.pdf
epoch 26


loss: 0.0019: 100%|█████████████████████████| 1622/1622 [00:31<00:00, 51.79it/s]


epoch 27


loss: 0.0016: 100%|█████████████████████████| 1622/1622 [00:29<00:00, 54.90it/s]


epoch 28


loss: 0.0019: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 53.12it/s]


epoch 29


loss: 0.0026: 100%|█████████████████████████| 1622/1622 [00:26<00:00, 60.33it/s]


epoch 30


loss: 0.0012: 100%|█████████████████████████| 1622/1622 [00:32<00:00, 50.31it/s]


epoch 31


loss: 0.0019: 100%|█████████████████████████| 1622/1622 [00:32<00:00, 49.63it/s]


epoch 32


loss: 0.0014: 100%|█████████████████████████| 1622/1622 [00:32<00:00, 49.65it/s]


epoch 33


loss: 0.0026: 100%|█████████████████████████| 1622/1622 [00:33<00:00, 48.01it/s]


epoch 34


loss: 0.0014: 100%|█████████████████████████| 1622/1622 [00:31<00:00, 51.34it/s]


epoch 35


loss: 0.0017: 100%|█████████████████████████| 1622/1622 [00:29<00:00, 55.24it/s]


epoch 36


loss: 0.0012: 100%|█████████████████████████| 1622/1622 [00:32<00:00, 50.18it/s]


epoch 37


loss: 0.0015: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 53.95it/s]


epoch 38


loss: 0.0016: 100%|█████████████████████████| 1622/1622 [00:32<00:00, 49.36it/s]


epoch 39


loss: 0.0014: 100%|█████████████████████████| 1622/1622 [00:32<00:00, 50.00it/s]


epoch 40


loss: 0.0013: 100%|█████████████████████████| 1622/1622 [00:28<00:00, 56.61it/s]


epoch 41


loss: 0.0022: 100%|█████████████████████████| 1622/1622 [00:32<00:00, 50.43it/s]


epoch 42


loss: 0.0019: 100%|█████████████████████████| 1622/1622 [00:31<00:00, 51.07it/s]


epoch 43


loss: 0.0015: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 54.00it/s]


epoch 44


loss: 0.0012: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 54.03it/s]


epoch 45


loss: 0.0023: 100%|█████████████████████████| 1622/1622 [00:31<00:00, 51.37it/s]


epoch 46


loss: 0.0013: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 52.89it/s]


epoch 47


loss: 0.0013: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 52.45it/s]


epoch 48


loss: 0.0015: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 53.00it/s]


epoch 49


loss: 0.0020: 100%|█████████████████████████| 1622/1622 [00:29<00:00, 54.27it/s]


epoch 50


loss: 0.0011: 100%|█████████████████████████| 1622/1622 [00:29<00:00, 54.20it/s]



saved image at model_TESS_O11-54_new/image_ep50.pdf
epoch 51


loss: 0.0012: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 52.52it/s]


epoch 52


loss: 0.0012: 100%|█████████████████████████| 1622/1622 [00:32<00:00, 50.57it/s]


epoch 53


loss: 0.0015: 100%|█████████████████████████| 1622/1622 [00:28<00:00, 57.16it/s]


epoch 54


loss: 0.0012: 100%|█████████████████████████| 1622/1622 [00:34<00:00, 46.94it/s]


epoch 55


loss: 0.0010: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 52.74it/s]


epoch 56


loss: 0.0011: 100%|█████████████████████████| 1622/1622 [00:27<00:00, 58.34it/s]


epoch 57


loss: 0.0014: 100%|█████████████████████████| 1622/1622 [00:27<00:00, 59.25it/s]


epoch 58


loss: 0.0015: 100%|█████████████████████████| 1622/1622 [00:31<00:00, 51.23it/s]


epoch 59


loss: 0.0010: 100%|█████████████████████████| 1622/1622 [00:33<00:00, 48.54it/s]


epoch 60


loss: 0.0011: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 52.64it/s]


epoch 61


loss: 0.0010: 100%|█████████████████████████| 1622/1622 [00:32<00:00, 49.59it/s]


epoch 62


loss: 0.0011: 100%|█████████████████████████| 1622/1622 [00:31<00:00, 50.90it/s]


epoch 63


loss: 0.0019: 100%|█████████████████████████| 1622/1622 [00:32<00:00, 50.33it/s]


epoch 64


loss: 0.0010: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 53.49it/s]


epoch 65


loss: 0.0010: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 52.54it/s]


epoch 66


loss: 0.0013: 100%|█████████████████████████| 1622/1622 [00:33<00:00, 48.54it/s]


epoch 67


loss: 0.0010: 100%|█████████████████████████| 1622/1622 [00:31<00:00, 50.98it/s]


epoch 68


loss: 0.0010: 100%|█████████████████████████| 1622/1622 [00:31<00:00, 51.44it/s]


epoch 69


loss: 0.0008: 100%|█████████████████████████| 1622/1622 [00:32<00:00, 50.37it/s]


epoch 70


loss: 0.0009: 100%|█████████████████████████| 1622/1622 [00:32<00:00, 50.47it/s]


epoch 71


loss: 0.0013: 100%|█████████████████████████| 1622/1622 [00:31<00:00, 51.26it/s]


epoch 72


loss: 0.0009: 100%|█████████████████████████| 1622/1622 [00:32<00:00, 49.32it/s]


epoch 73


loss: 0.0010: 100%|█████████████████████████| 1622/1622 [00:31<00:00, 51.70it/s]


epoch 74


loss: 0.0013: 100%|█████████████████████████| 1622/1622 [00:31<00:00, 51.72it/s]


epoch 75


loss: 0.0009: 100%|█████████████████████████| 1622/1622 [00:29<00:00, 54.81it/s]



saved image at model_TESS_O11-54_new/image_ep75.pdf
epoch 76


loss: 0.0008: 100%|█████████████████████████| 1622/1622 [00:31<00:00, 51.57it/s]


epoch 77


loss: 0.0015: 100%|█████████████████████████| 1622/1622 [00:32<00:00, 49.65it/s]


epoch 78


loss: 0.0009: 100%|█████████████████████████| 1622/1622 [00:32<00:00, 49.92it/s]


epoch 79


loss: 0.0009: 100%|█████████████████████████| 1622/1622 [00:33<00:00, 48.78it/s]


epoch 80


loss: 0.0008: 100%|█████████████████████████| 1622/1622 [00:33<00:00, 47.87it/s]


epoch 81


loss: 0.0012: 100%|█████████████████████████| 1622/1622 [00:31<00:00, 51.00it/s]


epoch 82


loss: 0.0008: 100%|█████████████████████████| 1622/1622 [00:32<00:00, 49.94it/s]


epoch 83


loss: 0.0008: 100%|█████████████████████████| 1622/1622 [00:33<00:00, 47.84it/s]


epoch 84


loss: 0.0009: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 52.46it/s]


epoch 85


loss: 0.0007: 100%|█████████████████████████| 1622/1622 [00:33<00:00, 49.03it/s]


epoch 86


loss: 0.0009: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 53.61it/s]


epoch 87


loss: 0.0007: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 52.86it/s]


epoch 88


loss: 0.0009: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 52.60it/s]


epoch 89


loss: 0.0006: 100%|█████████████████████████| 1622/1622 [00:32<00:00, 50.05it/s]


epoch 90


loss: 0.0007: 100%|█████████████████████████| 1622/1622 [00:31<00:00, 51.36it/s]


epoch 91


loss: 0.0024: 100%|█████████████████████████| 1622/1622 [00:33<00:00, 48.67it/s]


epoch 92


loss: 0.0013: 100%|█████████████████████████| 1622/1622 [00:33<00:00, 48.93it/s]


epoch 93


loss: 0.0006: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 53.56it/s]


epoch 94


loss: 0.0005: 100%|█████████████████████████| 1622/1622 [00:32<00:00, 49.26it/s]


epoch 95


loss: 0.0007: 100%|█████████████████████████| 1622/1622 [00:28<00:00, 56.57it/s]


epoch 96


loss: 0.0005: 100%|█████████████████████████| 1622/1622 [00:34<00:00, 47.25it/s]


epoch 97


loss: 0.0011: 100%|█████████████████████████| 1622/1622 [00:32<00:00, 50.56it/s]


epoch 98


loss: 0.0007: 100%|█████████████████████████| 1622/1622 [00:31<00:00, 51.07it/s]


epoch 99


loss: 0.0010: 100%|█████████████████████████| 1622/1622 [00:33<00:00, 48.05it/s]


epoch 100


loss: 0.0008: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 53.23it/s]



saved image at model_TESS_O11-54_new/image_ep100.pdf
epoch 101


loss: 0.0009: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 52.87it/s]


epoch 102


loss: 0.0009: 100%|█████████████████████████| 1622/1622 [00:28<00:00, 57.40it/s]


epoch 103


loss: 0.0010: 100%|█████████████████████████| 1622/1622 [00:29<00:00, 54.48it/s]


epoch 104


loss: 0.0009: 100%|█████████████████████████| 1622/1622 [00:29<00:00, 54.76it/s]


epoch 105


loss: 0.0008: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 52.47it/s]


epoch 106


loss: 0.0008: 100%|█████████████████████████| 1622/1622 [00:33<00:00, 48.72it/s]


epoch 107


loss: 0.0009: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 53.18it/s]


epoch 108


loss: 0.0012: 100%|█████████████████████████| 1622/1622 [00:34<00:00, 46.88it/s]


epoch 109


loss: 0.0008: 100%|█████████████████████████| 1622/1622 [00:33<00:00, 48.68it/s]


epoch 110


loss: 0.0010: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 53.91it/s]


epoch 111


loss: 0.0008: 100%|█████████████████████████| 1622/1622 [00:32<00:00, 50.41it/s]


epoch 112


loss: 0.0009: 100%|█████████████████████████| 1622/1622 [00:31<00:00, 51.35it/s]


epoch 113


loss: 0.0007: 100%|█████████████████████████| 1622/1622 [00:32<00:00, 50.19it/s]


epoch 114


loss: 0.0008: 100%|█████████████████████████| 1622/1622 [00:33<00:00, 48.08it/s]


epoch 115


loss: 0.0008: 100%|█████████████████████████| 1622/1622 [00:27<00:00, 59.13it/s]


epoch 116


loss: 0.0008: 100%|█████████████████████████| 1622/1622 [00:31<00:00, 52.15it/s]


epoch 117


loss: 0.0006: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 53.36it/s]


epoch 118


loss: 0.0006: 100%|█████████████████████████| 1622/1622 [00:26<00:00, 60.38it/s]


epoch 119


loss: 0.0006: 100%|█████████████████████████| 1622/1622 [00:29<00:00, 54.72it/s]


epoch 120


loss: 0.0009: 100%|█████████████████████████| 1622/1622 [00:31<00:00, 51.91it/s]


epoch 121


loss: 0.0006: 100%|█████████████████████████| 1622/1622 [00:31<00:00, 51.17it/s]


epoch 122


loss: 0.0010: 100%|█████████████████████████| 1622/1622 [00:31<00:00, 51.00it/s]


epoch 123


loss: 0.0008: 100%|█████████████████████████| 1622/1622 [00:24<00:00, 67.32it/s]


epoch 124


loss: 0.0006: 100%|█████████████████████████| 1622/1622 [00:28<00:00, 56.89it/s]


epoch 125


loss: 0.0006: 100%|█████████████████████████| 1622/1622 [00:29<00:00, 54.35it/s]


epoch 216


loss: 0.0004: 100%|█████████████████████████| 1622/1622 [00:25<00:00, 62.89it/s]


epoch 217


loss: 0.0005:  45%|███████████▊              | 738/1622 [00:13<00:17, 51.59it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

loss: 0.0006: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 52.46it/s]


epoch 238


loss: 0.0005: 100%|█████████████████████████| 1622/1622 [00:31<00:00, 51.98it/s]


epoch 239


loss: 0.0004: 100%|█████████████████████████| 1622/1622 [00:26<00:00, 61.48it/s]


epoch 240


loss: 0.0004: 100%|█████████████████████████| 1622/1622 [00:27<00:00, 58.83it/s]


epoch 241


loss: 0.0004: 100%|█████████████████████████| 1622/1622 [00:26<00:00, 60.09it/s]


epoch 242


loss: 0.0004:  39%|██████████▏               | 633/1622 [00:09<00:18, 54.08it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

loss: 0.0005: 100%|█████████████████████████| 1622/1622 [00:24<00:00, 65.99it/s]


epoch 255


loss: 0.0004: 100%|█████████████████████████| 1622/1622 [00:31<00:00, 51.08it/s]


epoch 256


loss: 0.0005: 100%|█████████████████████████| 1622/1622 [00:28<00:00, 56.39it/s]


epoch 258


loss: 0.0004:  69%|█████████████████▏       | 1119/1622 [00:21<00:10, 47.58it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

loss: 0.0004: 100%|█████████████████████████| 1622/1622 [00:27<00:00, 59.58it/s]


epoch 273


loss: 0.0004: 100%|█████████████████████████| 1622/1622 [00:31<00:00, 51.34it/s]


epoch 274


loss: 0.0003: 100%|█████████████████████████| 1622/1622 [00:27<00:00, 59.27it/s]


epoch 275


loss: 0.0004: 100%|█████████████████████████| 1622/1622 [00:26<00:00, 60.17it/s]



saved image at model_TESS_O11-54_new/image_ep275.pdf
epoch 276


loss: 0.0003: 100%|█████████████████████████| 1622/1622 [00:29<00:00, 55.79it/s]


epoch 277


loss: 0.0004:   5%|█▎                         | 81/1622 [00:02<00:30, 51.04it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

loss: 0.0004: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 53.83it/s]


epoch 289


loss: 0.0004: 100%|█████████████████████████| 1622/1622 [00:30<00:00, 53.26it/s]


epoch 290


loss: 0.0004: 100%|█████████████████████████| 1622/1622 [00:29<00:00, 54.85it/s]


epoch 291


loss: 0.0004:  59%|███████████████▎          | 953/1622 [00:19<00:13, 51.01it/s]