In [1]:
import cv2
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import numpy as np
from collections import defaultdict
from pathlib import Path
from tqdm import tqdm
import time
import matplotlib.pyplot as plt
from segmentation.scr.utils.utils import set_seed
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import segmentation_models_pytorch as smp
from segmentation.scr.utils import losses, transforms
from segmentation.models.unet import unet

from segmentation.scr.utils import losses, transforms
from segmentation.scr.utils.utils import set_seed
from segmentation.scr.utils.metrics import dice_coef
from segmentation.scr.utils.utils import save_model

In [None]:
class CFG:
    # ============== pred target =============
    target_size = 1

    # ============== model CFG =============
    model_name = 'Unet'
    backbone = 'se_resnext50_32x4d'

    in_chans = 3 # 65
    # ============== training CFG =============
    image_size = 512
    input_size=512

    train_batch_size = 4
    n_accumulate = max(1, 16 // train_batch_size)
    valid_batch_size = train_batch_size * 2

    epochs = 20
    lr = 3e-4
    chopping_percentile=1e-3
    # ============== fold =============
    valid_id = 1


    # ============== augmentation =============
    train_aug_list = [
        A.Rotate(limit=45, p=0.5),
        A.RandomScale(scale_limit=(0.8,1.25),interpolation=cv2.INTER_CUBIC,p=0.5),
        A.RandomCrop(input_size, input_size,p=1),
        A.RandomGamma(p=0.75),
        A.RandomBrightnessContrast(p=0.5,),
        A.GaussianBlur(p=0.5),
        A.MotionBlur(p=0.5),
        A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.5),
        ToTensorV2(transpose_mask=True),
    ]
    train_aug = A.Compose(train_aug_list)
    valid_aug_list = [
        ToTensorV2(transpose_mask=True),
    ]
    valid_aug = A.Compose(valid_aug_list)

In [None]:
class Data_loader(Dataset):
    def __init__(self,paths,is_label):
        self.paths=paths
        self.paths.sort()
        self.is_label=is_label
    
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self,index):
        img=cv2.imread(self.paths[index],cv2.IMREAD_GRAYSCALE)
        img=torch.from_numpy(img)
        if self.is_label:
            img=(img!=0).to(torch.uint8)*255
        else:
            img=img.to(torch.uint8)
        return img

In [None]:
CHOPPING_PER = 1e-3
def min_max_normalization(x:torch.Tensor)->torch.Tensor:
    """input.shape=(batch,f1,...)"""
    shape=x.shape
    if x.ndim>2:
        x=x.reshape(x.shape[0],-1)
    
    min_=x.min(dim=-1,keepdim=True)[0]
    max_=x.max(dim=-1,keepdim=True)[0]
    if min_.mean()==0 and max_.mean()==1:
        return x.reshape(shape)
    
    x=(x-min_)/(max_-min_+1e-9)
    return x.reshape(shape)

def norm_with_clip(x:torch.Tensor,smooth=1e-5):
    dim=list(range(1,x.ndim))
    mean=x.mean(dim=dim,keepdim=True)
    std=x.std(dim=dim,keepdim=True)
    x=(x-mean)/(std+smooth)
    x[x>5]=(x[x>5]-5)*1e-3 +5
    x[x<-3]=(x[x<-3]+3)*1e-3-3
    return x

def filter_noise(x):
    TH=x.reshape(-1)
    index = -int(len(TH) * CFG.chopping_percentile)
    TH:int = np.partition(TH, index)[index]
    x[x>TH]=int(TH)
    ########################################################################
    TH=x.reshape(-1)
    index = -int(len(TH) * CFG.chopping_percentile)
    TH:int = np.partition(TH, -index)[-index]
    x[x<TH]=int(TH)
    return x



In [None]:
def load_data(paths,is_label=False):
    data_loader=Data_loader(paths,is_label)
    data_loader=DataLoader(data_loader, batch_size=16)
    data=[]
    for x in tqdm(data_loader):
        data.append(x)
    x=torch.cat(data,dim=0)
    del data
    if not is_label:
      #  ########################################################################
      x = filter_noise(x)
      x=(min_max_normalization(x.to(torch.float16)[None])[0]*255).to(torch.uint8)
    return x

In [None]:
path_img_dir="data\\train\\kidney_1_dense\\images"
path_lb_dir="data\\train\\kidney_1_dense\\labels"
path_img_dir = Path(path_img_dir)
path_lb_dir = Path(path_lb_dir)

path_img_dir = sorted(list(path_img_dir.rglob("*.tif")))
path_lb_dir = sorted(list(path_lb_dir.rglob("*.tif")))

