In [None]:
# this implicitly imports numpy, pandas, etc
from fastai.basics import *
from fastai.vision.all import *

In [None]:
from unet import Unet
from copy import deepcopy

In [None]:
import time
import torch
import torch.nn.functional as F
from torch.nn.parallel import DataParallel
from matplotlib import pyplot as plt
from PIL import Image

# Mac users may need device = 'mps' (untested)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
!nvidia-smi

In [None]:
device

In [None]:
torch.cuda.device_count()

In [None]:
print(torch.version.cuda)

In [None]:
import torchvision
from torch.utils.data import Dataset
from torchvision import transforms

image_size = 64
sample_percentage = 10
valid_bptclasses = [1, 2, 3]
global_num_quantiles = 6

tab_df = pd.read_csv("agn_cleaned.csv")

tab_df = tab_df[
    (tab_df['oh_p50'] > 0) & 
    (tab_df['lgm_tot_p50'] > 0) &
    (tab_df['sfr_tot_p50'] > -10) &
    (tab_df['bptclass'].isin(valid_bptclasses))
]

tab_df = tab_df[:int(sample_percentage / 100 * tab_df.shape[0])]
num_samples = tab_df.shape[0]

objIDs = tab_df['objID']

In [None]:
tab_df.describe()

In [None]:
tab_dataset = np.column_stack((tab_df['z'].values, 
                               tab_df['lgm_tot_p50'].values, 
                               tab_df['sfr_tot_p50'].values, 
                               tab_df['oh_p50'].values))

#NOTE: Do not use tab_dataset to extract the object IDs. Instead, use objIDs above.

print(tab_dataset.shape)

In [None]:
mean = np.mean(tab_dataset, axis=0)
std = np.std(tab_dataset, axis=0)

# Get the lower and upper bounds for each column
lower_bound = mean - 4 * std
upper_bound = mean + 4 * std

# Apply the bounds to each value in the dataset
for i in range(tab_dataset.shape[1]):
    tab_dataset[:, i] = np.clip(tab_dataset[:, i], lower_bound[i], upper_bound[i])

# Normalizing the tab_dataset
tab_dataset = (tab_dataset - mean) / std

In [None]:
im_dataset = np.zeros((num_samples, image_size, image_size, 3))

# loading and resizing the images
for index in range(num_samples): 
    objID = objIDs.iloc[index]
    image_path = f"images-sdss/{objID}.jpg"
    img = Image.open(image_path)
    img = img.resize((image_size, image_size))
    im_dataset[index] = img

# Scale pixel values between 0 and 1
#im_dataset = im_dataset / 255.

im_dataset = np.array(im_dataset)
im_dataset = im_dataset.astype(np.uint8)
print(im_dataset[0])

im_dataset.shape, tab_dataset.shape

In [None]:
z_values = tab_dataset[:, 0]
lgm_tot_p50_values = tab_dataset[:, 1]
sfr_tot_p50_values = tab_dataset[:, 2]
oh_p50_values = tab_dataset[:, 3]

#---------------------- PLOTTING CODE --------------------------------
# Create subplots
fig, axs = plt.subplots(2, 2, figsize=(10,10))

# z values histogram
axs[0, 0].hist(z_values, bins='auto', color='#0504aa', alpha=0.7, rwidth=0.85)
axs[0, 0].set_title('z values')
axs[0, 0].set_xlabel('Value')
axs[0, 0].set_ylabel('Frequency')

# lgm_tot_p50 values histogram
axs[0, 1].hist(lgm_tot_p50_values, bins='auto', color='#0504aa', alpha=0.7, rwidth=0.85)
axs[0, 1].set_title('lgm_tot_p50 values')
axs[0, 1].set_xlabel('Value')
axs[0, 1].set_ylabel('Frequency')

# sfr_tot_p50 values histogram
axs[1, 0].hist(sfr_tot_p50_values, bins='auto', color='#0504aa', alpha=0.7, rwidth=0.85)
axs[1, 0].set_title('sfr_tot_p50 values')
axs[1, 0].set_xlabel('Value')
axs[1, 0].set_ylabel('Frequency')

# oh_p50 values histogram
axs[1, 1].hist(oh_p50_values, bins='auto', color='#0504aa', alpha=0.7, rwidth=0.85)
axs[1, 1].set_title('oh_p50 values')
axs[1, 1].set_xlabel('Value')
axs[1, 1].set_ylabel('Frequency')

# Display the plots
plt.tight_layout()
plt.show()

#---------------------------------------------------------------------

In [None]:
# Compute the average and standard deviation of oh_p50 values
average_oh_p50 = np.mean(oh_p50_values)
std_oh_p50 = np.std(oh_p50_values)

