





### Swin_MAE_Interface



### Dependent Scripts

In [None]:
with open('/content/utils/pos_embed.py', 'w') as f:
    f.write('''
import numpy as np

import torch

def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: int of the grid height and width
    return:
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid_h = np.arange(grid_size, dtype=np.float32)
    grid_w = np.arange(grid_size, dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
    return emb


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype= float)
    omega /= embed_dim / 2.
    omega = 1. / 10000 ** omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product

    emb_sin = np.sin(out)  # (M, D/2)
    emb_cos = np.cos(out)  # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb

def interpolate_pos_embed(model, checkpoint_model):
    if 'pos_embed' in checkpoint_model:
        pos_embed_checkpoint = checkpoint_model['pos_embed']
        embedding_size = pos_embed_checkpoint.shape[-1]
        num_patches = model.patch_embed.num_patches
        num_extra_tokens = model.pos_embed.shape[-2] - num_patches
        # height (== width) for the checkpoint position embedding
        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
        # height (== width) for the new position embedding
        new_size = int(num_patches ** 0.5)
        # class_token and dist_token are kept unchanged
        if orig_size != new_size:
            print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
            # only the position tokens are interpolated
            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
            pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
            pos_tokens = torch.nn.functional.interpolate(
                pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
            checkpoint_model['pos_embed'] = new_pos_embed
''')

In [None]:
with open('/content/utils/misc.py', 'w') as f:
    f.write('''


import builtins
import datetime
import os
import time
from collections import defaultdict, deque
from pathlib import Path

import torch
import torch.distributed as dist
#from torch._six import inf
from torch import inf


class SmoothedValue(object):
    """Track a series of values and provide access to smoothed values over a
    window or the global series average.
    """

    def __init__(self, window_size=20, fmt=None):
        if fmt is None:
            fmt = "{median:.4f} ({global_avg:.4f})"
        self.deque = deque(maxlen=window_size)
        self.total = 0.0
        self.count = 0
        self.fmt = fmt

    def update(self, value, n=1):
        self.deque.append(value)
        self.count += n
        self.total += value * n

    def synchronize_between_processes(self):
        """
        Warning: does not synchronize the deque!
        """
        if not is_dist_avail_and_initialized():
            return
        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
        dist.barrier()
        dist.all_reduce(t)
        t = t.tolist()
        self.count = int(t[0])
        self.total = t[1]

    @property
    def median(self):
        d = torch.tensor(list(self.deque))
        return d.median().item()

    @property
    def avg(self):
        d = torch.tensor(list(self.deque), dtype=torch.float32)
        return d.mean().item()

    @property
    def global_avg(self):
        return self.total / self.count

    @property
    def max(self):
        return max(self.deque)

    @property
    def value(self):
        return self.deque[-1]

    def __str__(self):
        return self.fmt.format(
            median=self.median,
            avg=self.avg,
            global_avg=self.global_avg,
            max=self.max,
            value=self.value)


class MetricLogger(object):
    def __init__(self, delimiter="\t"):
        self.meters = defaultdict(SmoothedValue)
        self.delimiter = delimiter

    def update(self, **kwargs):
        for k, v in kwargs.items():
            if v is None:
                continue
            if isinstance(v, torch.Tensor):
                v = v.item()
            assert isinstance(v, (float, int))
            self.meters[k].update(v)

    def __getattr__(self, attr):
        if attr in self.meters:
            return self.meters[attr]
        if attr in self.__dict__:
            return self.__dict__[attr]
        raise AttributeError("'{}' object has no attribute '{}'".format(
            type(self).__name__, attr))

    def __str__(self):
        loss_str = []
        for name, meter in self.meters.items():
            loss_str.append(
                "{}: {}".format(name, str(meter))
            )
        return self.delimiter.join(loss_str)

    def synchronize_between_processes(self):
        for meter in self.meters.values():
            meter.synchronize_between_processes()

    def add_meter(self, name, meter):
        self.meters[name] = meter

    def log_every(self, iterable, print_freq, header=None):
        i = 0
        if not header:
            header = ''
        start_time = time.time()
        end = time.time()
        iter_time = SmoothedValue(fmt='{avg:.4f}')
        data_time = SmoothedValue(fmt='{avg:.4f}')
        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
        log_msg = [
            header,
            '[{0' + space_fmt + '}/{1}]',
            'eta: {eta}',
            '{meters}',
            'time: {time}',
            'data: {data}'
        ]
        if torch.cuda.is_available():
            log_msg.append('max mem: {memory:.0f}')
        log_msg = self.delimiter.join(log_msg)
        MB = 1024.0 * 1024.0
        for obj in iterable:
            data_time.update(time.time() - end)
            yield obj
            iter_time.update(time.time() - end)
            if i % print_freq == 0 or i == len(iterable) - 1:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                if torch.cuda.is_available():
                    print(log_msg.format(
                        i, len(iterable), eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time),
                        memory=torch.cuda.max_memory_allocated() / MB))
                else:
                    print(log_msg.format(
                        i, len(iterable), eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time)))
            i += 1
            end = time.time()
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print('{} Total time: {} ({:.4f} s / it)'.format(
            header, total_time_str, total_time / len(iterable)))


def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    builtin_print = builtins.print

    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        force = force or (get_world_size() > 8)
        if is_master or force:
            now = datetime.datetime.now().time()
            builtin_print('[{}] '.format(now), end='')  # print with time stamp
            builtin_print(*args, **kwargs)

    builtins.print = print


def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()


def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()


def is_main_process():
    return get_rank() == 0


def save_on_master(*args, **kwargs):
    if is_main_process():
        torch.save(*args, **kwargs)


def init_distributed_mode(args):
    if args.dist_on_itp:
        args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
        args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
        args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
        args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
        os.environ['LOCAL_RANK'] = str(args.gpu)
        os.environ['RANK'] = str(args.rank)
        os.environ['WORLD_SIZE'] = str(args.world_size)
        # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
    elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        args.rank = int(os.environ["RANK"])
        args.world_size = int(os.environ['WORLD_SIZE'])
        args.gpu = int(os.environ['LOCAL_RANK'])
    elif 'SLURM_PROCID' in os.environ:
        args.rank = int(os.environ['SLURM_PROCID'])
        args.gpu = args.rank % torch.cuda.device_count()
    else:
        print('Not using distributed mode')
        setup_for_distributed(is_master=True)  # hack
        args.distributed = False
        return

    args.distributed = True

    torch.cuda.set_device(args.gpu)
    args.dist_backend = 'nccl'
    print('| distributed init (rank {}): {}, gpu {}'.format(
        args.rank, args.dist_url, args.gpu), flush=True)
    torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                         world_size=args.world_size, rank=args.rank)
    torch.distributed.barrier()
    setup_for_distributed(args.rank == 0)


class NativeScalerWithGradNormCount:
    state_dict_key = "amp_scaler"

    def __init__(self):
        self._scaler = torch.cuda.amp.GradScaler()

    def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
        self._scaler.scale(loss).backward(create_graph=create_graph)
        if update_grad:
            if clip_grad is not None:
                assert parameters is not None
                self._scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place
                norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
            else:
                self._scaler.unscale_(optimizer)
                norm = get_grad_norm_(parameters)
            self._scaler.step(optimizer)
            self._scaler.update()
        else:
            norm = None
        return norm

    def state_dict(self):
        return self._scaler.state_dict()

    def load_state_dict(self, state_dict):
        self._scaler.load_state_dict(state_dict)


def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = [p for p in parameters if p.grad is not None]
    norm_type = float(norm_type)
    if len(parameters) == 0:
        return torch.tensor(0.)
    device = parameters[0].grad.device
    if norm_type == inf:
        total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
    else:
        total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
    return total_norm


def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
    output_dir = Path(args.output_dir)
    epoch_name = str(epoch)
    if loss_scaler is not None:
        checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
        for checkpoint_path in checkpoint_paths:
            to_save = {
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
                'scaler': loss_scaler.state_dict(),
                'args': args,
            }

            save_on_master(to_save, checkpoint_path)
    else:
        client_state = {'epoch': epoch}
        model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)


def load_model(args, model_without_ddp):
    model_dict = model_without_ddp.state_dict()
    if args.checkpoint_encoder:
        print(f'Use encoder weights: {args.checkpoint_encoder}')
        checkpoint_encoder = torch.load(args.checkpoint_encoder, map_location='cpu')['model']
        for key in list(checkpoint_encoder.keys()):
            if key not in model_dict:
                del checkpoint_encoder[key]
    else:
        checkpoint_encoder = {}

    if args.checkpoint_decoder:
        print(f'Use decoder weights: {args.checkpoint_decoder}')
        checkpoint_decoder = torch.load(args.checkpoint_decoder, map_location='cpu')['model']
        for key in list(checkpoint_decoder.keys()):
            if (not key.startswith('decoder')) or (key not in model_dict):
                del checkpoint_decoder[key]
        for key in list(checkpoint_decoder.keys()):
            if key in model_dict:
                if checkpoint_decoder[key].shape != model_dict[key].shape:
                    print(f"Delete: '{key}'; "
                          f"Weight shape: {checkpoint_decoder[key].shape}; "
                          f"Model shape: {model_dict[key].shape}")
                    del checkpoint_decoder[key]
    else:
        checkpoint_decoder = {}

    checkpoint = checkpoint_encoder
    for key in list(checkpoint_decoder):
        checkpoint[key] = checkpoint_decoder[key]

    if checkpoint:
        result = model_without_ddp.load_state_dict(checkpoint, strict=False)
        print(result)


def all_reduce_mean(x):
    world_size = get_world_size()
    if world_size > 1:
        x_reduce = torch.tensor(x).cuda()
        dist.all_reduce(x_reduce)
        x_reduce /= world_size
        return x_reduce.item()
    else:
        return x
''')

