In [1]:
import torch as tc 
import torch.nn as nn  
import numpy as np
from tqdm import tqdm
import os,sys,cv2
from torch.cuda.amp import autocast
import matplotlib.pyplot as plt
import albumentations as A
import segmentation_models_pytorch as smp
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from torch.nn.parallel import DataParallel
from glob import glob
import tifffile as tiff
from dotenv import load_dotenv
import os

import sys
sys.path.append('..')
import util

log_board = util.diagnostics.LogBoard('log_dir', 6005)
log_board.launch()

In [2]:
p_augm = 0.05 #0.5
#add rotate.  less p_augm

class CFG:
    # ============== pred target =============
    target_size = 1

    # ============== model CFG =============
    model_name = 'Unet'
    router_backbone = 'resnext50_32x4d'
    backbones = [
        'resnext50_32x4d',
        'resnext50_32x4d',
        'resnext50_32x4d',
        'resnext50_32x4d',
        # 'resnext101_32x8d',
        # 'se_resnext50_32x4d',
        # 'mit_b2'
    ]

    in_chans = 1   #5 # 65
    # ============== training CFG =============
    image_size = 1024 # 512 # 512
    input_size = 1024 # 512 # 512

    train_batch_size = 2 #4 #16
    valid_batch_size = 2

    epochs = 31 #30 #25
    lr = 8e-4
    chopping_percentile=1e-5
    # ============== fold =============
    valid_id = 1


    # ============== augmentation =============
    train_aug_list = [
        A.Rotate(limit=270, p= 0.5),
        A.RandomScale(scale_limit=(0.8,1.25),interpolation=cv2.INTER_CUBIC,p=p_augm),
        A.RandomCrop(input_size, input_size,p=1),
        A.RandomGamma(p=p_augm*2/3),
        A.RandomBrightnessContrast(p=p_augm,),
        A.GaussianBlur(p=p_augm),
        A.MotionBlur(p=p_augm),
        A.GridDistortion(num_steps=5, distort_limit=0.3, p=p_augm),
        ToTensorV2(transpose_mask=True),
    ]
    train_aug = A.Compose(train_aug_list)
    valid_aug_list = [
        ToTensorV2(transpose_mask=True),
    ]
    valid_aug = A.Compose(valid_aug_list)

In [3]:
def to_1024(img , image_size = 1024):
    if image_size > img.shape[1]:
       img = np.rot90(img)
       start1 = (CFG.image_size - img.shape[0])//2 
       top =     img[0                    : start1,   0: img.shape[1] ]
       bottom  = img[img.shape[0] -start1 : img.shape[0],   0 : img.shape[1] ]
       img_result = np.concatenate((top,img,bottom ),axis=0)
       img_result = np.rot90(img_result)
       img_result = np.rot90(img_result)
       img_result = np.rot90(img_result)
    else :
       img_result = img
    return img_result

def to_1024_no_rot(img, image_size = 1024):
    if image_size > img.shape[0]:  
       start1 = ( image_size - img.shape[0])//2
       top =     img[0                    : start1,   0: img.shape[1] ]
       bottom  = img[img.shape[0] -start1 : img.shape[0],   0 : img.shape[1] ]
       img_result = np.concatenate((top,img,bottom ),axis=0)
    else: 
       img_result = img
    return img_result

#  add border
def to_1024_1024(img  , image_size = 1024 ):
     img_result = to_1024(img, image_size )
     return img_result
    
#  drop border
def to_original ( im_after, img, image_size = 1024 ):
    top_ = 0
    left_ = 0
    if (im_after.shape[0] > img.shape[0]):
             top_  = ( image_size - img.shape[0])//2 
    if    (im_after.shape[1] > img.shape[1]) :
             left_  = ( image_size - img.shape[1])//2  
    if (top_>0)or (left_>0) :
             img_result = im_after[top_                    : img.shape[0] + top_,   left_: img.shape[1] + left_ ]
             #print(im_after.shape,'-->',img_result.shape)
    else:
             img_result = im_after
    return img_result  