average_oh_p50, std_oh_p50

In [None]:
tab_dataset[:10]

In [None]:
# train-validation split (80%/20%)
np.random.seed(42)

N = len(im_dataset)
indices = np.random.permutation(N)
train_idxs = indices[:int(0.8*N)]
valid_idxs = indices[int(0.8*N):]

In [None]:
batch_size = 128 # batch size
epochs = 100
image_display_interval = 20

def get_x(i):
    return im_dataset[i]

def get_y(i):
    return tab_dataset[i]

# create DataBlock
dblock = DataBlock(
    blocks=(ImageBlock, RegressionBlock),
    get_x=get_x,
    get_y=get_y,
    splitter=IndexSplitter(valid_idxs), # use your existing validation set
    item_tfms=Resize(image_size),
    #batch_tfms=[Normalize.from_stats(0.5, 0.5), *aug_transforms(do_flip=True, flip_vert=True, max_rotate=10.0, max_zoom=1.1, max_lighting=0.2, max_warp=0.2, p_affine=0.75, p_lighting=0.75, xtra_tfms=None)]
    batch_tfms=Normalize.from_stats(0.5, 0.5)
)

In [None]:
# create dataloaders
dls = dblock.dataloaders(range(len(im_dataset)), bs=batch_size)

In [None]:
xb, yb = dls.one_batch()
xb.max(), xb.min(), xb.mean(), xb.std()

In [None]:
print(yb.shape)

In [None]:
class ConditionalDDPMCallback(Callback):
    def __init__(self, n_steps, beta_min, beta_max, num_conditioned_properties, targets, cfg_scale=0):
        store_attr()
        self.tensor_type=TensorImage

    def before_fit(self):
        self.beta = torch.linspace(self.beta_min, self.beta_max, self.n_steps).to(self.dls.device) # variance schedule, linearly increased with timestep
        self.alpha = 1. - self.beta 
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        self.sigma = torch.sqrt(self.beta)
    
    def sample_timesteps(self, x, dtype=torch.long):
        return torch.randint(self.n_steps, (x.shape[0],), device=x.device, dtype=dtype)
    
    def generate_noise(self, x):
        return self.tensor_type(torch.randn_like(x))
    
    def noise_image(self, x, eps, t):
        alpha_bar_t = self.alpha_bar[t][:, None, None, None]
        return torch.sqrt(alpha_bar_t)*x + torch.sqrt(1-alpha_bar_t)*eps # noisify the image
    
    def before_batch_training(self):
        x0 = self.xb[0] # original images and labels
        y0 =  self.yb[0] if np.random.random() > 0.1 else None
        
        # y0 = None
        
        eps = self.generate_noise(x0) # noise same shape as x0
        t =  self.sample_timesteps(x0) # select random timesteps
        xt =  self.noise_image(x0, eps, t)  # add noise to the image
        # print(x0.shape, y0.shape, t.shape, xt.shape, eps.shape)
        
        self.learn.xb = (xt, t, y0) # input to our model is noisy image, timestep and label
        self.learn.yb = (eps,) # ground truth is the noise 

    def sampling_algo(self, xt, t, train_targets=None):
        t_batch = torch.full((xt.shape[0],), t, device=xt.device, dtype=torch.long)
        z = self.generate_noise(xt) if t > 0 else torch.zeros_like(xt)
        alpha_t = self.alpha[t] # get noise level at current timestep
        alpha_bar_t = self.alpha_bar[t]
        sigma_t = self.sigma[t]
        alpha_bar_t_1 = self.alpha_bar[t-1]  if t > 0 else torch.tensor(1, device=xt.device)
        beta_bar_t = 1 - alpha_bar_t
        beta_bar_t_1 = 1 - alpha_bar_t_1
        predicted_noise = self.model(xt, t_batch, targets=train_targets)
        if self.cfg_scale>0:
            uncond_predicted_noise = self.model(xt, t_batch, targets=None)
            predicted_noise = torch.lerp(uncond_predicted_noise, predicted_noise, self.cfg_scale)
        x0hat = (xt - torch.sqrt(beta_bar_t) * predicted_noise)/torch.sqrt(alpha_bar_t)
        x0hat = torch.clamp(x0hat, -1, 1)
        xt = x0hat * torch.sqrt(alpha_bar_t_1)*(1-alpha_t)/beta_bar_t + xt * torch.sqrt(alpha_t)*beta_bar_t_1/beta_bar_t + sigma_t*z 

        return xt
    
    # def sampling_algo_old(self, xt, t, label=None):
    #     t_batch = torch.full((xt.shape[0],), t, device=xt.device, dtype=torch.long)
    #     z = self.generate_noise(xt) if t > 0 else torch.zeros_like(xt)
    #     alpha_t = self.alpha[t] # get noise level at current timestep
    #     alpha_bar_t = self.alpha_bar[t]
    #     sigma_t = self.sigma[t]
    #     xt = 1/torch.sqrt(alpha_t) * (xt - (1-alpha_t)/torch.sqrt(1-alpha_bar_t) * self.model(xt, t_batch, label=label)) + sigma_t*z 
    #          1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
    #     # predict x_(t-1) in accordance to Algorithm 2 in paper
    #     return xt
    
    def sample(self):
        # Randomly generate batch_size property tuples here
        xt = self.generate_noise(self.xb[0]) # a full batch at once! 
        self.targets = torch.randn(xt.shape[0], self.num_conditioned_properties)
        # sort by metallicity
        self.targets = self.targets[self.targets[:,3].argsort()]
        for t in progress_bar(reversed(range(self.n_steps)), total=self.n_steps, leave=False):
            xt = self.sampling_algo(xt, t, self.targets) 
        return xt
    
    def before_batch_sampling(self):
        xt = self.sample()
        self.learn.pred = (xt,)
        raise CancelBatchException
    
    def after_validate(self):
        if (self.epoch+1) % image_display_interval == 0:
            with torch.no_grad():
                xt = self.sample()
                wandb.log({"preds": [wandb.Image(torch.tensor(im)) for im in xt[0:36]]})
    
    def before_batch(self):
        if not hasattr(self, 'gather_preds'): self.before_batch_training()
        else: self.before_batch_sampling()

