In [None]:
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from matplotlib import pyplot as plt
from PIL import Image
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import make_grid, save_image
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from einops import rearrange

from diffusers import DDPMScheduler
from diffusers.models.attention import BasicTransformerBlock
from diffusers.optimization import get_cosine_schedule_with_warmup
from accelerate import Accelerator

torch.manual_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


batch_size = 128
high_res_image_size = 64
low_res_image_size = 32
channels = 3

num_train_timesteps = 1000
num_inference_steps = 50
learning_rate = 2e-4
epochs = 100
num_cra_blocks = 3
encoder_scale_factor_r = 8

save_image_epochs = 2
eval_batch_size = 16
gradient_accumulation_steps = 1
mixed_precision = "fp16"

torch.backends.cudnn.benchmark = True
num_workers = 4
prefetch_factor = 2

output_dir = "./adis_model_output"
os.makedirs(output_dir, exist_ok=True)
os.makedirs(f"{output_dir}/samples", exist_ok=True)

tensorboard_dir = f"{output_dir}/tensorboard"
os.makedirs(tensorboard_dir, exist_ok=True)
writer = SummaryWriter(tensorboard_dir)


class CustomImageDataset(Dataset):
    def __init__(self, root_dir, high_res_transform, low_res_transform):
        self.root_dir = root_dir
        self.high_res_transform = high_res_transform
        self.low_res_transform = low_res_transform
        self.image_paths = []

        for subdir in ["infected", "notinfected"]:
            subdir_path = os.path.join(root_dir, subdir)
            if os.path.exists(subdir_path):
                for file in os.listdir(subdir_path):
                    if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.webp')):
                        self.image_paths.append(os.path.join(subdir_path, file))

        print(f"Found {len(self.image_paths)} images in dataset")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]

        try:
            image = Image.open(img_path).convert('RGB')
            high_res_image = self.high_res_transform(image)
            low_res_image = self.low_res_transform(image)
            return (high_res_image, low_res_image), 0
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            placeholder_high = torch.zeros((channels, high_res_image_size, high_res_image_size))
            placeholder_low = torch.zeros((channels, low_res_image_size, low_res_image_size))
            return (placeholder_high, placeholder_low), 0

high_res_transform = transforms.Compose([
    transforms.Resize((high_res_image_size, high_res_image_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5] * channels, [0.5] * channels)
])

low_res_transform = transforms.Compose([
    transforms.Resize((low_res_image_size, low_res_image_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5] * channels, [0.5] * channels)
])


train_dataset = CustomImageDataset(
    root_dir="./data/train",
    high_res_transform=high_res_transform,
    low_res_transform=low_res_transform
)

dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    prefetch_factor=prefetch_factor,
    persistent_workers=True,
    drop_last=True
)

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim=None):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.norm1 = nn.GroupNorm(8, out_channels)
        self.act1 = nn.SiLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.norm2 = nn.GroupNorm(8, out_channels)
        self.act2 = nn.SiLU()
        self.residual_conv = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()

        self.time_mlp = None
        if time_emb_dim is not None:
            self.time_mlp = nn.Sequential(
                nn.SiLU(),
                nn.Linear(time_emb_dim, out_channels)
            )

    def forward(self, x, t_emb=None):
        h = self.act1(self.norm1(self.conv1(x)))
        if self.time_mlp is not None and t_emb is not None:
            t_h = self.time_mlp(t_emb)
            h = h + t_h[:, :, None, None]
        h = self.act2(self.norm2(self.conv2(h)))
        return h + self.residual_conv(x)

class CrossResolutionAttentionBlock(nn.Module):
    def __init__(self, dim, num_heads=8, dim_head=32):
        super().__init__()
        self.attn = BasicTransformerBlock(
            dim=dim,
            num_attention_heads=num_heads,
            attention_head_dim=dim_head,
            cross_attention_dim=dim,
            only_cross_attention=True
        )

    def forward(self, x, context):
        x = rearrange(x, 'b c h w -> b (h w) c')
        context = rearrange(context, 'b c h w -> b (h w) c')
        out = self.attn(hidden_states=x, encoder_hidden_states=context)
        out = rearrange(out, 'b (h w) c -> b c h w', h=int(x.shape[1]**0.5))
        return out

class PixelShuffleUpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor=2):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels * (scale_factor ** 2), 3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(scale_factor)

    def forward(self, x):
        return self.pixel_shuffle(self.conv(x))

class ADISModel(nn.Module):
    def __init__(
        self,
        in_channels=3,
        out_channels=3,
        base_dim=64,
        time_emb_dim=256,
        cra_blocks=3,
        cra_dim=256,
        cra_heads=8
    ):
        super().__init__()

        self.time_emb = SinusoidalPosEmb(base_dim)
        self.time_mlp = nn.Sequential(
            nn.Linear(base_dim, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim)
        )

        self.noisy_encoder_in = ResidualBlock(in_channels, base_dim, time_emb_dim)
        self.noisy_encoder_down1 = ResidualBlock(base_dim, base_dim * 2, time_emb_dim)

        self.low_res_encoder_in = ResidualBlock(in_channels, base_dim)
        self.low_res_encoder_down1 = ResidualBlock(base_dim, base_dim * 2)

        self.low_res_feat_upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)

        current_dim = base_dim * 2
        self.noisy_encoder_down2 = ResidualBlock(current_dim, cra_dim, time_emb_dim)
        self.low_res_encoder_down2 = ResidualBlock(current_dim, cra_dim)

        self.low_res_feat_upsample_2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)

        self.cra_blocks = nn.ModuleList([
            CrossResolutionAttentionBlock(dim=cra_dim, num_heads=cra_heads)
            for _ in range(cra_blocks)
        ])

        self.decoder_in = ResidualBlock(cra_dim, current_dim, time_emb_dim)

        self.decoder_up1 = PixelShuffleUpBlock(current_dim, current_dim)
        self.decoder_res1 = ResidualBlock(current_dim, base_dim, time_emb_dim)

        self.decoder_up2 = PixelShuffleUpBlock(base_dim, base_dim)
        self.decoder_res2 = ResidualBlock(base_dim, base_dim, time_emb_dim)

        self.final_conv = nn.Conv2d(base_dim, out_channels, 1)

    def forward(self, noisy_image_u_t, timestep_t, low_res_image_u_low):
        t_emb = self.time_emb(timestep_t)
        t_emb = self.time_mlp(t_emb)

        phi_d = self.low_res_encoder_in(low_res_image_u_low)
        phi_d = self.low_res_encoder_down1(phi_d)

        phi_u = self.noisy_encoder_in(noisy_image_u_t, t_emb)
        phi_u = self.noisy_encoder_down1(phi_u, t_emb)

        phi_d_upsampled = self.low_res_feat_upsample(phi_d)
        phi_u = phi_u + phi_d_upsampled

        phi_u_cra = self.noisy_encoder_down2(phi_u, t_emb)
        phi_d_cra = self.low_res_encoder_down2(phi_d)

        phi_d_cra_context = self.low_res_feat_upsample_2(phi_d_cra)

        cra_out = phi_u_cra
        for cra_block in self.cra_blocks:
            cra_out = cra_block(cra_out, context=phi_d_cra_context)

        dec = self.decoder_in(cra_out, t_emb)
        dec = self.decoder_up1(dec)
        dec = self.decoder_res1(dec, t_emb)
        dec = self.decoder_up2(dec)
        dec = self.decoder_res2(dec, t_emb)

        predicted_noise = self.final_conv(dec)
        return predicted_noise


model = ADISModel(
    in_channels=channels,
    out_channels=channels,
    base_dim=128,
    time_emb_dim=512,
    cra_blocks=num_cra_blocks,
    cra_dim=512,
    cra_heads=8
)
model.to(device)

model_size = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {model_size:,}")

noise_scheduler = DDPMScheduler(
    num_train_timesteps=num_train_timesteps,
    beta_schedule="linear",
    prediction_type="epsilon",
    clip_sample=False
)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=learning_rate,
    betas=(0.9, 0.999),
    weight_decay=0.01,
    eps=1e-8
)

lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=(len(dataloader) * epochs) // gradient_accumulation_steps,
)