In [None]:
with open('/content/utils/lr_sched.py', 'w') as f:
    f.write('''
import math

def adjust_learning_rate(optimizer, epoch, args):
    """Decay the learning rate with half-cycle cosine after warmup"""
    if epoch < args.warmup_epochs:
        lr = args.lr * epoch / args.warmup_epochs
    else:
        lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
             (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
    for param_group in optimizer.param_groups:
        if "lr_scale" in param_group:
            param_group["lr"] = lr * param_group["lr_scale"]
        else:
            param_group["lr"] = lr
    return lr
''')

In [None]:
with open('/content/utils/engine_pretrain.py', 'w') as f:
    f.write('''
import math
import sys

import torch
import numpy as np
import torchvision

import utils.misc as misc
import utils.lr_sched as lr_sched


def train_one_epoch(model: torch.nn.Module,
                    data_loader, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler,
                    log_writer=None,
                    args=None):
    model.train(True)
    metric_logger = misc.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10

    accum_iter = args.accum_iter

    optimizer.zero_grad()

    if log_writer is not None:
        print('log_dir: {}'.format(log_writer.log_dir))

    for data_iter_step, samples in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        #print(samples['image'])
        # we use a per iteration (instead of per epoch) lr scheduler
        if data_iter_step % accum_iter == 0:
            lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)

        input_image = samples['image']
        batch1 = torch.einsum('nchw->nhwc', input_image)
        stacked_img = np.stack((batch1[:,:,:,0],)*3, axis=-1)
        stacked_img = torch.einsum('nhwc->nchw', torch.from_numpy(stacked_img)).to(device, non_blocking=True)
        with torch.cuda.amp.autocast():
            loss, pred, mask = model(stacked_img)
            y1 = model.unpatchify(pred)
        loss_value = loss.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)

        loss /= accum_iter
        loss_scaler(loss, optimizer, parameters=model.parameters(),
                    update_grad=(data_iter_step + 1) % accum_iter == 0)
        if (data_iter_step + 1) % accum_iter == 0:
            optimizer.zero_grad()

        torch.cuda.synchronize()

        metric_logger.update(loss=loss_value)

        lr = optimizer.param_groups[0]["lr"]
        metric_logger.update(lr=lr)

        loss_value_reduce = misc.all_reduce_mean(loss_value)
        if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
            """ We use epoch_1000x as the x-axis in tensorboard.
            This calibrates different curves when batch size changes.
            """
            epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
            log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
            log_writer.add_scalar('lr', lr, epoch_1000x)
    #After completion of all steps in one epoch, i.e. in last step we will have one entry
    grid_batch_input = ((torchvision.utils.make_grid(stacked_img[0:10,:,:,:])))
    log_writer.add_image(f"input_image/{epoch}", grid_batch_input,epoch)
    grid_batch_output = ((torchvision.utils.make_grid(y1[0:10,:,:,:])))
    log_writer.add_image(f'output_image/{epoch}', grid_batch_output ,epoch)
    ####log_writer.add_image(f'output_image', grid_batch_output ,epoch)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
''')