In [4]:
def min_max_normalization(x:tc.Tensor)->tc.Tensor:
    """input.shape=(batch,f1,...)"""
    shape=x.shape
    if x.ndim>2:
        x=x.reshape(x.shape[0],-1)
    
    min_=x.min(dim=-1,keepdim=True)[0]
    max_=x.max(dim=-1,keepdim=True)[0]
    if min_.mean()==0 and max_.mean()==1:
        return x.reshape(shape)
    
    x=(x-min_)/(max_-min_+1e-9)
    return x.reshape(shape)

def norm_with_clip(x:tc.Tensor,smooth=1e-5):
    dim=list(range(1,x.ndim))
    mean=x.mean(dim=dim,keepdim=True)
    std=x.std(dim=dim,keepdim=True)
    x=(x-mean)/(std+smooth)
    x[x>5]=(x[x>5]-5)*1e-3 +5
    x[x<-3]=(x[x<-3]+3)*1e-3-3
    return x

def add_noise(x:tc.Tensor,max_randn_rate=0.1,randn_rate=None,x_already_normed=False):
    """input.shape=(batch,f1,f2,...) output's var will be normalizate  """
    ndim=x.ndim-1
    if x_already_normed:
        x_std=tc.ones([x.shape[0]]+[1]*ndim,device=x.device,dtype=x.dtype)
        x_mean=tc.zeros([x.shape[0]]+[1]*ndim,device=x.device,dtype=x.dtype)
    else: 
        dim=list(range(1,x.ndim))
        x_std=x.std(dim=dim,keepdim=True)
        x_mean=x.mean(dim=dim,keepdim=True)
    if randn_rate is None:
        randn_rate=max_randn_rate*np.random.rand()*tc.rand(x_mean.shape,device=x.device,dtype=x.dtype)
    cache=(x_std**2+(x_std*randn_rate)**2)**0.5
    #https://blog.csdn.net/chaosir1991/article/details/106960408
    
    return (x-x_mean+tc.randn(size=x.shape,device=x.device,dtype=x.dtype)*randn_rate*x_std)/(cache+1e-7)
 
class Data_loader(Dataset):
     
    def __init__(self,paths,is_label):
        self.paths=paths
        self.paths.sort()
        self.is_label=is_label
    
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self,index):
         
        img = cv2.imread(self.paths[index],cv2.IMREAD_GRAYSCALE)
        
        img = to_1024_1024(img , image_size = CFG.image_size ) #  to_original( im_after, img_save, image_size = 1024)

        img = tc.from_numpy(img.copy())
        if self.is_label:
            img=(img!=0).to(tc.uint8)*255
        else:
            img=img.to(tc.uint8)
        return img

def load_data(paths,is_label=False):
    data_loader=Data_loader(paths,is_label)
    data_loader=DataLoader(data_loader, batch_size=8, num_workers=2)  
    data=[]
    for x in tqdm(data_loader):
        data.append(x)
    x=tc.cat(data,dim=0)
    del data
    if not is_label:
        ########################################################################
        TH=x.reshape(-1).numpy()
        index = -int(len(TH) * CFG.chopping_percentile)
        TH:int = np.partition(TH, index)[index]
        x[x>TH]=int(TH)
        ########################################################################
        TH=x.reshape(-1).numpy()
        index = -int(len(TH) * CFG.chopping_percentile)
        TH:int = np.partition(TH, -index)[-index]
        x[x<TH]=int(TH)
        ########################################################################
        x=(min_max_normalization(x.to(tc.float16)[None])[0]*255).to(tc.uint8)
    return x


#https://www.kaggle.com/code/kashiwaba/sennet-hoa-train-unet-simple-baseline
def dice_coef(y_pred:tc.Tensor,y_true:tc.Tensor, thr=0.5, dim=(-1,-2), epsilon=0.001):
    #y_pred=y_pred.sigmoid()
    y_true = y_true.to(tc.float32)
    y_pred = (y_pred>thr).to(tc.float32)
    inter = (y_true*y_pred).sum(dim=dim)
    den = y_true.sum(dim=dim) + y_pred.sum(dim=dim)
    dice = ((2*inter+epsilon)/(den+epsilon)).mean()
    return dice

