In [None]:
import logger
import fp16_util
import unet
import gaussian_diffusion
import script_util
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import torchvision
import os
import matplotlib.pyplot as plt
import torchvision.models as models
from tqdm import tqdm
device = torch.device('cuda:0')
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 

In [None]:
model = torch.load('unet/model/path')
classifier = torch.load('classifier/model/path')

In [None]:
transform = transforms.Compose([
            transforms.Resize((256,256)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

In [None]:
import torch.nn.functional as F
import math

def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    return torch.linspace(start, end, timesteps)

def betas_cos(n_steps,max_beta=0.999):
        # Cosine Noise Scheme Generation Beta
        betas = []
        alpha_bar = lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
        for i in range(n_steps):
            t1 = i / n_steps
            t2 = (i + 1) / n_steps
            betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
        return torch.Tensor(betas)



# Retrieve the data with index t from `vals`, and reshape it into the required format.
def get_index_from_list(vals, t, x_shape):


    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())

    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def forward_diffusion_sample(x_0, t, device=device):


    left_img,middle_img,right_img = torch.chunk(x_0,chunks=3,dim=1)

    noise = torch.randn_like(middle_img)
    
    sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, middle_img.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, middle_img.shape
    )
    return sqrt_alphas_cumprod_t.to(device) * middle_img.to(device) \
    + 0.5*sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)


T = 1000
beta = betas_cos(T)

In [None]:
alphas = 1. - beta
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = beta * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

eta = 0

def get_sigma(eta):
    sigma = (
        
            eta
            * torch.sqrt((1 - alphas_cumprod_prev) / (1 - alphas_cumprod))
            * torch.sqrt(1 - alphas_cumprod / alphas_cumprod_prev)
        )

    return sigma.to(t.device)

In [None]:
#Convert the tensor to an image.
def show_tensor_image(image):
    reverse_transforms = transforms.Compose([
        transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
        transforms.Lambda(lambda t: t * 255.),
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
        transforms.ToPILImage(),
    ])

    # Take first image of batch
    if len(image.shape) == 4:
        image = image[0, :, :, :] 
    plt.imshow(reverse_transforms(image.cpu()))

In [None]:
#classfier_guidence
def cond_fn(x, y=1,t=None,s=0):
    assert y is not None
    with torch.enable_grad():
        x_in = x.detach().requires_grad_(True)
        logits = classifier(x_in,t)
        log_probs = F.log_softmax(logits, dim=-1)
        selected = log_probs[range(len(logits)), y.view(-1)]
        return torch.autograd.grad(selected.sum(), x_in)[0] * s

In [None]:
@torch.no_grad()
def sample_timestep(eta, x, t,y=None):

    noise_pred = model(x, t,y)
    left_img,x,right_img = torch.chunk(x,chunks=3,dim=1)
    img = trans_pil(left_img[0])
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    ).to(device)

    alphas_comprod_t = get_index_from_list(alphas_cumprod,t,x.shape).to(device)
    alphas_comprod_prev_t = get_index_from_list(alphas_cumprod_prev,t,x.shape).to(device)
    sigma = get_sigma(eta).to(device)
    sigma_t =  get_index_from_list(sigma.cpu(), t, x.shape).to(device)
    

    x.to(device)

    prev_x = torch.sqrt(alphas_comprod_prev_t) * (
        (x.to(device) - torch.sqrt(1 - alphas_comprod_t) * noise_pred) /torch.sqrt(alphas_comprod_t)
     ) + torch.sqrt(1- alphas_comprod_prev_t - sigma_t**2)*noise_pred 
    
    
    posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
    if t == 0:
        return prev_x
    else:
        noise = torch.randn_like(x).to(device)
        return prev_x + sigma_t * noise 

    
import random

    
@torch.no_grad()
def sample_plot_image(eta,num,y=None):
    img_size = IMG_SIZE
    img = torch.randn((1, 3, img_size, img_size), device=device)

    plt.figure(figsize=(15,15))
    plt.axis('off')
    num_images = 10
    stepsize = int(T/num_images)
    for i in range(0,T)[::-1]:
        t = torch.full((1,),i, device=device, dtype=torch.long)
        img = torch.cat((left_img.to(device),img,right_img.to(device)),dim=1)
        img = sample_timestep(eta,img, t)
        if i % stepsize == 0:
            plt.subplot(1, num_images, int(i/stepsize+1))
            show_tensor_image(img.detach().cpu())
    plt.show()      

## Fusion

### Add noise