In [None]:
with open('/content/swin_unet.py', 'w') as f:
    f.write('''
import torch
import torch.nn as nn
import torch.nn.functional as func

from einops import rearrange
from typing import Optional


class DropPath(nn.Module):
    def __init__(self, drop_prob: float = 0.):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if self.drop_prob == 0. or not self.training:
            return x

        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()
        x = x.div(keep_prob) * random_tensor
        return x


class PatchEmbedding(nn.Module):
    def __init__(self, patch_size: int = 4, in_c: int = 3, embed_dim: int = 96, norm_layer: nn.Module = None):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=(patch_size,) * 2, stride=(patch_size,) * 2)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def padding(self, x: torch.Tensor) -> torch.Tensor:
        _, _, H, W = x.shape
        if H % self.patch_size != 0 or W % self.patch_size != 0:
            x = func.pad(x, (0, self.patch_size - W % self.patch_size,
                             0, self.patch_size - H % self.patch_size,
                             0, 0))
        return x

    def forward(self, x):
        x = self.padding(x)
        x = self.proj(x)
        x = rearrange(x, 'B C H W -> B H W C')
        x = self.norm(x)
        return x


class PatchMerging(nn.Module):
    def __init__(self, dim: int, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.norm = norm_layer(4 * dim)
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)

    @staticmethod
    def padding(x: torch.Tensor) -> torch.Tensor:
        _, H, W, _ = x.shape

        if H % 2 == 1 or W % 2 == 1:
            x = func.pad(x, (0, 0, 0, W % 2, 0, H % 2))
        return x

    @staticmethod
    def merging(x: torch.Tensor) -> torch.Tensor:
        x0 = x[:, 0::2, 0::2, :]
        x1 = x[:, 1::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, :]
        x3 = x[:, 1::2, 1::2, :]
        x = torch.cat([x0, x1, x2, x3], -1)
        return x

    def forward(self, x):
        x = self.padding(x)
        x = self.merging(x)
        x = self.norm(x)
        x = self.reduction(x)
        return x


class PatchExpanding(nn.Module):
    def __init__(self, dim: int, norm_layer=nn.LayerNorm):
        super(PatchExpanding, self).__init__()
        self.dim = dim
        self.expand = nn.Linear(dim, 2 * dim, bias=False)
        self.norm = norm_layer(dim // 2)

    def forward(self, x: torch.Tensor):
        x = self.expand(x)
        x = rearrange(x, 'B H W (P1 P2 C) -> B (H P1) (W P2) C', P1=2, P2=2)
        x = self.norm(x)
        return x


class FinalPatchExpanding(nn.Module):
    def __init__(self, dim: int, norm_layer=nn.LayerNorm):
        super(FinalPatchExpanding, self).__init__()
        self.dim = dim
        self.expand = nn.Linear(dim, 16 * dim, bias=False)
        self.norm = norm_layer(dim)

    def forward(self, x: torch.Tensor):
        x = self.expand(x)
        x = rearrange(x, 'B H W (P1 P2 C) -> B (H P1) (W P2) C', P1=4, P2=4)
        x = self.norm(x)
        return x


class Mlp(nn.Module):
    def __init__(self, in_features: int, hidden_features: int = None, out_features: int = None,
                 act_layer=nn.GELU, drop: float = 0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop)
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop2 = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x


class WindowAttention(nn.Module):
    def __init__(self, dim: int, window_size: int, num_heads: int, qkv_bias: Optional[bool] = True,
                 attn_drop: Optional[float] = 0., proj_drop: Optional[float] = 0., shift: bool = False):
        super().__init__()
        self.window_size = window_size
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5

        if shift:
            self.shift_size = window_size // 2
        else:
            self.shift_size = 0

        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size - 1) ** 2, num_heads))
        nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)

        coords_size = torch.arange(self.window_size)
        coords = torch.stack(torch.meshgrid([coords_size, coords_size], indexing="ij"))
        coords_flatten = torch.flatten(coords, 1)

        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += self.window_size - 1
        relative_coords[:, :, 1] += self.window_size - 1
        relative_coords[:, :, 0] *= 2 * self.window_size - 1
        relative_position_index = relative_coords.sum(-1)
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.softmax = nn.Softmax(dim=-1)

    def window_partition(self, x: torch.Tensor) -> torch.Tensor:
        _, H, W, _ = x.shape

        x = rearrange(x, 'B (Nh Mh) (Nw Mw) C -> (B Nh Nw) Mh Mw C', Nh=H // self.window_size, Nw=W // self.window_size)
        return x

    def create_mask(self, x: torch.Tensor) -> torch.Tensor:
        _, H, W, _ = x.shape

        assert H % self.window_size == 0 and W % self.window_size == 0, "H or W is not divisible by window_size"

        img_mask = torch.zeros((1, H, W, 1), device=x.device)
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        mask_windows = self.window_partition(img_mask)
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)

        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        return attn_mask

    def forward(self, x):
        _, H, W, _ = x.shape

        if self.shift_size > 0:
            x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
            mask = self.create_mask(x)
        else:
            mask = None

        x = self.window_partition(x)
        Bn, Mh, Mw, _ = x.shape
        x = rearrange(x, 'Bn Mh Mw C -> Bn (Mh Mw) C')
        qkv = rearrange(self.qkv(x), 'Bn L (T Nh P) -> T Bn Nh L P', T=3, Nh=self.num_heads)
        q, k, v = qkv.unbind(0)
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size ** 2, self.window_size ** 2, -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(Bn // nW, nW, self.num_heads, Mh * Mw, Mh * Mw) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, Mh * Mw, Mh * Mw)
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)
        x = attn @ v
        x = rearrange(x, 'Bn Nh (Mh Mw) C -> Bn Mh Mw (Nh C)', Mh=Mh)
        x = self.proj(x)
        x = self.proj_drop(x)
        x = rearrange(x, '(B Nh Nw) Mh Mw C -> B (Nh Mh) (Nw Mw) C', Nh=H // Mh, Nw=H // Mw)

        if self.shift_size > 0:
            x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        return x


class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, window_size=7, shift=False, mlp_ratio=4., qkv_bias=True,
                 drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(dim, window_size=window_size, num_heads=num_heads, qkv_bias=qkv_bias,
                                    attn_drop=attn_drop, proj_drop=drop, shift=shift)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x_copy = x
        x = self.norm1(x)

        x = self.attn(x)
        x = self.drop_path(x)
        x = x + x_copy

        x_copy = x
        x = self.norm2(x)

        x = self.mlp(x)
        x = self.drop_path(x)
        x = x + x_copy
        return x


class BasicBlock(nn.Module):
    def __init__(self, index: int, embed_dim: int = 96, window_size: int = 7, depths: tuple = (2, 2, 6, 2),
                 num_heads: tuple = (3, 6, 12, 24), mlp_ratio: float = 4., qkv_bias: bool = True,
                 drop_rate: float = 0., attn_drop_rate: float = 0., drop_path: float = 0.1,
                 norm_layer=nn.LayerNorm, patch_merging: bool = True):
        super(BasicBlock, self).__init__()
        depth = depths[index]
        dim = embed_dim * 2 ** index
        num_head = num_heads[index]

        dpr = [rate.item() for rate in torch.linspace(0, drop_path, sum(depths))]
        drop_path_rate = dpr[sum(depths[:index]):sum(depths[:index + 1])]

        self.blocks = nn.ModuleList([
            SwinTransformerBlock(
                dim=dim,
                num_heads=num_head,
                window_size=window_size,
                shift=False if (i % 2 == 0) else True,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=drop_path_rate[i],
                norm_layer=norm_layer)
            for i in range(depth)])

        if patch_merging:
            self.downsample = PatchMerging(dim=embed_dim * 2 ** index, norm_layer=norm_layer)
        else:
            self.downsample = None

    def forward(self, x):
        for layer in self.blocks:
            x = layer(x)
        if self.downsample is not None:
            x = self.downsample(x)
        return x


class BasicBlockUp(nn.Module):
    def __init__(self, index: int, embed_dim: int = 96, window_size: int = 7, depths: tuple = (2, 2, 6, 2),
                 num_heads: tuple = (3, 6, 12, 24), mlp_ratio: float = 4., qkv_bias: bool = True,
                 drop_rate: float = 0., attn_drop_rate: float = 0., drop_path: float = 0.1,
                 patch_expanding: bool = True, norm_layer=nn.LayerNorm):
        super(BasicBlockUp, self).__init__()
        index = len(depths) - index - 2
        depth = depths[index]
        dim = embed_dim * 2 ** index
        num_head = num_heads[index]

        dpr = [rate.item() for rate in torch.linspace(0, drop_path, sum(depths))]
        drop_path_rate = dpr[sum(depths[:index]):sum(depths[:index + 1])]

        self.blocks = nn.ModuleList([
            SwinTransformerBlock(
                dim=dim,
                num_heads=num_head,
                window_size=window_size,
                shift=False if (i % 2 == 0) else True,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=drop_path_rate[i],
                norm_layer=norm_layer)
            for i in range(depth)])
        if patch_expanding:
            self.upsample = PatchExpanding(dim=embed_dim * 2 ** index, norm_layer=norm_layer)
        else:
            self.upsample = nn.Identity()

    def forward(self, x):
        for layer in self.blocks:
            x = layer(x)
        x = self.upsample(x)
        return x


class SwinUnet(nn.Module):
    def __init__(self, patch_size: int = 4, in_chans: int = 3, num_classes: int = 1000, embed_dim: int = 96,
                 window_size: int = 7, depths: tuple = (2, 2, 6, 2), num_heads: tuple = (3, 6, 12, 24),
                 mlp_ratio: float = 4., qkv_bias: bool = True, drop_rate: float = 0., attn_drop_rate: float = 0.,
                 drop_path_rate: float = 0.1, norm_layer=nn.LayerNorm, patch_norm: bool = True):
        super().__init__()

        self.window_size = window_size
        self.depths = depths
        self.num_heads = num_heads
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.mlp_ratio = mlp_ratio
        self.qkv_bias = qkv_bias
        self.drop_rate = drop_rate
        self.attn_drop_rate = attn_drop_rate
        self.drop_path = drop_path_rate
        self.norm_layer = norm_layer

        self.patch_embed = PatchEmbedding(
            patch_size=patch_size, in_c=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if patch_norm else None)
        self.pos_drop = nn.Dropout(p=drop_rate)
        self.layers = self.build_layers()
        self.first_patch_expanding = PatchExpanding(dim=embed_dim * 2 ** (len(depths) - 1), norm_layer=norm_layer)
        self.layers_up = self.build_layers_up()
        self.skip_connection_layers = self.skip_connection()
        self.norm_up = norm_layer(embed_dim)
        self.final_patch_expanding = FinalPatchExpanding(dim=embed_dim, norm_layer=norm_layer)
        self.head = nn.Conv2d(in_channels=embed_dim, out_channels=num_classes, kernel_size=(1, 1), bias=False)
        self.apply(self.init_weights)

    @staticmethod
    def init_weights(m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def build_layers(self):
        layers = nn.ModuleList()
        for i in range(self.num_layers):
            layer = BasicBlock(
                index=i,
                depths=self.depths,
                embed_dim=self.embed_dim,
                num_heads=self.num_heads,
                drop_path=self.drop_path,
                window_size=self.window_size,
                mlp_ratio=self.mlp_ratio,
                qkv_bias=self.qkv_bias,
                drop_rate=self.drop_rate,
                attn_drop_rate=self.attn_drop_rate,
                norm_layer=self.norm_layer,
                patch_merging=False if i == self.num_layers - 1 else True)
            layers.append(layer)
        return layers

    def build_layers_up(self):
        layers_up = nn.ModuleList()
        for i in range(self.num_layers - 1):
            layer = BasicBlockUp(
                index=i,
                depths=self.depths,
                embed_dim=self.embed_dim,
                num_heads=self.num_heads,
                drop_path=self.drop_path,
                window_size=self.window_size,
                mlp_ratio=self.mlp_ratio,
                qkv_bias=self.qkv_bias,
                drop_rate=self.drop_rate,
                attn_drop_rate=self.attn_drop_rate,
                patch_expanding=True if i < self.num_layers - 2 else False,
                norm_layer=self.norm_layer)
            layers_up.append(layer)
        return layers_up

    def skip_connection(self):
        skip_connection_layers = nn.ModuleList()
        for i in range(self.num_layers - 1):
            dim = self.embed_dim * 2 ** (self.num_layers - 2 - i)
            layer = nn.Linear(dim * 2, dim)
            skip_connection_layers.append(layer)
        return skip_connection_layers

    def forward(self, x):
        x = self.patch_embed(x)
        x = self.pos_drop(x)

        x_save = []
        for i, layer in enumerate(self.layers):
            x_save.append(x)
            x = layer(x)

        x = self.first_patch_expanding(x)

        for i, layer in enumerate(self.layers_up):
            x = torch.cat([x, x_save[len(x_save) - i - 2]], -1)
            x = self.skip_connection_layers[i](x)
            x = layer(x)

        x = self.norm_up(x)
        x = self.final_patch_expanding(x)

        x = rearrange(x, 'B H W C -> B C H W')
        x = self.head(x)
        return x
''')