class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        # inputs = inputs.sigmoid()   
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice

class SurfaceDiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(SurfaceDiceLoss, self).__init__()

    def forward(self, pred, surface_true, volume_true, smooth=1, k_size=3):
        pred = pred.sigmoid()
        surface_pred = pred * surface_true
        volume_pred = pred * volume_true
      
        surface_pred = surface_pred.view(-1)
        surface_true = surface_true.view(-1)
        volume_pred = volume_pred.view(-1)
        volume_true = volume_true.view(-1)
        
        surface_intersection = (surface_pred * surface_true).sum()
        volume_intersection = (volume_pred * volume_true).sum()

        surface_dice = (2. * (surface_intersection + (0.35*volume_intersection)) + smooth) / (surface_pred.sum() + surface_true.sum() + (0.35*volume_pred.sum()) + (0.35*volume_true.sum()) + smooth)

        return 1 - surface_dice


class Kaggld_Dataset(Dataset):
    def __init__(self,x:list,y:list,arg=False):
        super(Dataset,self).__init__()
        self.x=x#list[(C,H,W),...]
        self.y=y#list[(C,H,W),...]
        self.image_size=CFG.image_size
        self.in_chans=CFG.in_chans
        self.arg=arg
        if arg:
            self.transform=CFG.train_aug
        else: 
            self.transform=CFG.valid_aug

    def __len__(self) -> int:
        return sum([y.shape[0]-self.in_chans for y in self.y])
    
    def __getitem__(self,index):
        i=0
        for x in self.x:
            if index>x.shape[0]-self.in_chans:
                index-=x.shape[0]-self.in_chans
                i+=1
            else:
                break
        x=self.x[i]
        y=self.y[i]
        
        print (f'x.shape[1] ={x.shape[1]}    x.shape[2]={x.shape[2]}')
        
        x_index= (x.shape[1]-self.image_size)//2 #np.random.randint(0,x.shape[1]-self.image_size)
        y_index= (x.shape[2]-self.image_size)//2 # np.random.randint(0,x.shape[2]-self.image_size)
        # i i+5 
        x=x[index:index+self.in_chans   ,   x_index:x_index+self.image_size,   y_index:y_index+self.image_size]
        # i+2
        y=y[index+self.in_chans//2   ,      x_index:x_index+self.image_size,   y_index:y_index+self.image_size]

        data = self.transform(image=x.numpy().transpose(1,2,0), mask=y.numpy())
        x = data['image']
        y = data['mask']>=127
        if self.arg:
            i=np.random.randint(4)
            x=x.rot90(i,dims=(1,2))
            y=y.rot90(i,dims=(0,1))
            for i in range(3):
                if np.random.randint(2):
                    x=x.flip(dims=(i,))
                    if i>=1:
                        y=y.flip(dims=(i-1,))
        return x,y#(uint8,uint8)

In [5]:

# Load training/valid data

train_x=[]
train_y=[]

root_path="/root/data/train"
train_datasets = ["kidney_1_dense"]
val_datasets = ["kidney_3_dense"]
def load_tiff_tensor(data_dir: str, dataset: str, ds_type: str, force_reload: bool = False) -> tc.Tensor:
    assert ds_type in ["images", "labels"], f"Invalid type {ds_type}"
    cache_fn = f"./bin/{dataset}/{ds_type}.pt"
    if not force_reload and os.path.exists(cache_fn):
        print(f'Loading {dataset} from cache')
        return tc.load(cache_fn)
    else:
        print(f'Loading {dataset} from tiff and saving to cache...')
        data = load_data(glob(f"{data_dir}/{dataset}/{ds_type}/*"),is_label=False)
        os.makedirs(os.path.dirname(cache_fn), exist_ok=True)
        tc.save(data, cache_fn)
        return data

force_reload = False
k1_img, k1_lbl = (
    load_tiff_tensor(root_path, "kidney_1_dense", "images", force_reload),
    load_tiff_tensor(root_path, "kidney_1_dense", "labels", force_reload),

    # util.data.kidney_1_fixed().squeeze(0).to(tc.uint8) * 255,
)
train_x = [k1_img, k1_img.permute(1,2,0), k1_img.permute(2,0,1)]
train_y = [k1_lbl, k1_lbl.permute(1,2,0), k1_lbl.permute(2,0,1)]

val_x, val_y = (
    load_tiff_tensor(root_path, "kidney_3_dense", "images", force_reload),
    load_tiff_tensor(root_path, "kidney_3_dense", "labels", force_reload),
)

train_dataset=Kaggld_Dataset(train_x,train_y,arg=True)
train_dataset = DataLoader(train_dataset, batch_size=CFG.train_batch_size ,num_workers=2, shuffle=True, pin_memory=True)
val_dataset=Kaggld_Dataset([val_x],[val_y])
val_dataset = DataLoader(val_dataset, batch_size=CFG.valid_batch_size, num_workers=2, shuffle=True, pin_memory=True)

Loading kidney_1_dense from cache


TensorFlow installation not found - running with reduced feature set.

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784



Loading kidney_1_dense from cache


Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.15.1 at http://localhost:6005/ (Press CTRL+C to quit)


Loading kidney_3_dense from cache
Loading kidney_3_dense from cache


In [6]:

# Mixture of Experts implementation inspired by [https://arxiv.org/pdf/2208.02813.pdf]

class Router(nn.Module):
    def __init__(self, CFG, weight=None):
        super().__init__()

        self.encoder = smp.Unet(
            encoder_name=CFG.router_backbone, 
            encoder_weights=weight,
            in_channels=CFG.in_chans,
            classes=CFG.target_size,
            activation=None,
        ).encoder
        self.classifier = nn.Sequential(
            nn.AdaptiveMaxPool2d(1),
            nn.Flatten(),
            nn.Linear(2048, len(CFG.backbones)),
        )

    def forward(self, x):
        z = self.encoder(x)[-1]
        z = self.classifier(z)
        return z

class CustomModel_MoE(nn.Module):
    def __init__(self, CFG, topk: int = 2, weight=None):
        super().__init__()
        assert 0 < topk <= len(CFG.backbones), f"topk should be in (0, {len(CFG.backbones)}]"
        self.topk = topk

        self.router = Router(CFG, weight)
        self.experts = nn.ModuleList([smp.Unet(
            encoder_name=backbone, 
            encoder_weights=weight,
            in_channels=CFG.in_chans,
            classes=CFG.target_size,
            activation=None,
        ) for backbone in CFG.backbones])
        self.register_buffer('n_experts', tc.tensor(len(CFG.backbones), dtype=tc.float32))

    def forward(self, image):
        router_logits = self.router(image)
        preds_list = [expert(image)[:, 0].sigmoid() for expert in self.experts]
        return preds_list, router_logits

    def predict(self, image):
        # Distribute inputs to top-k experts with router
        dist = nn.functional.softmax(self.router(image), dim=-1)

        batch_size, _num_classes = dist.size()
        top_k_values, top_k_indices = tc.topk(dist, self.topk, dim=1)
        batch_indices = tc.arange(batch_size).unsqueeze(1).expand(-1, self.topk)

        k_hot_weights = tc.zeros_like(dist)
        k_hot_outputs = tc.zeros_like(dist)
        k_hot_outputs[batch_indices, top_k_indices] = 1
        k_hot_weights[batch_indices, top_k_indices] = top_k_values / top_k_values.sum(dim=-1, keepdim=True)

        agg_pred = tc.zeros_like(image, dtype=tc.float32).squeeze(1)
        for i in range(len(self.experts)):
            idx = k_hot_outputs[:, i].bool()
            if idx.sum() > 0:
                agg_pred[idx] += self.experts[i](image[idx])[:, 0].sigmoid() * k_hot_weights[idx, i].unsqueeze(-1).unsqueeze(-1)
        
        return agg_pred

def build_model(weight="imagenet"):
    load_dotenv()

    print('model_name', CFG.model_name)
    print('router_backbone', CFG.router_backbone)
    print('expert_backbones', CFG.backbones)

    model = CustomModel_MoE(CFG, weight=weight)

    return model.cuda()

In [7]:

# Clear tensorboard

log_board.clear('train [MoE 512]')
log_board.clear('val [MoE 512]')
# log_board.clear('train (edge-better2)')
# log_board.clear('val (edge-better2)')

In [10]:
save_every = 500
extra_log_every = 25
explore_experts_until = 1000

t_logger = log_board.get_logger('train [MoE diff learning]')
v_logger = log_board.get_logger('val [MoE diff learning]')

model=build_model()
# model=DataParallel(model)
model.train()

edge = util.Edge().cuda()
dice_score = util.DiceScore().cuda()
loss_fc=DiceLoss()
router_loss_fn = nn.CrossEntropyLoss()
edge_weighted_dice_loss = util.EdgeWeightedDiceLoss(alpha=0.85).cuda()
focal_loss = util.BinaryFocalLoss(
    alpha=0.8,
    gamma=1.5,
)

optimizer_router = tc.optim.AdamW(model.router.parameters(),lr=CFG.lr*0.1)
optimizer=tc.optim.AdamW(model.experts.parameters(),lr=CFG.lr)
scaler=tc.cuda.amp.GradScaler()
scheduler = tc.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=CFG.lr,
                                                steps_per_epoch=len(train_dataset), epochs=CFG.epochs+1,
                                                pct_start=0.1,)