images_labels = defaultdict(list)
for img in path_img_dir:
    images_labels[img.name].append(img)
for lb in path_lb_dir:
    images_labels[lb.name].append(lb)
new_dict = dict(filter(lambda item: len(item[1]) > 1, images_labels.items()))
print(f"Общее число изображений с масками сегментации : {len(new_dict)}")
img_path = list(map(lambda key : str(new_dict[key][0]) , new_dict))
lb_path = list(map(lambda key : str(new_dict[key][1]), new_dict))

In [None]:
train_x=[]
train_y=[]
x=load_data(img_path ,is_label=False)

y=load_data(lb_path,is_label=True)
print(y.shape)
train_x.append(x)
train_y.append(y)

train_x.append(x.permute(1,2,0))
train_y.append(y.permute(1,2,0))
train_x.append(x.permute(2,0,1))
train_y.append(y.permute(2,0,1))

In [None]:
path_img_dir="data\\train\\kidney_3_sparse\\images"
path_lb_dir="data\\train\\kidney_3_dense\\labels"
path_img_dir = Path(path_img_dir)
path_lb_dir = Path(path_lb_dir)

path_img_dir = sorted(list(path_img_dir.rglob("*.tif")))
path_lb_dir = sorted(list(path_lb_dir.rglob("*.tif")))

images_labels = defaultdict(list)
for img in path_img_dir:
    images_labels[img.name].append(img)
for lb in path_lb_dir:
    images_labels[lb.name].append(lb)
new_dict = dict(filter(lambda item: len(item[1]) > 1, images_labels.items()))
print(f"Общее число изображений с масками сегментации : {len(new_dict)}")
img_path = list(map(lambda key : str(new_dict[key][0]) , new_dict))
lb_path = list(map(lambda key : str(new_dict[key][1]), new_dict))

In [None]:
val_x=load_data(img_path ,is_label=False)
print(val_x.shape)
val_y=load_data(lb_path ,is_label=True)
print(val_y.shape)

In [None]:
class Kaggld_Dataset(Dataset):
    def __init__(self,x:list,y:list,arg=False):
        super(Dataset,self).__init__()
        self.x=x#list[(C,H,W),...]
        self.y=y#list[(C,H,W),...]
        self.image_size=CFG.image_size
        self.in_chans=CFG.in_chans
        self.arg=arg
        if arg:
            self.transform=CFG.train_aug
        else: 
            self.transform=CFG.valid_aug
            
    def __len__(self) -> int:
        return sum([y.shape[0]-self.in_chans for y in self.y])
    
    def __getitem__(self,index):
        i=0
        for x in self.x:
            if index>x.shape[0]-self.in_chans:
                index-=x.shape[0]-self.in_chans
                i+=1
            else:
                break
        x=self.x[i]
        y=self.y[i]
        
        x_index=np.random.randint(0,x.shape[1]-self.image_size)
        y_index=np.random.randint(0,x.shape[2]-self.image_size)

        x=x[index:index+self.in_chans,x_index:x_index+self.image_size,y_index:y_index+self.image_size]
        y=y[index+self.in_chans//2,x_index:x_index+self.image_size,y_index:y_index+self.image_size]

        data = self.transform(image=x.numpy().transpose(1,2,0), mask=y.numpy())
        x = data['image']
        y = data['mask']>=127
        if self.arg:
            i=np.random.randint(4)
            x=x.rot90(i,dims=(1,2))
            y=y.rot90(i,dims=(0,1))
            for i in range(3):
                if np.random.randint(2):
                    x=x.flip(dims=(i,))
                    if i>=1:
                        y=y.flip(dims=(i-1,))
        return x,y#(uint8,uint8)
           
        

In [None]:
import torch as tc
def norm_with_clip(x:tc.Tensor,smooth=1e-5):
    dim=list(range(1,x.ndim))
    mean=x.mean(dim=dim,keepdim=True)
    std=x.std(dim=dim,keepdim=True)
    x=(x-mean)/(std+smooth)
    x[x>5]=(x[x>5]-5)*1e-3 +5
    x[x<-3]=(x[x<-3]+3)*1e-3-3
    return x

def add_noise(x:tc.Tensor,max_randn_rate=0.1,randn_rate=None,x_already_normed=False):
    """input.shape=(batch,f1,f2,...) output's var will be normalizate  """
    ndim=x.ndim-1
    if x_already_normed:
        x_std=tc.ones([x.shape[0]]+[1]*ndim,device=x.device,dtype=x.dtype)
        x_mean=tc.zeros([x.shape[0]]+[1]*ndim,device=x.device,dtype=x.dtype)
    else: 
        dim=list(range(1,x.ndim))
        x_std=x.std(dim=dim,keepdim=True)
        x_mean=x.mean(dim=dim,keepdim=True)
    if randn_rate is None:
        randn_rate=max_randn_rate*np.random.rand()*tc.rand(x_mean.shape,device=x.device,dtype=x.dtype)
    cache=(x_std**2+(x_std*randn_rate)**2)**0.5
    #https://blog.csdn.net/chaosir1991/article/details/106960408
    
    return (x-x_mean+tc.randn(size=x.shape,device=x.device,dtype=x.dtype)*randn_rate*x_std)/(cache+1e-7)

In [None]:
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

set_seed(42)  
train_dataset=Kaggld_Dataset(train_x,train_y,arg=True)
train_loader = DataLoader(train_dataset, batch_size=CFG.train_batch_size , shuffle=True )
val_dataset=Kaggld_Dataset([val_x],[val_y])
val_loader = DataLoader(val_dataset, batch_size=CFG.valid_batch_size, shuffle=False, )

model=unet.UNet(n_channels=CFG.in_chans, n_classes=CFG.target_size)
model = model.cuda()


loss_fn=losses.DiceLoss()
#DiceLoss()
#loss_fn=nn.BCEWithLogitsLoss()
optimizer=torch.optim.AdamW(model.parameters(),lr=CFG.lr)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=CFG.lr,
                                                steps_per_epoch=len(train_dataset), epochs=CFG.epochs+1,
                                                pct_start=0.1,)
        