accelerator = Accelerator(
    mixed_precision=mixed_precision,
    gradient_accumulation_steps=gradient_accumulation_steps,
    log_with="tensorboard",
    project_dir=tensorboard_dir,
)
model, optimizer, dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, dataloader, lr_scheduler
)


@torch.no_grad()
def sample_and_save_images(epoch, model, scheduler, low_res_batch):
    num_images = low_res_batch.shape[0]

    sample = torch.randn(
        num_images,
        channels,
        high_res_image_size,
        high_res_image_size,
        device=device
    )

    scheduler.set_timesteps(num_inference_steps)

    for t in tqdm(scheduler.timesteps, desc="Sampling"):
        noise_pred = model(
            noisy_image_u_t=sample,
            timestep_t=t.unsqueeze(0).repeat(num_images).to(device),
            low_res_image_u_low=low_res_batch
        )
        sample = scheduler.step(noise_pred, t, sample).prev_sample

    images = (sample / 2 + 0.5).clamp(0, 1)
    image_grid = make_grid(images, nrow=int(math.sqrt(num_images)))

    grid_image_path = f"{output_dir}/samples/epoch_{epoch}.png"
    save_image(image_grid, grid_image_path)

    writer.add_image("generated_images", image_grid, epoch)

    plt.figure(figsize=(10, 10))
    plt.imshow(transforms.ToPILImage()(image_grid))
    plt.axis('off')
    plt.title(f"Epoch {epoch}")
    plt.show()

def save_real_examples(dataloader, num_examples=9):
    batch, _ = next(iter(dataloader))
    high_res_images = batch[0][:num_examples]
    low_res_images = batch[1][:num_examples]

    high_res_images = (high_res_images / 2 + 0.5).clamp(0, 1)
    low_res_images = (low_res_images / 2 + 0.5).clamp(0, 1)

    image_grid_high = make_grid(high_res_images, nrow=3)
    image_grid_low = make_grid(low_res_images, nrow=3)

    save_image(image_grid_high, f"{output_dir}/real_examples_high_res.png")
    save_image(image_grid_low, f"{output_dir}/real_examples_low_res.png")

    plt.figure(figsize=(10, 10))
    plt.imshow(transforms.ToPILImage()(image_grid_high))
    plt.axis('off')
    plt.title("Real Dataset Examples (High-Res 64x64)")
    plt.show()

    plt.figure(figsize=(10, 10))
    plt.imshow(transforms.ToPILImage()(image_grid_low))
    plt.axis('off')
    plt.title("Real Dataset Examples (Low-Res 32x32)")
    plt.show()

try:
    save_real_examples(dataloader)
except Exception as e:
    print(f"Error saving real examples: {e}")

dataloader_iter = iter(dataloader)

global_step = 0
final_model_state = None

for epoch in range(epochs):
    model.train()
    epoch_loss = 0.0
    start_time = torch.cuda.Event(enable_timing=True)
    end_time = torch.cuda.Event(enable_timing=True)

    start_time.record()
    progress_bar = tqdm(total=len(dataloader), desc=f"Epoch {epoch}")

    for step, (batch_data, _) in enumerate(dataloader):
        high_res_images, low_res_images = batch_data

        if high_res_images.shape[0] < 2:
            continue

        with accelerator.accumulate(model):
            noise = torch.randn_like(high_res_images)
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps,
                (high_res_images.shape[0],),
                device=high_res_images.device
            ).long()

            noisy_images = noise_scheduler.add_noise(high_res_images, noise, timesteps)

            noise_pred = model(
                noisy_image_u_t=noisy_images,
                timestep_t=timesteps,
                low_res_image_u_low=low_res_images
            )

            loss = F.mse_loss(noise_pred, noise)
            accelerator.backward(loss)

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        epoch_loss += loss.detach().item()
        progress_bar.update(1)
        logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
        progress_bar.set_postfix(**logs)
        global_step += 1

        if step % 10 == 0:
            accelerator.log(logs, step=global_step)

    end_time.record()
    torch.cuda.synchronize()
    epoch_time = start_time.elapsed_time(end_time) / 1000

    avg_loss = epoch_loss / len(dataloader)
    writer.add_scalar("train/loss", avg_loss, epoch)
    writer.add_scalar("train/epoch_time", epoch_time, epoch)

    progress_bar.close()
    print(f"Epoch {epoch} completed in {epoch_time:.2f}s, Avg loss: {avg_loss:.4f}")

    if (epoch + 1) % save_image_epochs == 0 or epoch == epochs - 1:
        unwrapped_model = accelerator.unwrap_model(model)
        unwrapped_model.eval()

        try:
            eval_batch, _ = next(dataloader_iter)
        except StopIteration:
            dataloader_iter = iter(dataloader)
            eval_batch, _ = next(dataloader_iter)

        eval_low_res = eval_batch[1][:eval_batch_size].to(device)

        with torch.no_grad():
            sample_and_save_images(epoch + 1, unwrapped_model, noise_scheduler, eval_low_res)

        if epoch == epochs - 1:
            final_model_state = unwrapped_model.state_dict()

        model.train()