step = 0
agg_expert_load = tc.ones(len(model.experts), dtype=tc.float32).cuda()
valid_iterator = iter(val_dataset)
for epoch in range(CFG.epochs):
    for i,(x,y) in enumerate(train_dataset):
        step+=1
        # prepare data
        x=x.cuda().to(tc.float32)
        y=y.cuda().to(tc.float32)
        x=norm_with_clip(x.reshape(-1,*x.shape[2:])).reshape(x.shape)
        x=add_noise(x,max_randn_rate=0.5,x_already_normed=True)

        with autocast():
            # compute prediction
            pred_list, expert_dist_logits = model(x)
            expert_dist = expert_dist_logits.detach().softmax(dim=-1)
            if step < explore_experts_until:
                expert_dist = tc.ones_like(expert_dist) / model.n_experts

            # rebalance expert distribution for load/train balancing
            dist_balancer = (agg_expert_load.sum() / (agg_expert_load))
            dist_balancer = model.n_experts * dist_balancer.softmax(dim=-1)
            dist_balancer, dist_balancer.sum()
            balanced_dist = (expert_dist * dist_balancer)
            balanced_dist /= balanced_dist.sum(1, keepdim=True)
            agg_expert_load += balanced_dist.sum(0)
            agg_expert_load *= 0.95

            # combine expert predictions
            final_pred = tc.stack([
                p * w.unsqueeze(-1).unsqueeze(-1)
                for p, w in zip(pred_list, balanced_dist.unbind(1))
            ]).sum(0)
            
            # create labels for router
            expert_sample_dice = tc.stack([dice_score(p.detach(), y, 'separate') for p in pred_list], dim=1)
            router_labels = tc.argmax(expert_sample_dice, dim=1)

        # compute loss
        experts_loss = edge_weighted_dice_loss(final_pred.unsqueeze(1), y.unsqueeze(1)) * 4
        router_loss = router_loss_fn(expert_dist_logits, router_labels)
        loss = experts_loss + router_loss

        # gradient step
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.step(optimizer_router)
        scaler.update()
        optimizer.zero_grad()
        optimizer_router.zero_grad()
        scheduler.step()

        # performance logging
        with tc.no_grad():
            t_logger.add_scalars('loss', {
                'experts': experts_loss / model.n_experts,
                'router': router_loss,
            }, step)
            t_logger.add_scalar('lr', optimizer.param_groups[0]['lr'], step)
            t_logger.add_scalar('dice', dice_coef(final_pred.detach()>0.5,y), step)
            t_logger.add_scalar('surface-dice', dice_coef(
                edge((final_pred.detach() > 0.5).float().unsqueeze(1)).squeeze(1),
                edge(y.unsqueeze(1)).squeeze(1)
            ), step)
            t_logger.add_scalars('expert-load', {
                f'expert-{i}': w for i, w in enumerate(expert_dist.mean(0))
            }, step)
            del loss,pred_list, expert_dist_logits

            if (step + 1) % extra_log_every == 0:
                model.eval()
                try:
                    x, y= next(valid_iterator)
                except StopIteration:
                    valid_iterator = iter(val_dataset)
                    x, y = next(valid_iterator)
                
                x=x.cuda().to(tc.float32)
                y=y.cuda().to(tc.float32)
                # x, y = x.squeeze(1), y.squeeze(1).squeeze(1)
                x=norm_with_clip(x.reshape(-1,*x.shape[2:])).reshape(x.shape)

                with autocast():
                    pred = model.predict(x)
                v_logger.add_scalar('dice', dice_coef(pred.detach()>0.5,y), step)
                v_logger.add_scalar('surface-dice', dice_coef(
                    edge((pred > 0.5).float().unsqueeze(1)).squeeze(1),
                    edge(y.unsqueeze(1)).squeeze(1)
                ), step)
                model.train()
            
            if (step + 1) % save_every == 0:
                tc.save(model.state_dict(), f'./bin/_tmp_models/unet_MoE.pt')

