In [1]:
from segmentation.scr.utils.utils import set_seed
from segmentation.scr.utils.metrics import dice_coef
from segmentation.models.smp import get_pretrained_model
from segmentation.config import CFG
import torch
import torch.nn as nn
import numpy as np
from segmentation.scr.inference.tta import tta

from segmentation.scr.train_model.data_loader import *
from segmentation.scr.utils.utils import set_seed, resize_to_size
from segmentation.scr.train_model.evaluate_dataset import evaluate_dataset
import matplotlib.pyplot as plt



In [2]:
model_my_pc= get_pretrained_model(path_to_model='weight\mobnet_my_pc.pth', train_parallel=False, CFG=CFG)


In [3]:
path_img_dir="data\\train\\kidney_3_sparse\\images"
path_lb_dir="data\\train\\kidney_3_dense\\labels"
img_path, lb_path = create_img_lb_paths(path_img_dir = path_img_dir, path_lb_dir = path_lb_dir)

Общее число изображений с масками сегментации : 501


In [4]:
def create_tillings(images, overlap_pct=CFG.tilling_overlap_pct):
    if len(images.shape) == 2:
        # print('ss')
        images = images.unsqueeze(0).repeat(CFG.in_chans, 1, 1)
    min_overlap = float(overlap_pct) * 0.01
    max_stride = CFG.image_size * (1.0 - min_overlap)

    height, width = images.shape[-2], images.shape[-1]
    num_patches = np.ceil(np.array([height, width]) / max_stride).astype(np.int32)

    starts = [
        np.int32(np.linspace(0, height - CFG.input_size, num_patches[0])),
        np.int32(np.linspace(0, width - CFG.input_size, num_patches[1])),
    ]
    stops = [starts[0] + CFG.input_size, starts[1] + CFG.input_size]

    indexs = []
    tills = []
    for y1, y2 in zip(starts[0], stops[0]):
        for x1, x2 in zip(starts[1], stops[1]):
            tills.append(
                    images[..., y1:y2, x1:x2]
            )
            indexs.append((y1, y2, x1, x2))

    return tills, indexs