In [None]:
with open('/content/swin_mae_inference.py', 'w') as f:
    f.write('''
import sys
sys.path.insert(0, '/content/drive/MyDrive/ie643_course_project_24M1644')
import os
import torch
import torch.nn as nn
import numpy as np
from functools import partial
from einops import rearrange
from swin_unet import PatchEmbedding, BasicBlock, PatchExpanding, BasicBlockUp
from utils.pos_embed import get_2d_sincos_pos_embed


class SwinMAE(nn.Module):
    """
    Masked Autoencoder with Swin Transformer backbone (CUDA-safe version)
    """

    def __init__(self, img_size=224, patch_size=4, mask_ratio=0.25, in_chans=3,
                 decoder_embed_dim=768, norm_pix_loss=False,
                 depths=(2, 2, 2, 2), embed_dim=96, num_heads=(3, 6, 12, 24),
                 window_size=7, qkv_bias=True, mlp_ratio=4.,
                 drop_path_rate=0.1, drop_rate=0., attn_drop_rate=0.,
                 norm_layer=partial(nn.LayerNorm, eps=1e-6), patch_norm=True):
        super().__init__()
        self.mask_ratio = mask_ratio
        assert img_size % patch_size == 0
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_size = patch_size
        self.norm_pix_loss = norm_pix_loss
        self.depths = depths
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.mlp_ratio = mlp_ratio
        self.qkv_bias = qkv_bias
        self.drop_path = drop_path_rate
        self.drop_rate = drop_rate
        self.attn_drop_rate = attn_drop_rate
        self.norm_layer = norm_layer

        # Encoder
        self.patch_embed = PatchEmbedding(patch_size=patch_size, in_c=in_chans, embed_dim=embed_dim,
                                          norm_layer=norm_layer if patch_norm else None)
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim), requires_grad=False)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.layers = self.build_layers()

        # Decoder
        self.first_patch_expanding = PatchExpanding(dim=decoder_embed_dim, norm_layer=norm_layer)
        self.layers_up = self.build_layers_up()
        self.norm_up = norm_layer(embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim // 8, patch_size ** 2 * in_chans, bias=True)

        self.initialize_weights()

    def initialize_weights(self):
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.num_patches ** 0.5), cls_token=False)
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
        nn.init.normal_(self.mask_token, std=.02)
        self.apply(self._init_weights)

    @staticmethod
    def _init_weights(m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    # ---- Patch operations ----
    def patchify(self, imgs):
        p = self.patch_size
        h = w = imgs.shape[2] // p
        x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
        x = torch.einsum('nchpwq->nhwpqc', x)
        x = x.reshape(imgs.shape[0], h * w, p ** 2 * 3)
        return x

    def unpatchify(self, x):
        p = self.patch_size
        h = w = int(x.shape[1] ** 0.5)
        x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(x.shape[0], 3, h * p, h * p)
        return imgs

    # ---- Window masking ----
    def window_masking(self, x, r=4, remove=False, mask_len_sparse=False):
        x = rearrange(x, 'B H W C -> B (H W) C')
        B, L, D = x.shape
        device = x.device
        d = int(L ** 0.5 // r)

        noise = torch.rand(B, d ** 2, device=device)
        sparse_shuffle = torch.argsort(noise, dim=1)
        sparse_restore = torch.argsort(sparse_shuffle, dim=1)
        sparse_keep = sparse_shuffle[:, :int(d ** 2 * (1 - self.mask_ratio))]

        index_keep_part = (torch.div(sparse_keep, d, rounding_mode='floor') * d * r ** 2 +
                           sparse_keep % d * r).long().to(device)
        index_keep = index_keep_part
        for i in range(r):
            for j in range(r):
                if i == 0 and j == 0:
                    continue
                offset = (int(L ** 0.5) * i + j)
                index_keep = torch.cat([index_keep, (index_keep_part + offset)], dim=1)

        index_all = torch.arange(L, device=device).unsqueeze(0).repeat(B, 1)
        index_mask = torch.zeros((B, L - index_keep.shape[-1]), dtype=torch.long, device=device)
        for i in range(B):
            diff = torch.tensor(
                np.setdiff1d(index_all[i].cpu().numpy(), index_keep[i].cpu().numpy(), assume_unique=True),
                device=device, dtype=torch.long
            )
            index_mask[i, :diff.shape[0]] = diff

        index_shuffle = torch.cat([index_keep, index_mask], dim=1)
        index_restore = torch.argsort(index_shuffle, dim=1)

        mask = torch.ones([B, L], device=device)
        mask[:, :index_keep.shape[-1]] = 0
        mask = torch.gather(mask, dim=1, index=index_restore)

        if remove:
            x_masked = torch.gather(x, dim=1, index=index_keep.unsqueeze(-1).repeat(1, 1, D))
            x_masked = rearrange(x_masked, 'B (H W) C -> B H W C', H=int(x_masked.shape[1] ** 0.5))
            return x_masked, mask, sparse_restore
        else:
            x_masked = torch.clone(x)
            for i in range(B):
                x_masked[i, index_mask[i], :] = self.mask_token.to(device)
            x_masked = rearrange(x_masked, 'B (H W) C -> B H W C', H=int(x_masked.shape[1] ** 0.5))
            return x_masked, mask

    def window_masking_(self, x, window_arr, r=4, remove=False, mask_len_sparse=False, index=27):
        """
        Device-safe variant used in inference with a fixed window_arr.
        """
        x = rearrange(x, 'B H W C -> B (H W) C')
        B, L, D = x.shape
        device = x.device
        d = int(L ** 0.5 // r)

        noise = torch.rand(B, d ** 2, device=device)
        sparse_shuffle = torch.argsort(noise, dim=1)
        sparse_restore = torch.argsort(sparse_shuffle, dim=1)
        sparse_keep = sparse_shuffle[:, :int(d ** 2 * (1 - self.mask_ratio))]

        # window_arr exclusion on GPU
        arr = list(range(0, 196))
        arr2 = sorted(list(set(arr) - set(window_arr)))
        sparse_keep = torch.tensor([arr2], dtype=torch.long, device=device).repeat(B, 1)

        index_keep_part = (torch.div(sparse_keep, d, rounding_mode='floor') * d * r ** 2 +
                           sparse_keep % d * r).long().to(device)
        index_keep = index_keep_part
        for i in range(r):
            for j in range(r):
                if i == 0 and j == 0:
                    continue
                offset = (int(L ** 0.5) * i + j)
                index_keep = torch.cat([index_keep, (index_keep_part + offset)], dim=1)

        index_all = torch.arange(L, device=device).unsqueeze(0).repeat(B, 1)
        index_mask = torch.zeros((B, L - index_keep.shape[-1]), dtype=torch.long, device=device)
        for i in range(B):
            diff = torch.tensor(
                np.setdiff1d(index_all[i].cpu().numpy(), index_keep[i].cpu().numpy(), assume_unique=True),
                device=device, dtype=torch.long
            )
            index_mask[i, :diff.shape[0]] = diff

        index_shuffle = torch.cat([index_keep, index_mask], dim=1)
        index_restore = torch.argsort(index_shuffle, dim=1)

        mask = torch.ones([B, L], device=device)
        mask[:, :index_keep.shape[-1]] = 0
        mask = torch.gather(mask, dim=1, index=index_restore)

        if remove:
            x_masked = torch.gather(x, dim=1, index=index_keep.unsqueeze(-1).repeat(1, 1, D))
            x_masked = rearrange(x_masked, 'B (H W) C -> B H W C', H=int(x_masked.shape[1] ** 0.5))
            return x_masked, mask, sparse_restore
        else:
            x_masked = torch.clone(x)
            for i in range(B):
                x_masked[i, index_mask[i], :] = self.mask_token.to(device)
            x_masked = rearrange(x_masked, 'B (H W) C -> B H W C', H=int(x_masked.shape[1] ** 0.5))
            return x_masked, mask

    # ---- Network construction ----
    def build_layers(self):
        layers = nn.ModuleList()
        for i in range(len(self.depths)):
            layers.append(
                BasicBlock(
                    index=i,
                    depths=self.depths,
                    embed_dim=self.embed_dim,
                    num_heads=self.num_heads,
                    drop_path=self.drop_path,
                    window_size=self.window_size,
                    mlp_ratio=self.mlp_ratio,
                    qkv_bias=self.qkv_bias,
                    drop_rate=self.drop_rate,
                    attn_drop_rate=self.attn_drop_rate,
                    norm_layer=self.norm_layer,
                    patch_merging=(i < len(self.depths) - 1)
                )
            )
        return layers

    def build_layers_up(self):
        layers_up = nn.ModuleList()
        for i in range(len(self.depths) - 1):
            layers_up.append(
                BasicBlockUp(
                    index=i,
                    depths=self.depths,
                    embed_dim=self.embed_dim,
                    num_heads=self.num_heads,
                    drop_path=self.drop_path,
                    window_size=self.window_size,
                    mlp_ratio=self.mlp_ratio,
                    qkv_bias=self.qkv_bias,
                    drop_rate=self.drop_rate,
                    attn_drop_rate=self.attn_drop_rate,
                    patch_expanding=(i < len(self.depths) - 2),
                    norm_layer=self.norm_layer
                )
            )
        return layers_up

    # ---- Forward passes ----
    def forward_encoder(self, x, window_arr):
        x = self.patch_embed(x)
        x, mask = self.window_masking(x, remove=False, mask_len_sparse=False)
        for layer in self.layers:
            x = layer(x)
        return x, mask

    def forward_decoder(self, x):
        x = self.first_patch_expanding(x)
        for layer in self.layers_up:
            x = layer(x)
        x = self.norm_up(x)
        x = rearrange(x, 'B H W C -> B (H W) C')
        x = self.decoder_pred(x)
        return x

    def forward_loss(self, imgs, pred, mask):
        target = self.patchify(imgs)
        if self.norm_pix_loss:
            mean = target.mean(dim=-1, keepdim=True)
            var = target.var(dim=-1, keepdim=True)
            target = (target - mean) / (var + 1.e-6) ** 0.5
        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)
        loss = (loss * mask).sum() / mask.sum()
        return loss

    def forward(self, x, window_arr):
        latent, mask = self.forward_encoder(x, window_arr)
        pred = self.forward_decoder(latent)
        loss = self.forward_loss(x, pred, mask)
        return loss, pred, mask


def swin_mae(**kwargs):
    model = SwinMAE(
        img_size=224, patch_size=4, in_chans=3,
        decoder_embed_dim=768,
        depths=(2, 2, 2, 2), embed_dim=96, num_heads=(3, 6, 12, 24),
        window_size=7, qkv_bias=True, mlp_ratio=4,
        drop_path_rate=0.1, drop_rate=0, attn_drop_rate=0,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model
''')