model_name Unet
router_backbone resnext50_32x4d
expert_backbones ['resnext50_32x4d', 'resnext50_32x4d', 'resnext50_32x4d', 'resnext50_32x4d']


x.shape[1] =1706    x.shape[2]=1510x.shape[1] =1706    x.shape[2]=1510

x.shape[1] =1706    x.shape[2]=1510x.shape[1] =1706    x.shape[2]=1510

x.shape[1] =1706    x.shape[2]=1510x.shape[1] =1706    x.shape[2]=1510

x.shape[1] =1706    x.shape[2]=1510x.shape[1] =1706    x.shape[2]=1510

x.shape[1] =1024    x.shape[2]=2279
x.shape[1] =2279    x.shape[2]=1303
x.shape[1] =1024    x.shape[2]=2279
x.shape[1] =2279    x.shape[2]=1303
x.shape[1] =1303    x.shape[2]=1024
x.shape[1] =2279    x.shape[2]=1303
x.shape[1] =1303    x.shape[2]=1024
x.shape[1] =2279    x.shape[2]=1303
x.shape[1] =2279    x.shape[2]=1303
x.shape[1] =1024    x.shape[2]=2279
x.shape[1] =1303    x.shape[2]=1024
x.shape[1] =1024    x.shape[2]=2279
x.shape[1] =1024    x.shape[2]=2279
x.shape[1] =1303    x.shape[2]=1024
x.shape[1] =1024    x.shape[2]=2279
x.shape[1] =1303    x.shape[2]=1024
x.shape[1] =2279    x.shape[2]=1303
x.shape[1] =2279    x.shape[2]=1303
x.shape[1] =1303    x.shape[2]=1024
x.shape[1] =1303    x.shape[

In [9]:
tc.save(model.module.state_dict(), f'./bin/_tmp_models/unet2.5d_IN_PROGRESS.pt')

In [10]:

# Just testing inference options

save_every = 500
extra_log_every = 25
explore_experts_until = 1000

v_logger = log_board.get_logger('val [MoE topk=2]')

model=build_model()
model.load_state_dict(tc.load(f'./bin/_tmp_models/unet_MoE.pt',map_location='cuda:0'))
model.eval()
model.topk = 2

edge = util.Edge().cuda()
dice_score = util.DiceScore().cuda()
loss_fc=DiceLoss()
router_loss_fn = nn.CrossEntropyLoss()
edge_weighted_dice_loss = util.EdgeWeightedDiceLoss(alpha=0.85).cuda()
focal_loss = util.BinaryFocalLoss(
    alpha=0.8,
    gamma=1.5,
)

optimizer_router = tc.optim.AdamW(model.router.parameters(),lr=CFG.lr*0.1)
optimizer=tc.optim.AdamW(model.experts.parameters(),lr=CFG.lr)
scaler=tc.cuda.amp.GradScaler()
scheduler = tc.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=CFG.lr,
                                                steps_per_epoch=len(train_dataset), epochs=CFG.epochs+1,
                                                pct_start=0.1,)