In [None]:
#Original image
brain_img = transform(Image.open('health/brain/img/path')).unsqueeze(0)
#tumor image
tumor_img =transform(Image.open('tumor/img/path')).unsqueeze(0)
#unhealth mask
unhealth_list_mask.append(mask_transform(Image.open('data/classfire/unmask/'+tumor_files[i][:-4]+'_mask.tif')))


#Divide the image into patches of size 256x256.
left_img,middle_img,right_img = torch.chunk(brain_img,chunks=3,dim=3)
#Overlay health brain images to form a nine-channel tensor.
brain_img = torch.cat((left_img,middle_img,right_img),dim=1)
# Add noise to the intermediate image by passing a 9-channel tensor into the function.
img, noise = forward_diffusion_sample(brain_img, torch.tensor([L]), device)
img_noisy = torch.cat((left_img.to(device),img.to(device),right_img.to(device)),dim=1)


#Divide the tumor image into patches of size 256x256.
left_tumor_img,middle_tumor_img,right_tumor_img = torch.chunk(tumor_img,chunks=3,dim=3)
#Overlay tumoe images to form a nine-channel tensor.
tumor_img = torch.cat((left_tumor_img,middle_tumor_img,right_tumor_img),dim=1)
# Add noise to the intermediate image by passing a 9-channel tensor into the function.
img_tumor, noise = forward_diffusion_sample(tumor_img, torch.tensor([L]), device)

img_noisy_tumor = torch.cat((left_tumor_img.to(device),img_tumor.to(device),right_tumor_img.to(device)),dim=1)

### Image addition

In [None]:
res_unhealth = []
mean = [0.485,0.456,0.406]
std = [0.229,0.224,0.225]

#Here, the fusion1 method is used.

left_tumor_img,middle_tumor_img,right_tumor_img = torch.chunk(img_noisy_tumor,chunks=3,dim=1)
left_img,middle_img,right_img = torch.chunk(img_noisy,chunks=3,dim=1)

#To mitigate the impact of normalization, the images are inversely normalized during the fusion stage.
left_img = left_img[0]*torch.tensor(std)[:,None,None].to(device)+torch.tensor(mean)[:,None,None].to(device)
right_img = right_img[0]*torch.tensor(std)[:,None,None].to(device)+torch.tensor(mean)[:,None,None].to(device)
middle_img = middle_img[0]*torch.tensor(std)[:,None,None].to(device)+torch.tensor(mean)[:,None,None].to(device)
left_tumor_img = left_tumor_img[0]*torch.tensor(std)[:,None,None].to(device)+torch.tensor(mean)[:,None,None].to(device)
right_tumor_img = right_tumor_img[0]*torch.tensor(std)[:,None,None].to(device)+torch.tensor(mean)[:,None,None].to(device)
middle_tumor_img = middle_tumor_img[0]*torch.tensor(std)[:,None,None].to(device)+torch.tensor(mean)[:,None,None].to(device)
res_unhealth.append(torch.cat((left_img+left_tumor_img,middle_img+middle_tumor_img,right_img+right_tumor_img),dim=0).unsqueeze(0))

lef_img,middle_img,right_img = torch.chunk(res_unhealth[0],chunks=3,dim=1)

### Denoising

In [None]:
normalize = transforms.Compose([
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

In [None]:
img_0 = res_unhealth[i]

left_img,middle_img,right_img = torch.chunk(img_0,chunks=3,dim=1)

#Due to the inverse normalization applied earlier, it is necessary to renormalize the data here.
left_img = normalize(left_img)
right_img = normalize(right_img)
middle_img = normalize(middle_img)
img_0 = torch.cat((left_img,middle_img,right_img),dim=1)
for j in range(0,L)[::-1]:
    t = torch.full((1,), j, device=device, dtype=torch.long)
    left_img,middle_img,right_img = torch.chunk(img_0,chunks=3,dim=1)
    img_0 = sample_timestep(eta, img_0.to(device), t)
    img_0 = torch.cat((left_img.to(device),img_0,right_img.to(device)),dim=1)
mean = [0.485,0.456,0.406]
std = [0.229,0.224,0.225]
left_img,middle_img,right_img = torch.chunk(img_0,chunks=3,dim=1)

#Inverse normalization allows the image to display properly.
left_img = left_img[0]*torch.tensor(std)[:,None,None].to(device)+torch.tensor(mean)[:,None,None].to(device)
right_img = right_img[0]*torch.tensor(std)[:,None,None].to(device)+torch.tensor(mean)[:,None,None].to(device)
middle_img = middle_img[0]*torch.tensor(std)[:,None,None].to(device)+torch.tensor(mean)[:,None,None].to(device)
img = trans_pil(middle_img)
img.save('target/path')