In [1]:
from time import time
import torch as tc 
import torch.nn as nn  
import numpy as np
from tqdm import tqdm
import os,sys,cv2
from torch.cuda.amp import autocast
import matplotlib.pyplot as plt
import albumentations as A
import segmentation_models_pytorch as smp
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from glob import glob
import tifffile as tiff
from dotenv import load_dotenv
import os
from typing import Optional

import sys
sys.path.append('..')
import util

# import tensorrt as trt
# import pycuda.driver as cuda
# import pycuda.autoinit  # This is needed for initializing CUDA driver


class CFG:
    model_name = 'Unet'
    router_backbone = 'resnext50_32x4d'
    backbones = ["efficientnet-b0"] * 25

    input_size = 1024
    exp_input_size = 128

    batch = 2
    exp_batch = 128
    exp_batch_eval = 256

class DiceScore(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.register_buffer('smooth', tc.tensor(smooth))

    def forward(self, p_y: tc.Tensor, y: tc.Tensor, w: Optional[tc.Tensor] = None, mode: str = 'together') -> tc.Tensor:
        """
        p_y: (B, ...) Tensor of probabilities
        y: (B, ...) Tensor of binary labels
        w: (B) Tensor of sample weightings
        """
        
        assert mode in ['together', 'separate'], "Mode must be 'together' or 'separate'"
        flat_prob = p_y.view(p_y.shape[0], -1)
        flat_y = y.view(y.shape[0], -1)

        intersection = (flat_prob * flat_y).sum(1)
        cardinality = flat_prob.sum(1) + flat_y.sum(1)
        if w is not None:
            intersection *= w
            cardinality *= w
        if mode == 'together':
            return (2. * intersection.sum() + self.smooth) / (cardinality.sum() + self.smooth)
        return ((2. * intersection + self.smooth) / (cardinality + self.smooth))

        # if mode == 'together':
        #     intersection = (flat_prob * flat_y).sum()
        #     cardinality = flat_prob.sum() + flat_y.sum()
        #     return (2. * intersection + self.smooth) / (cardinality + self.smooth)

        # elif mode == 'separate':
        #     intersection = (flat_prob * flat_y).sum(1)
        #     cardinality = flat_prob.sum(1) + flat_y.sum(1)
        #     return ((2. * intersection + self.smooth) / (cardinality + self.smooth))

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [2]:
class Router(nn.Module):
    def __init__(self, in_f: int, out_f: int, CFG: CFG, weight=None):
        super().__init__()

        self.encoder = smp.Unet(
            encoder_name=CFG.router_backbone, 
            encoder_weights=weight,
            in_channels=in_f,
            classes=out_f,
            activation=None,
        ).encoder
        self.classifier = nn.Sequential(
            nn.AdaptiveMaxPool2d(CFG.input_size // CFG.exp_input_size),
            nn.Conv2d(2048, len(CFG.backbones), kernel_size=1, stride=1, padding=0, bias=False),
        )

    def forward(self, x: tc.Tensor) -> tc.Tensor:
        z = self.encoder(x)[-1]
        z = self.classifier(z)
        return z

class Expert(nn.Module):
    def __init__(self, backbone: str,  in_f: int, out_f: int, CFG: CFG, weight=None):
        super().__init__()
        self.batch = CFG.exp_batch, CFG.exp_batch_eval
        self.net = smp.Unet(
            encoder_name=backbone,
            encoder_weights=weight,
            in_channels=in_f,
            classes=out_f,
            activation=None,
        )

    def forward(self, x: tc.Tensor) -> tc.Tensor:
        batch = self.batch[0] if self.training else self.batch[1]
        for i in range(0, len(x), batch):
            y = self.net(x[i:i+batch])
            if i == 0:
                out = y
            else:
                out = tc.cat((out, y), 0)
        return out

class SegMoE(nn.Module):
    def __init__(self, topk: int, CFG: CFG, loss_scaler: tc.cuda.amp.GradScaler, weight=None):
        super().__init__()
        assert 0 < topk <= len(CFG.backbones), f"topk should be in (0, {len(CFG.backbones)}]"
        self.topk = topk
        self.dice = DiceScore()
        self.scaler=loss_scaler
        self.router_loss = nn.CrossEntropyLoss()

        self.router = Router(1,1, CFG, weight)
        self.experts = nn.ModuleList([smp.Unet(
            encoder_name=backbone, 
            encoder_weights=weight,
            in_channels=1,
            classes=1,
            activation=None,
        ) for backbone in CFG.backbones])
        self.register_buffer('n_experts', tc.tensor(len(self.experts), dtype=tc.float32))
        self.grid_size = CFG.input_size // CFG.exp_input_size

        self.expert_load_dist = tc.ones(len(self.experts), dtype=tc.float32).cuda()

    def _to_patches(self, x: tc.Tensor) -> tc.Tensor:
        return x \
            .unfold(2, CFG.exp_input_size, CFG.exp_input_size) \
            .unfold(3, CFG.exp_input_size, CFG.exp_input_size) \
            .reshape(-1, 1, CFG.exp_input_size, CFG.exp_input_size)

    def _smooth_weight(self, expert_dist: tc.Tensor) -> tc.Tensor:
        dist_balancer = (self.expert_load_dist.sum() / (self.expert_load_dist))
        dist_balancer = self.n_experts * dist_balancer.softmax(dim=-1)
        dist_balancer, dist_balancer.sum()
        balanced_dist = (expert_dist * dist_balancer)
        balanced_dist /= balanced_dist.sum(1, keepdim=True) # must sum to 1 on expert-dist axis
        self.expert_load_dist += balanced_dist.sum(0)
        self.expert_load_dist *= 0.9

        return balanced_dist

    def _assemble_patches(self, patches: tc.Tensor) -> tc.Tensor:
        return patches \
            .view(-1, self.grid_size, self.grid_size, 1, CFG.exp_input_size, CFG.exp_input_size) \
            .permute(0, 3, 1, 4, 2, 5) \
            .reshape(-1, 1, CFG.input_size, CFG.input_size)


    def _predict(self, x: tc.Tensor, dist: tc.Tensor) -> tc.Tensor:
        batch_size, _num_classes = dist.size()
        top_k_values, top_k_indices = tc.topk(dist, self.topk, dim=1)
        batch_indices = tc.arange(batch_size).unsqueeze(1).expand(-1, self.topk)

        k_hot_weights = tc.zeros_like(dist)
        k_hot_outputs = tc.zeros_like(dist)
        k_hot_outputs[batch_indices, top_k_indices] = 1
        k_hot_weights[batch_indices, top_k_indices] = top_k_values / top_k_values.sum(dim=-1, keepdim=True)

        agg_pred = tc.zeros_like(x, dtype=tc.float32).squeeze(1)
        for i in range(len(self.experts)):
            idx = k_hot_outputs[:, i].bool()
            if idx.sum() > 0:
                agg_pred[idx] += self.experts[i](x[idx])[:, 0].sigmoid() * k_hot_weights[idx, i].unsqueeze(-1).unsqueeze(-1)
        
        return self._assemble_patches(agg_pred)

    def forward(self, image: tc.Tensor, label: Optional[tc.Tensor] = None, flat_dist: bool = False) -> (tc.Tensor, tc.Tensor):
        # if self.training: assert label is not None, "Expected label tensor in training mode."
        # router_logits = self.router(image)
        # preds_list = [expert(image)[:, 0].sigmoid() for expert in self.experts]
        # return preds_list, router_logits

        with autocast():
            router_logits = self.router(image).permute(0, 2, 3, 1).reshape(-1, len(self.experts))
            expert_dist = router_logits.detach().softmax(1)
            expert_inputs = self._to_patches(image)

            # Sparse predict if eval, else prep dist & labels for training
            if not self.training: return self._predict(expert_inputs, expert_dist), expert_dist
            assert label is not None, "Expected label tensor in training mode."
            expert_labels = self._to_patches(label)
            expert_dist = self._smooth_weight(expert_dist) if not flat_dist \
                else tc.ones_like(expert_dist) / self.n_experts

        experts_sample_dice = []
        agg_pred = tc.zeros_like(expert_inputs, dtype=tc.float32)
        for expert, w in zip(self.experts, expert_dist.unbind(1)): # every expert sees all examples
            with autocast():
                # w = w.flatten()
                z = expert(expert_inputs).sigmoid()
                expert_loss = (1 - self.dice(z, expert_labels, w))
                sample_dice = self.dice(z.detach(), expert_labels, mode='separate')

            self.scaler.scale(expert_loss).backward()
            experts_sample_dice.append(sample_dice)
            agg_pred += z * w.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        
        # compute router loss
        experts_sample_dice = tc.stack(experts_sample_dice, dim=1)
        router_labels = tc.argmax(experts_sample_dice, dim=1)
        # router_loss = self.router_loss(router_logits, router_labels)
        router_loss = self.router_loss(router_logits, router_labels)
        self.scaler.scale(router_loss).backward()

        return self._assemble_patches(agg_pred), expert_dist

        # expert_pred_list = [expert(expert_inputs).sigmoid() for expert in self.experts]
        # combined_preds = tc.stack([
        #     expert_pred_list[i] * router_logits.softmax(1)[:, i].flatten().unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        #     for i in range(len(self.experts))
        # ]).sum(0)

        # seg_pred = agg_pred.view(-1, self.grid_size, self.grid_size, 1, CFG.exp_input_size, CFG.exp_input_size)
        # seg_pred = seg_pred.permute(0, 3, 1, 4, 2, 5).reshape(-1, 1, CFG.input_size, CFG.input_size)




In [3]:
scaler = tc.cuda.amp.GradScaler()
model = SegMoE(1, CFG, scaler, "imagenet").cuda()
model.train()
1

1

In [4]:
model.eval()
model.expert_load_dist

tensor([1., 1., 1., 1.], device='cuda:0')

In [7]:
batch_size = 2
x, y = tc.rand(batch_size, 1, CFG.input_size, CFG.input_size), tc.rand(batch_size, 1, CFG.input_size, CFG.input_size)
x=x.cuda().to(tc.float32)
y=y.cuda().to(tc.float32)
# x=norm_with_clip(x.reshape(-1,*x.shape[2:])).reshape(x.shape)
# x=add_noise(x,max_randn_rate=0.5,x_already_normed=True)

# with autocast():
    # compute prediction
pred = model(x, y)
    # expert_dist = expert_dist_logits.detach().softmax(dim=-1)
    # if step < explore_experts_until:
    #     expert_dist = tc.ones_like(expert_dist) / model.n_experts