step = 0
agg_expert_load = tc.ones(len(model.experts), dtype=tc.float32).cuda()
valid_iterator = iter(val_dataset)
with tc.no_grad():
    while True:
        step += 1
        try:
            x, y= next(valid_iterator)
        except StopIteration:
            valid_iterator = iter(val_dataset)
            x, y = next(valid_iterator)
        x=x.cuda().to(tc.float32)
        y=y.cuda().to(tc.float32)
        # x, y = x.squeeze(1), y.squeeze(1).squeeze(1)
        x=norm_with_clip(x.reshape(-1,*x.shape[2:])).reshape(x.shape)

        with autocast():
            pred = model.predict(x)
        v_logger.add_scalar('dice', dice_coef(pred.detach()>0.5,y), step)
        v_logger.add_scalar('surface-dice', dice_coef(
            edge((pred > 0.5).float().unsqueeze(1)).squeeze(1),
            edge(y.unsqueeze(1)).squeeze(1)
        ), step)




# for epoch in range(CFG.epochs):
#     for i,(x,y) in enumerate(train_dataset):
#         step+=1
#         # prepare data
#         x=x.cuda().to(tc.float32)
#         y=y.cuda().to(tc.float32)
#         x=norm_with_clip(x.reshape(-1,*x.shape[2:])).reshape(x.shape)
#         x=add_noise(x,max_randn_rate=0.5,x_already_normed=True)