### Interface

In [None]:
import os, sys, time, glob, json
from pathlib import Path
import numpy as np
import torch
import h5py
import matplotlib.pyplot as plt
import cv2
import argparse
from functools import partial
import torch.nn as nn
from skimage.segmentation import felzenszwalb
from skimage.measure import regionprops
from sklearn.metrics import precision_recall_curve, average_precision_score
from matplotlib.colors import LinearSegmentedColormap
import gradio as gr
import tempfile
import nest_asyncio

nest_asyncio.apply()

# ---- CONFIG - edit these if needed ----
WINDOW_SIZE = 32  # sliding window size (gamma)
WINDOW_STEP = 32  # sliding step (k)
WINDOW_BATCH = 2  # how many windows to batch at once (tune for L4)
SEED = 42  # fixed seed
MIN_GT_PIXELS = -1  # skip tiny GT slices
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# -------------------------------
print("Device:", DEVICE)
# reproducible
np.random.seed(SEED)
torch.manual_seed(SEED)

from swin_mae_inference import SwinMAE  # Assuming swin_mae_inference.py is in the same directory

# Custom yellow colormap (black -> yellow)
YELLOW_MAP = LinearSegmentedColormap.from_list("black_to_yellow", ["black", "yellow"], N=256)


# --- Utility functions ------------------------------------------------------
def coordinates_to_patch_indexes(x, y, patch_size=16, image_size=224):
    px = x // patch_size
    py = y // patch_size
    px = min(max(px, 0), image_size // patch_size - 1)
    py = min(max(py, 0), image_size // patch_size - 1)
    return px, py


def sliding_windows_for_image(image_shape, window_size, step):
    H, W = image_shape[:2]
    for y in range(48, H - window_size + 1, step):
        for x in range(25, W - window_size + 1, step):
            yield x, y


def minmax_normalization_uint8(image):
    mn = image.min(); mx = image.max()
    if mx == mn:
        return np.zeros_like(image, dtype=np.uint8)
    out = 255 * ((image - mn) / (mx - mn))
    return out.astype(np.uint8)


def reconstruction_loss_lp(orig, recon):
    if orig.ndim == 3 and orig.shape[2] == 3:
        o = orig[..., 0]
    else:
        o = orig
    if recon.ndim == 3 and recon.shape[2] == 3:
        r = recon[..., 0]
    else:
        r = recon
    return np.abs(o - r)


def compute_auprc(y_pred, y_true):
    y_pred = y_pred.flatten()
    y_true = y_true.flatten()
    try:
        auprc = average_precision_score(y_true.astype(int), y_pred)
        precisions, recalls, thresholds = precision_recall_curve(y_true.astype(int), y_pred)
    except Exception:
        auprc = float('nan'); precisions, recalls, thresholds = None, None, None
    return auprc, precisions, recalls, thresholds


def calculate_confusion_matrix(gt, pred):
    tp = np.sum((gt == 1) & (pred == 1))
    fp = np.sum((gt == 0) & (pred == 1))
    fn = np.sum((gt == 1) & (pred == 0))
    return tp, fp, fn


def calculate_dice_score(gt, pred):
    tp, fp, fn = calculate_confusion_matrix(gt, pred)
    denom = (2 * tp + fp + fn)
    dice = (2 * tp) / denom if denom > 0 else 0.0
    return dice, tp, fp, fn


# --- run_one_batch_of_windows -----------------------------------------------
def run_windows_batched(image_rgb, window_arrs, model, device):
    """
    image_rgb: np.array (H,W,3), values in float (0..1)
    window_arrs: list of lists of patch indices for each window
    Returns: list of reconstructed images (H,W,3) for each window (same order)

    This function tries to call model(x, window_arr) if supported by the loaded model class.
    If not supported, it falls back to model(x) (note: fallback DOES NOT apply window-specific masking).
    """
    recon_results = []
    x_single = torch.from_numpy(image_rgb[np.newaxis]).float().permute(0, 3, 1, 2).to(device)  # (1,3,H,W)

    for arr in window_arrs:
        with torch.no_grad():
            # prefer calling model(x, arr) if possible (the inference variant).
            try:
                loss_pred_mask = model(x_single, arr)   # expecting (loss, pred, mask)
                # if model returns a single tensor (some variants), handle below
            except TypeError:
                # fallback: model probably only accepts (x,)
                loss_pred_mask = model(x_single)

            # normalize all possible return shapes:
            # common expected: (loss, pred, mask) where pred shape (N, L, p*p*3)
            if isinstance(loss_pred_mask, tuple) or isinstance(loss_pred_mask, list):
                _, pred, _ = loss_pred_mask
            else:
                # model returned only pred (or only output); assume it's pred
                pred = loss_pred_mask

            # reconstruct image
            recon_t = model.unpatchify(pred)                 # (1,3,H,W) or similar
            recon_np = recon_t.detach().cpu().numpy()[0]     # (3,H,W)
            recon_hw3 = np.transpose(recon_np, (1, 2, 0))
            recon_results.append(recon_hw3)

            # free
            del pred, recon_t, recon_np
            torch.cuda.empty_cache()

    return recon_results
# --- main evaluation per-slice using sliding windows -------------------------
def evaluate_slice_with_sliding_window(image_rgb, label_mask, model,
                                       window_size=WINDOW_SIZE, step=WINDOW_STEP,
                                       window_batch=WINDOW_BATCH, device=DEVICE,
                                       save_dir=None, base_name="subj", slice_idx=0):
    """
    Compute combined heatmap for a single slice using sliding windows.
    Returns: combined_heatmap (H, W) float32 in [0..1].
    Also returns a 'pseudo reconstruction' (see note).
    """
    H, W = image_rgb.shape[:2]
    img_float = image_rgb.astype(np.float32)
    if img_float.max() > 1.0:
        img_float = img_float / 255.0

    sum_heat = np.zeros((H, W), dtype=np.float32)
    coverage = np.zeros((H, W), dtype=np.float32)

    windows = [(x, y) for (x, y) in sliding_windows_for_image((H, W), window_size, step)]
    window_patch_idx_list = []
    for (x, y) in windows:
        patch_idxs = []
        for a in range(x, x + window_size):
            for b in range(y, y + window_size):
                px, py = coordinates_to_patch_indexes(a, b, patch_size=16, image_size=224)
                final_patch_number = px * (224 // 16) + py
                patch_idxs.append(final_patch_number)
        window_patch_idx_list.append(list(np.unique(patch_idxs)))

    i = 0
    N = len(window_patch_idx_list)
    last_recon_full = None
    while i < N:
        batch_arrs = window_patch_idx_list[i:i + window_batch]
        batch_windows = windows[i:i + window_batch]

        recon_list = run_windows_batched(img_float, batch_arrs, model, device)

        for j, (x, y) in enumerate(batch_windows):
            recon_win = recon_list[j][y:y + window_size, x:x + window_size, :]
            orig_win = img_float[y:y + window_size, x:x + window_size, :]

            o = orig_win[..., 0].astype(np.float32)
            r = recon_win[..., 0].astype(np.float32)

            loss_map = np.abs(o - r).astype(np.float32)
            sum_heat[y:y + window_size, x:x + window_size] += loss_map
            coverage[y:y + window_size, x:x + window_size] += 1.0

        # keep last recon_list's first returned recon as a convenient "example" to visualize full-image reconstruction
        # NOTE: this is NOT the stitched full reconstruction; it's just the model output for the last processed window call's input image.
        # For a true full-image reconstruction you should run model(x_full, window_arr=None or specific) and unpatchify its pred.
        try:
            last_recon_full = recon_list[0]  # shape (H,W,3)
        except Exception:
            last_recon_full = None

        del recon_list, batch_arrs
        torch.cuda.empty_cache()
        i += window_batch

    coverage_safe = np.where(coverage == 0, 1.0, coverage)
    avg_heat = sum_heat / coverage_safe
    avg_heat = avg_heat * (img_float[..., 0] > 0.01)

    nonzero_vals = avg_heat[avg_heat > 0]
    if nonzero_vals.size > 0:
        p = np.percentile(nonzero_vals, 99.5)
        if p <= 0:
            p = nonzero_vals.max() if nonzero_vals.max() > 0 else 1.0
    else:
        p = 1.0

    combined_heatmap = np.clip(avg_heat / p, 0.0, 1.0).astype(np.float32)

    # build a pseudo-reconstruction for visualization only (original attenuated by heatmap)
    # (This is a stand-in when you don't have a stitched real reconstruction.)
    pseudo_recon = img_float.copy()
    pseudo_recon[..., 0] = np.clip(img_float[..., 0] * (1.0 - combined_heatmap), 0.0, 1.0)
    pseudo_recon = np.stack([pseudo_recon[..., 0]] * 3, axis=-1)

    # prefer last_recon_full (model output) if available (not stitched), fallback to pseudo_recon
    recon_for_vis = last_recon_full if (last_recon_full is not None) else pseudo_recon

    return combined_heatmap, recon_for_vis


# --- post-processing & saving visualizations --------------------------------
def post_process_and_save(comb_heat_map, org_img, gt, recon_img,
                          save_dir, subj_name, slice_num,
                          heatmap_dpi=140, heatmap_figsize=(12, 4),
                          compact_dpi=140, compact_figsize=(12, 4)):
    """
    Save:
     - <subj>_slice_<NNN>_anomaly.png        : Atropos seg | Combined heatmap (yellow) | Final anomaly segmentation
    And return (gt_u8, anomaly_mask).

    NOTE: The earlier 'reconstruction' PNG is intentionally NOT saved (to save space).
    """
    Path(save_dir).mkdir(parents=True, exist_ok=True)
    base = f"{subj_name}_slice_{slice_num:03d}"

    # prepare arrays
    org_gray = org_img[..., 0] if org_img.ndim == 3 else org_img
    org_gray = org_gray.astype(np.float32)
    comb_heat_map = comb_heat_map.astype(np.float32)
    gt_u8 = (gt > 0).astype(np.uint8)

    # --- Compute Atropos / segmentation or fallback ---
    try:
        import ants
        ants_image = ants.from_numpy(org_gray)
        img_ = ants.resample_image(ants_image, (224, 224), 1, 0)
        mask = ants.get_mask(img_)
        img_seg = ants.atropos(a=img_, m='[0.2,1x1]', c='[2,0]', i='kmeans[4]', x=mask)
        img_seg = img_seg['segmentation'].numpy()
        # resize back if necessary
        if img_seg.shape != org_gray.shape:
            import skimage.transform as sktf
            img_seg = sktf.resize(img_seg, org_gray.shape, order=0, preserve_range=True).astype(np.int32)
    except Exception:
        # fallback to felzenszwalb segmentation
        img_seg = felzenszwalb(org_gray, scale=75, sigma=0.1, min_size=10)

    # --- Final anomaly segmentation logic (same as your pipeline) ---
    heat_map_rev = (1 - comb_heat_map) * (org_gray > 0.3)
    kernel = np.ones((1, 1), np.uint8)
    eroded_image = cv2.morphologyEx((heat_map_rev * 255).astype('uint8'), cv2.MORPH_ERODE, kernel)
    eroded_image = (eroded_image / 255) > 0.5
    segments_old = felzenszwalb(comb_heat_map, scale=75, sigma=0.8, min_size=100)
    segments = eroded_image * segments_old
    region_props = regionprops(segments, intensity_image=comb_heat_map)
    intensity_sorted_regions = sorted(region_props, key=lambda prop: prop.intensity_mean, reverse=True)
    top_regions = intensity_sorted_regions[:7]
    predicted_mask_comb = np.zeros_like(gt_u8)
    for rr in top_regions:
        predicted_mask_comb = predicted_mask_comb + (segments == rr.label)
    kernel = np.ones((5, 5), np.uint8)
    predicted_mask_comb = cv2.morphologyEx(predicted_mask_comb.astype('uint8'), cv2.MORPH_DILATE, kernel)

    anomaly_mask = (predicted_mask_comb > 0).astype(np.uint8)

    # --- PNG: Anomaly visualization only (reduced size) ---
    fig, axes = plt.subplots(1, 3, figsize=heatmap_figsize, dpi=heatmap_dpi)
    # Atropos / segmentation
    axes[0].imshow(org_gray, cmap='gray')
    axes[0].imshow(img_seg, alpha=0.5, cmap='nipy_spectral')
    axes[0].set_title("Atropos / segmentation"); axes[0].axis('off')
    # Combined heatmap with yellow colormap
    axes[1].imshow(org_gray, cmap='gray')
    axes[1].imshow(comb_heat_map, cmap=YELLOW_MAP, alpha=0.7)
    axes[1].set_title("Combined Heatmap"); axes[1].axis('off')
    # Final anomaly segmentation overlay
    axes[2].imshow(org_gray, cmap='gray')
    axes[2].imshow(anomaly_mask, alpha=0.5, cmap='hot')
    axes[2].set_title("Anomaly segmentation"); axes[2].axis('off')

    anomaly_png = os.path.join(save_dir, f"{base}_anomaly.png")
    plt.tight_layout(pad=0)
    fig.savefig(anomaly_png, bbox_inches='tight', pad_inches=0)
    plt.close(fig)

    return gt_u8, anomaly_mask


# --- run for single pair ----------------------------------------------------
def process_single(image_path, seg_path, checkpoint_path, save_dir):
    Path(save_dir).mkdir(parents=True, exist_ok=True)

    # Instantiate & load model
    model = SwinMAE(norm_layer=partial(nn.LayerNorm, eps=1e-6), patch_norm=True)
    torch.serialization.add_safe_globals([argparse.Namespace])
    ck = torch.load(checkpoint_path, map_location='cpu')
    if isinstance(ck, dict) and 'model' in ck:
        st = ck['model']
    elif isinstance(ck, dict) and 'state_dict' in ck:
        st = ck['state_dict']
    else:
        st = ck
    new_st = {}
    for k, v in st.items():
        nk = k[len('module.'):] if k.startswith('module.') else k
        new_st[nk] = v
    model.load_state_dict(new_st, strict=False)
    model.to(DEVICE)
    model.eval()

    subj = Path(image_path).stem
    print(f"\n>>> Subject: {subj}")
    if not os.path.exists(image_path) or not os.path.exists(seg_path):
        raise ValueError("Missing files")

    with h5py.File(image_path, 'r') as f_img, h5py.File(seg_path, 'r') as f_seg:
        img_ds = f_img['image']  # shape (1,H,W,D)
        seg_key = 'label' if 'label' in f_seg else ('prediction' if 'prediction' in f_seg else None)
        if seg_key is None:
            raise ValueError("seg h5 missing 'label' or 'prediction'")
        seg_ds = f_seg[seg_key]
        _, H, W, D = img_ds.shape

        results = {}
        for s in range(D):
            orig = np.array(img_ds[0, :, :, s], dtype=np.float32)
            gt_mask = np.array(seg_ds[0, :, :, s], dtype=np.uint8)
            if gt_mask.sum() <= MIN_GT_PIXELS:
                continue

            img_rgb = np.stack([orig, orig, orig], axis=-1)
            if img_rgb.max() > 1.0:
                img_rgb = img_rgb / 255.0

            t0 = time.time()
            combined_heatmap, recon_for_vis = evaluate_slice_with_sliding_window(
                img_rgb, gt_mask, model,
                window_size=WINDOW_SIZE, step=WINDOW_STEP, window_batch=WINDOW_BATCH, device=DEVICE,
                save_dir=save_dir, base_name=subj, slice_idx=s
            )
            t1 = time.time()

            gt_out, pred_mask = post_process_and_save(
                comb_heat_map=combined_heatmap,
                org_img=img_rgb,
                gt=gt_mask,
                recon_img=recon_for_vis,
                save_dir=save_dir,
                subj_name=subj,
                slice_num=s
            )

            dice, tp, fp, fn = calculate_dice_score(gt_out, pred_mask)
            auprc, precisions, recalls, _ = compute_auprc(combined_heatmap, gt_out)

            vispath = os.path.join(save_dir, f"{subj}_slice_{s:03d}_compact.png")
            # save smaller compact
            fig, axes = plt.subplots(1, 3, figsize=(12, 4), dpi=140)
            axes[0].imshow(orig, cmap='gray')
            axes[0].set_title("Original Slice"); axes[0].axis('off')
            axes[1].imshow(orig, cmap='gray'); axes[1].imshow(gt_out, alpha=0.35, cmap='Reds')
            axes[1].set_title("Original+GT"); axes[1].axis('off')
            axes[2].imshow(recon_for_vis[..., 0], cmap='gray')
            axes[2].set_title("Reconstructed"); axes[2].axis('off')
            plt.tight_layout(pad=0); fig.savefig(vispath, bbox_inches='tight', pad_inches=0); plt.close(fig)

            results[s] = {
                "dice": float(dice),
                "tp": int(tp),
                "fp": int(fp),
                "fn": int(fn),
                "auprc": float(auprc) if np.isfinite(auprc) else None,
                "time_sec": float(t1 - t0),
                "vis": vispath,
                "anomaly": os.path.join(save_dir, f"{subj}_slice_{s:03d}_anomaly.png")
            }

            # cleanup
            del combined_heatmap, pred_mask, gt_out, recon_for_vis
            torch.cuda.empty_cache()

    print("Done processing.")
    return results, subj, save_dir


# --- Gradio functions ---
def process_files(image_h5, seg_h5, checkpoint):
    if not all([image_h5, seg_h5, checkpoint]):
        return gr.update(), "Please upload all files", None, None, None

    save_dir = tempfile.mkdtemp()
    try:
        results, subj, save_dir = process_single(image_h5.name, seg_h5.name, checkpoint.name, save_dir)
    except Exception as e:
        return gr.update(), f"Error: {str(e)}", None, None, None

    if not results:
        return gr.update(), "No slices processed (all skipped due to small GT)", None, None, None

    slices = sorted(results.keys())
    slice_choices = [f"Slice {s}" for s in slices]
    initial_slice = slice_choices[0]
    initial_s = slices[0]

    metrics = results[initial_s]
    metrics_str = f"DICE: {metrics['dice']:.4f}\nAUPRC: {metrics['auprc'] if metrics['auprc'] is not None else 'N/A'}\nTP: {metrics['tp']}\nFP: {metrics['fp']}\nFN: {metrics['fn']}\nTime: {metrics['time_sec']:.2f} sec"

    state = {"results": results, "subj": subj, "save_dir": save_dir, "slices": slices}

    return gr.update(choices=slice_choices, value=initial_slice), metrics_str, metrics["vis"], metrics["anomaly"], state


def update_slice(selected, state):
    if not state or not selected:
        return "No slice selected", None, None

    s_str = selected.split(" ")[1]
    s = int(s_str)

    metrics = state["results"][s]
    metrics_str = f"DICE: {metrics['dice']:.4f}\nAUPRC: {metrics['auprc'] if metrics['auprc'] is not None else 'N/A'}\nTP: {metrics['tp']}\nFP: {metrics['fp']}\nFN: {metrics['fn']}\nTime: {metrics['time_sec']:.2f} sec"

    return metrics_str, metrics["vis"], metrics["anomaly"]


# --- Gradio interface ---
with gr.Blocks(title="SwinMAE Inference") as demo:
    gr.Markdown("# SwinMAE Inferencing Interface")
    gr.Markdown("Upload the image .h5, anomaly seg .h5, and model checkpoint .pth. Process to generate visualizations per slice.")

    with gr.Row():
        image_h5 = gr.File(label="Image .h5 File")
        seg_h5 = gr.File(label="Anomaly Seg .h5 File")
        checkpoint = gr.File(label="Checkpoint .pth File")

    process_btn = gr.Button("Process Files")

    slice_dropdown = gr.Dropdown(label="Select Slice", choices=[], interactive=True)
    metrics_text = gr.Textbox(label="Slice Metrics", lines=6)
    compact_image = gr.Image(label="Compact Visualization (Original | Orig+GT | Recon)")
    anomaly_image = gr.Image(label="Anomaly Visualization (Atropos/Seg | Heatmap| Anomaly Seg)")

    state = gr.State()

    process_btn.click(
        process_files,
        inputs=[image_h5, seg_h5, checkpoint],
        outputs=[slice_dropdown, metrics_text, compact_image, anomaly_image, state]
    )
    slice_dropdown.change(
        update_slice,
        inputs=[slice_dropdown, state],
        outputs=[metrics_text, compact_image, anomaly_image]
    )

demo.launch()

Device: cpu
It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://de09c94b193f67f010.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