train_metrics, val_metrics,train_losses, val_losses = [], [], [],[]
best_metric = -np.inf  

In [None]:
for epoch in range(CFG.epochs):
    model.train()
    
    time=tqdm(range(len(train_loader)))
    losss=0
    scores=0
    
    for i,(x,y) in enumerate(train_loader):
        x=x.cuda().to(tc.float32)
        y=y.cuda().to(tc.float32)
        x=norm_with_clip(x.reshape(-1,*x.shape[2:])).reshape(x.shape)
        x=add_noise(x,max_randn_rate=0.5,x_already_normed=True)
        
        pred=model(x)
        loss=loss_fn(pred,y)
        loss = loss / CFG.n_accumulate
 
        loss.backward()  # loss.backward()  # backward-pass


        if (i + 1) % CFG.n_accumulate == 0 or (i + 1 == len(train_loader)):
        
            optimizer.step()  # update weights
            optimizer.zero_grad()
            scheduler.step()
        score=dice_coef(pred.detach(),y)
        losss=(losss*i+loss.item())/(i+1)
        scores=(scores*i+score)/(i+1)
        time.set_description(f"epoch:{epoch},loss:{losss:.4f},score:{scores:.4f},lr{optimizer.param_groups[0]['lr']:.4e}")
        time.update()
        del loss,pred
    train_losses.append(losss)
    train_metrics.append(scores)
    
    
    time.close()
    
    model.eval()
    time=tqdm(range(len(val_dataset)))
    val_losss=0
    val_scores=0
    for i,(x,y) in enumerate(val_loader):
        x=x.cuda().to(tc.float32)
        y=y.cuda().to(tc.float32)
        x=norm_with_clip(x.reshape(-1,*x.shape[2:])).reshape(x.shape)

       
        with torch.no_grad():
            pred=model(x)
            loss=loss_fn(pred,y)
        score=dice_coef(pred.detach(),y)
        val_losss=(val_losss*i+loss.item())/(i+1)
        val_scores=(val_scores*i+score)/(i+1)
        time.set_description(f"val-->loss:{val_losss:.4f},score:{val_scores:.4f}")
        time.update()

    time.close()
    val_metrics.append(val_scores)
    val_losses.append(val_losss)
    if val_scores > best_metric:
        best_metric = val_scores
        save_model(
                model=model,
                optimizer=optimizer,
                model_name=model.__class__.__name__
                + "_best_model_at_"
                + str(epoch + 1),
                path='w',
                lr_scheduler=scheduler,
            )
    with open("train_results.txt", "w") as file_handler:
            file_handler.write("train_loss\n")
            for item in train_losses:
                file_handler.write("{}\t".format(item))

            file_handler.write("\nval_loss\n")
            for item in val_losses:
                file_handler.write("{}\t".format(item))

            file_handler.write("\nval_metrics\n")
            for item in val_metrics:
                file_handler.write("{}\t".format(item))
                
            file_handler.write("\ntrain_metrics\n")
            for item in train_metrics:
                file_handler.write("{}\t".format(item))

    
    


time.close()