#         with autocast():
#             # compute prediction
#             pred_list, expert_dist_logits = model(x)
#             expert_dist = expert_dist_logits.detach().softmax(dim=-1)
#             if step < explore_experts_until:
#                 expert_dist = tc.ones_like(expert_dist) / model.n_experts

#             # rebalance expert distribution for load/train balancing
#             dist_balancer = (agg_expert_load.sum() / (agg_expert_load))
#             dist_balancer = model.n_experts * dist_balancer.softmax(dim=-1)
#             dist_balancer, dist_balancer.sum()
#             balanced_dist = (expert_dist * dist_balancer)
#             balanced_dist /= balanced_dist.sum(1, keepdim=True)
#             agg_expert_load += balanced_dist.sum(0)
#             agg_expert_load *= 0.95

#             # combine expert predictions
#             final_pred = tc.stack([
#                 p * w.unsqueeze(-1).unsqueeze(-1)
#                 for p, w in zip(pred_list, balanced_dist.unbind(1))
#             ]).sum(0)
            
#             # create labels for router
#             expert_sample_dice = tc.stack([dice_score(p.detach(), y, 'separate') for p in pred_list], dim=1)
#             router_labels = tc.argmax(expert_sample_dice, dim=1)

#         # compute loss
#         experts_loss = edge_weighted_dice_loss(final_pred.unsqueeze(1), y.unsqueeze(1)) * 4
#         router_loss = router_loss_fn(expert_dist_logits, router_labels)
#         loss = experts_loss + router_loss

#         # gradient step
#         scaler.scale(loss).backward()
#         scaler.step(optimizer)
#         scaler.step(optimizer_router)
#         scaler.update()
#         optimizer.zero_grad()
#         optimizer_router.zero_grad()
#         scheduler.step()

#         # performance logging
#         with tc.no_grad():
#             t_logger.add_scalars('loss', {
#                 'experts': experts_loss / model.n_experts,
#                 'router': router_loss,
#             }, step)
#             t_logger.add_scalar('lr', optimizer.param_groups[0]['lr'], step)
#             t_logger.add_scalar('dice', dice_coef(final_pred.detach()>0.5,y), step)
#             t_logger.add_scalar('surface-dice', dice_coef(
#                 edge((final_pred.detach() > 0.5).float().unsqueeze(1)).squeeze(1),
#                 edge(y.unsqueeze(1)).squeeze(1)
#             ), step)
#             t_logger.add_scalars('expert-load', {
#                 f'expert-{i}': w for i, w in enumerate(expert_dist.mean(0))
#             }, step)
#             del loss,pred_list, expert_dist_logits