writer.close()
print("Training completed!")


if final_model_state:
    final_model = ADISModel(
        in_channels=channels,
        out_channels=channels,
        base_dim=128,
        time_emb_dim=512,
        cra_blocks=num_cra_blocks,
        cra_dim=512,
        cra_heads=8
    )
    final_model.load_state_dict(final_model_state)
    final_model.to(device)
    final_model.eval()

    final_scheduler = noise_scheduler

    num_samples = 16

    try:
        final_eval_batch, _ = next(dataloader_iter)
    except StopIteration:
        dataloader_iter = iter(dataloader)
        final_eval_batch, _ = next(dataloader_iter)

    final_low_res_batch = final_eval_batch[1][:num_samples].to(device)

    sample = torch.randn(
        num_samples,
        channels,
        high_res_image_size,
        high_res_image_size,
        device=device
    )
    final_scheduler.set_timesteps(num_inference_steps)

    for t in tqdm(final_scheduler.timesteps, desc="Final Generation"):
        with torch.no_grad():
            noise_pred = final_model(
                noisy_image_u_t=sample,
                timestep_t=t.unsqueeze(0).repeat(num_samples).to(device),
                low_res_image_u_low=final_low_res_batch
            )
        sample = final_scheduler.step(noise_pred, t, sample).prev_sample

    sample_images_tensors = (sample / 2 + 0.5).clamp(0, 1)
    sample_images = [transforms.ToPILImage()(img) for img in sample_images_tensors]

    plt.figure(figsize=(12, 12))
    plt.subplots_adjust(hspace=0.1, wspace=0.1)
    for i, image in enumerate(sample_images):
        plt.subplot(4, 4, i+1)
        plt.imshow(image)
        plt.axis('off')
    plt.tight_layout()
    plt.suptitle("Final Generated Samples (Conditioned on Low-Res)", y=1.02)
    plt.show()

    @torch.no_grad()
    def visualize_denoising(model, scheduler, low_res_image, num_steps=10):
        sample = torch.randn(1, channels, high_res_image_size, high_res_image_size).to(device)
        low_res_image = low_res_image.unsqueeze(0).to(device)

        scheduler.set_timesteps(num_steps)

        images = []
        for t in tqdm(scheduler.timesteps, desc="Visualizing Denoising"):
            noise_pred = model(
                noisy_image_u_t=sample,
                timestep_t=t.unsqueeze(0).to(device),
                low_res_image_u_low=low_res_image
            )
            sample = scheduler.step(noise_pred, t, sample).prev_sample

            image = (sample / 2 + 0.5).clamp(0, 1)
            image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
            images.append(image)

        plt.figure(figsize=(15, 4))
        for i, img in enumerate(images):
            plt.subplot(1, num_steps, i+1)
            plt.imshow(img)
            plt.title(f"Step {i+1}")
            plt.axis('off')
        plt.tight_layout()
        plt.suptitle("Denoising Process", y=1.05)
        plt.show()

    visualize_denoising(final_model, final_scheduler, final_low_res_batch[0], num_steps=10)

else:
    print("Final model state was not captured. Skipping final generation/visualization.")