In [53]:
class Data_loader_inference(Dataset):
    def __init__(self,img_path,lb_path):
        self.img_paths=img_path
        self.lb_path = lb_path
        h_m, w_m = 0,0
        for path in self.img_paths:
            img=cv2.imread(path,cv2.IMREAD_GRAYSCALE)
            h, w  = img.shape
            h_m = max(h, h_m)
            w_m = max(w, w_m)
        self.img_size = (h_m, w_m)
        
        
    
    def __len__(self):
        return sum([len(self.lb_path)- CFG.in_chans for y in self.lb_path])
        #return len(self.paths)
    
    def __getitem__(self,index):
        paths_img  = self.img_paths[index : index +  CFG.in_chans]
        path_lb  = self.lb_path[index + CFG.in_chans // 2]
        
        images = []

        for path_img in paths_img:
            img=cv2.imread(path_img,cv2.IMREAD_GRAYSCALE)
            h, w  = img.shape
            img=torch.from_numpy(img).to(torch.uint8)
            if h < self.img_size[0] or w < self.img_size[1]:
                img = resize_to_size(img=img, image_size=self.img_size)
            images.append(img)
            
        label = cv2.imread(path_lb,cv2.IMREAD_GRAYSCALE)
        label=torch.from_numpy(label!=0).to(torch.uint8)*255
        if h < self.img_size[0] or w < self.img_size[1]:  
            label = resize_to_size(img=label, image_size=self.img_size)
        images= torch.stack(images)
      

        return images, label

In [54]:
loader_inf = Data_loader_inference(img_path, lb_path)

In [55]:
val_loader = DataLoader(loader_inf, batch_size=2, shuffle=False, )

In [56]:
batch = next(iter(val_loader))

In [57]:
images, masks = batch
images = (min_max_normalization(images.to(torch.float16)[None])[0]*255).to(torch.uint8)
tills, indexs = create_tillings(images=images, overlap_pct=10)
till_shape = len(tills)
tills = torch.cat(tills)
tills =tills.to(torch.float32)
tills = norm_with_clip(tills)

In [60]:
model_my_pc.eval()
with torch.no_grad():
    ans= tta(model=model_my_pc, x=tills.cuda().to(torch.float32))
ans = ans.cpu()
#ans = ans.sigmoid()

In [62]:
ans = ans.reshape(till_shape, 2, CFG.input_size, CFG.input_size)

In [65]:
mask_pred = torch.zeros_like(masks,dtype=torch.float32)
mask_count = torch.zeros_like(masks,dtype=torch.float32)
for i,(y1,y2,x1,x2) in enumerate(indexs):
    mask_pred[...,y1:y2, x1:x2] += ans[i]
    mask_count[...,y1:y2, x1:x2] += 1
 
mask_pred /= mask_count

In [69]:
dice_coef(mask_pred[1], masks > 0, from_logits= False, thr=0.5)

tensor(0.8928)

In [None]:
im1 = loader_inf[0][0]
im1 = (min_max_normalization(im1.to(torch.float16)[None])[0]*255).to(torch.uint8)
tills, indexs = create_tillings(images=im1.repeat(CFG.in_chans,1,1))
img_shape = tills[0].shape
tills = torch.cat(tills)

In [None]:
im1.shape

In [None]:
img_shape 

In [None]:
11 % 3

In [None]:
15 + (3 - 15% 3)

In [None]:
11 // 3

In [None]:
lb_path[500]

In [None]:
val_x=load_data(img_path ,is_label=False)
print(val_x.shape)
val_y=load_data(lb_path ,is_label=True)
print(val_y.shape)

In [None]:
set_seed(42)
val_dataset_3=Kaggld_Dataset([val_x],[val_y], arg=False, CFG=CFG_3)
val_loader_3 = DataLoader(val_dataset_3, batch_size=CFG.valid_batch_size, shuffle=False, )

val_dataset_5=Kaggld_Dataset([val_x],[val_y], arg=False, CFG=CFG)
val_loader_5 = DataLoader(val_dataset_5, batch_size=CFG.valid_batch_size, shuffle=False, )

In [None]:
def tta(x, model):
    model.eval()
    x_n=[torch.rot90(x,k=i,dims=(-2,-1)) for i in range(4)]

    shape = x.shape
    for i in range(4):
        if i ==0:
            x_n[i] = x_n[i]
        else:
            x_n[i]=torch.flip(x_n[i] , dims = (-2,))
    with torch.no_grad():
        pred = [model(b) for b in x_n]
    
        torch.flip(pred[i] , dims = (-2,))
    pred=torch.cat(pred,dim=0)
    pred = pred.sigmoid()
    pred =pred.reshape(4,shape[0],*shape[2:])
    
    
    for i in range(4):
        if i ==0:
            pred[i] = pred[i]
        else:
            pred[i]=torch.flip(pred[i] , dims = (-2,))
    pred=[torch.rot90(pred[i],k=-i,dims=(-2,-1)) for i in range(4)]
    pred=torch.stack(pred,dim=0).mean(0)
            

    return pred

In [None]:
def create_tilling( images, overlap_pct = 50):
    if len(images.shape) == 2:
        #print('ss')
        images = images.unsqueeze(0).repeat(CFG.in_chans,1,1)
    min_overlap = float(overlap_pct) * 0.01
    max_stride = CFG.image_size * (1.0 - min_overlap)

    #print(images.shape)
    height, width = images.shape[1], images.shape[2]
    num_patches = np.ceil(np.array([height, width]) / max_stride).astype(
                np.int32
            )

    starts = [
                np.int32(np.linspace(0, height - CFG.input_size, num_patches[0])),
                np.int32(np.linspace(0, width - CFG.input_size, num_patches[1])),
            ]
    stops = [starts[0] + CFG.input_size, starts[1] + CFG.input_size]
        
    indexs=[]
    tills=[]
    for y1, y2 in zip(starts[0], stops[0]):
        for x1, x2 in zip(starts[1], stops[1]):
            tills.append(images[...,y1:y2, x1:x2])
            indexs.append((y1,y2,x1,x2))
                
  
    return tills, indexs

In [None]:
def evaluate_dataset_2(model, val_loader):
    model.eval()
    timer = tqdm(range(len(val_loader)))
    val_scores = 0
    scores = []
    for i, (x, y) in enumerate(val_loader):
        x = x.cuda().to(torch.float32)
        y = y.cuda().to(torch.float32)
        x = norm_with_clip(x.reshape(-1, *x.shape[2:])).reshape(x.shape)
        with torch.no_grad():
            pred = tta(model=model, x=x)

        score = dice_coef(pred.detach(), y, from_logits=False)
        scores.append(score.detach().cpu().item())
        val_scores = (val_scores * i + score) / (i + 1)
        timer.set_description(f"eval--> score:{val_scores:.4f}")
        timer.update()
    timer.close()
    return scores, val_scores

In [None]:
scores_2, val_scores_2 = evaluate_dataset_2(model=model_my_pc, val_loader=val_loader_5)

In [None]:
print(val_scores_2, val_scores_my)

In [None]:
scores_my, val_scores_my= evaluate_dataset(model=model_my_pc ,val_loader=val_loader_5)

In [None]:
import scipy
scipy.stats.ttest_rel(a=scores_my, b=scores_2)