#             if (step + 1) % extra_log_every == 0:
#                 model.eval()
#                 try:
#                     x, y= next(valid_iterator)
#                 except StopIteration:
#                     valid_iterator = iter(val_dataset)
#                     x, y = next(valid_iterator)
                
#                 x=x.cuda().to(tc.float32)
#                 y=y.cuda().to(tc.float32)
#                 # x, y = x.squeeze(1), y.squeeze(1).squeeze(1)
#                 x=norm_with_clip(x.reshape(-1,*x.shape[2:])).reshape(x.shape)

#                 with autocast():
#                     pred = model.predict(x)
#                 v_logger.add_scalar('dice', dice_coef(pred.detach()>0.5,y), step)
#                 v_logger.add_scalar('surface-dice', dice_coef(
#                     e  dge((pred > 0.5).float().unsqueeze(1)).squeeze(1),
#                     edge(y.unsqueeze(1)).squeeze(1)
#                 ), step)
#                 model.train()
            
#             if (step + 1) % save_every == 0:
#                 tc.save(model.state_dict(), f'./bin/_tmp_models/unet_MoE.pt')


#     time.close()
    
#     model.eval()
#     time=tqdm(range(len(val_dataset)))
#     val_losss=0
#     val_scores=0
#     for i,(x,y) in enumerate(val_dataset):
#         x=x.cuda().to(tc.float32)
#         y=y.cuda().to(tc.float32)
#         x=norm_with_clip(x.reshape(-1,*x.shape[2:])).reshape(x.shape)

#         with autocast():
#             with tc.no_grad():
#                 pred=model(x)
#                 loss=loss_fc(pred,y)
#         score=dice_coef(pred.detach(),y)
#         val_losss=(val_losss*i+loss.item())/(i+1)
#         val_scores=(val_scores*i+score)/(i+1)
#         time.set_description(f"val-->loss:{val_losss:.4f},score:{val_scores:.4f}")
#         time.update()

#     time.close()
# #tc.save(model.module.state_dict(),f"./{CFG.backbone}_{epoch}_loss{losss:.2f}_score{scores:.2f}_val_loss{val_losss:.2f}_val_score{val_scores:.2f}_midd_1024.pt")

# time.close()

model_name Unet
router_backbone resnext50_32x4d
expert_backbones ['resnext50_32x4d', 'resnext50_32x4d', 'resnext50_32x4d', 'resnext50_32x4d']


x.shape[1] =1706    x.shape[2]=1510
x.shape[1] =1706    x.shape[2]=1510
x.shape[1] =1706    x.shape[2]=1510
x.shape[1] =1706    x.shape[2]=1510x.shape[1] =1706    x.shape[2]=1510

x.shape[1] =1706    x.shape[2]=1510
x.shape[1] =1706    x.shape[2]=1510
x.shape[1] =1706    x.shape[2]=1510
x.shape[1] =1706    x.shape[2]=1510
x.shape[1] =1706    x.shape[2]=1510
x.shape[1] =1706    x.shape[2]=1510
x.shape[1] =1706    x.shape[2]=1510
x.shape[1] =1706    x.shape[2]=1510
x.shape[1] =1706    x.shape[2]=1510
x.shape[1] =1706    x.shape[2]=1510
x.shape[1] =1706    x.shape[2]=1510
x.shape[1] =1706    x.shape[2]=1510
x.shape[1] =1706    x.shape[2]=1510
x.shape[1] =1706    x.shape[2]=1510
x.shape[1] =1706    x.shape[2]=1510
x.shape[1] =1706    x.shape[2]=1510
x.shape[1] =1706    x.shape[2]=1510
x.shape[1] =1706    x.shape[2]=1510
x.shape[1] =1706    x.shape[2]=1510
x.shape[1] =1706    x.shape[2]=1510
x.shape[1] =1706    x.shape[2]=1510
x.shape[1] =1706    x.shape[2]=1510
x.shape[1] =1706    x.shape[

KeyboardInterrupt: 