In [None]:
class EMA(Callback):
    "Exponential Moving average CB"
    def __init__(self, beta=0.995, pct_start=0.3):
        store_attr()
        
    
    def before_fit(self):
        self.ema_model = deepcopy(self.model).eval().requires_grad_(False)
        self.step_start_ema = int(self.pct_start*self.n_epoch)  #start EMA at 30% of epochs
        
    def update_model_average(self):
        for current_params, ma_params in zip(self.model.parameters(), self.ema_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        return old * self.beta + (1 - self.beta) * new

    def step_ema(self):
        if self.epoch < self.step_start_ema:
            self.reset_parameters()
            self.step += 1
            return
        self.update_model_average()
        self.step += 1

    def reset_parameters(self):
        self.ema_model.load_state_dict(self.model.state_dict())
    
    def after_batch(self):
        if hasattr(self, 'pred'): return
        self.step_ema()
    
    def after_training(self):
        self.model = self.ema_model

In [None]:
@delegates(Unet)
class ConditionalUnet(Unet):
    def __init__(self, dim, num_conditioned_properties=None, **kwargs):
        super().__init__(dim=dim, **kwargs)
        if num_conditioned_properties is not None:
            self.target_MLP = nn.Linear(num_conditioned_properties, dim * 4)
    
    def forward(self, x, time, targets=None):
        x = self.init_conv(x)
        t = self.time_mlp(time)
        if targets is not None:
            out = self.target_MLP(targets)
            t += out
            
        return super().forward_blocks(x, t)

In [None]:
model = ConditionalUnet(dim=128, channels=3, num_conditioned_properties=4)
model.to(device);
model = DataParallel(model);

In [None]:
import wandb
from fastai.callback.wandb import WandbCallback

In [None]:
num_timesteps = 1000

ddpm_learner = Learner(dls, model, 
                       cbs=[ConditionalDDPMCallback(n_steps=num_timesteps, beta_min=0.0001, beta_max=0.02, num_conditioned_properties=4, targets=tab_dataset, cfg_scale=3),
                            EMA()], 
                       #If the above breaks, change targets
                       loss_func=nn.L1Loss())

In [None]:
ddpm_learner.lr_find()

In [None]:
wandb.init(project="cond_ddpm_sdss", group="chesapeake_ml", tags=["ddpm", "ema"])

In [None]:
ddpm_learner.fit_one_cycle(epochs, 1e-4, cbs =[SaveModelCallback(monitor="train_loss", fname="cond_ddpm_sdss"), 
                                           WandbCallback(log_preds=False, log_model=True)])

In [None]:
ddpm_learner.recorder.plot_loss()

In [None]:
preds = ddpm_learner.get_preds()

In [None]:
wandb.Image(torch.tensor(0.5*preds[0][0]+0.5)).image

In [None]:
p = preds[0]

In [None]:
p.shape

In [None]:
p.mean(dim=(0,2,3))

In [None]:
nrows = 5
ncols = int(math.ceil(25/10))
axs = subplots(nrows, 10)[1].flat
for i, (pred, ax) in enumerate(zip(preds[0], axs)): 
    ((pred+1)/2).show(ax=ax, title=None)