In [1]:
import sys
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')

In [2]:
# !conda install -y --channel conda-forge pyvips
!conda install ../input/pyvips-install/pyvips/*.tar.bz2


Downloading and Extracting Packages
######################################################################## | 100% 
######################################################################## | 100% 
######################################################################## | 100% 
######################################################################## | 100% 
######################################################################## | 100% 
######################################################################## | 100% 
######################################################################## | 100% 
######################################################################## | 100% 
######################################################################## | 100% 
######################################################################## | 100% 
######################################################################## | 100% 
###########################################################

In [3]:
import cv2
from openslide import OpenSlide
import tifffile as tiff
import os
import gc

In [4]:
import glob

# test_image_paths = glob.glob("../input/mayo-clinic-strip-ai/train/*.tif")
test_image_paths = glob.glob("../input/mayo-clinic-strip-ai/test/*.tif")

In [5]:
# scale = 4
# output_dir = "/kaggle/working/"
# too_big_for_process = []
# # for path in test_image_paths[180:200]:
# for path in test_image_paths:

#     print(path)
#     slide = OpenSlide(path)

#     if slide.dimensions[0]*slide.dimensions[1] < 4131662535:
#         image_id = os.path.splitext(os.path.basename(path))[0]
#         image = tiff.imread(path)
#         print(f"{image.shape}")
#         cv2.imwrite(os.path.join(output_dir, f"{image_id}.jpg"), image[::scale,::scale,::-1])
#         del image
#         gc.collect()
#     else:
#         print("Skip process for avoiding OOM possibility.")
#         too_big_for_process.append(path)

# test_image_paths = glob.glob("/kaggle/working/*.jpg") 

In [6]:
import zipfile
import torch

import numpy as np
import pandas as pd
from tqdm.auto import tqdm
# from efficientnet_pytorch import EfficientNet
from torch.utils.data import Dataset, DataLoader
from torch.optim import lr_scheduler
from torchvision import models
import torch.nn as nn
import timm
import torchvision
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
import pyvips
import scipy.stats
import random
from fastai.vision import *
from fastai.layers import AdaptiveConcatPool2d, Flatten, Mish

In [7]:
class Model(nn.Module):
    def __init__(self, arch='tf_efficientnetv2_s', n=1, pre=False):
        super().__init__()
#         m = torch.hub.load('../input/facebookresearchsemisupervisedimagenet1kmodels/semi-supervised-ImageNet1K-models-master', model=arch, source='local')
        m = timm.create_model(arch, pretrained=False)
        self.enc = nn.Sequential(*list(m.children())[:-2])       
        nc = list(m.children())[-1].in_features
        self.head = nn.Sequential(AdaptiveConcatPool2d(),Flatten(),nn.Linear(2*nc,512),
                            Mish(),nn.BatchNorm1d(512), nn.Dropout(0.5),nn.Linear(512,n))
        
    def forward(self, x):
        x = [x for x in x]
        shape = x[0].shape
        n = 16
        x = torch.stack(x,1).view(-1,shape[1],shape[2],shape[3])
        #x: bs*N x 3 x 128 x 128
        x = self.enc(x)
        #x: bs*N x C x 4 x 4
        shape = x.shape
        #concatenate the output for tiles into a single map
        x = x.view(-1,n,shape[1],shape[2],shape[3]).permute(0,2,1,3,4).contiguous()\
          .view(-1,shape[1],shape[2]*n,shape[3])
        #x: bs x C x N*4 x 4
        x = self.head(x)
        #x: bs x n
        return x

In [8]:
class Model2(nn.Module):
    def __init__(self, arch='swinv2_tiny_window16_256', n=1, pre=False):
        super().__init__()
#         m = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', arch)
        m = timm.create_model(arch, pretrained=pre, num_classes=0)
        self.enc = m
#         nc = list(m.children())[-1].in_features
        self.nc=768
        self.head = nn.Sequential(AdaptiveConcatPool2d(),Flatten(),nn.Linear(2*self.nc,512),
                            Mish(),nn.BatchNorm1d(512), nn.Dropout(0.5),nn.Linear(512,n))
        
    def forward(self, x):
        x = [x for x in x]
        shape = x[0].shape
        n = N
        x = torch.stack(x,1).view(-1,shape[1],shape[2],shape[3])
        #x: bs*N x 3 x 128 x 128
        x = self.enc(x)
        #x: bs*N x C x 4 x 4
        shape = x.shape
#         print(x.shape)
        #concatenate the output for tiles into a single map
        x =  x.view(-1, 16, self.nc, 1).permute(0,2,1,3).contiguous() # 1024 for swinv2_base_window16_256, 768 for swinv2_tiny_window8_256, 192 for vit_tiny_patch16_384, 384 for vit_small, deit3_small_patch16_384_in21ft1k
        #x: bs x C x N*4 x 4
        x = self.head(x)
        #x: bs x n
        return x

In [9]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

tile_sz=384
sz = 384
sz2 = 256
N=64
ims_per_batch = 16

cuda:0


In [10]:
# n_class = 2
# backbone = timm.create_model(model_name, pretrained=False, num_classes=0).to(device)
# model = nn.Sequential(
#     backbone,
#     nn.Dropout(0.2),
#     nn.Linear(backbone.num_features, n_class)
# ).to(device)

model = Model().to(device)

# model_path = "../input/mayoinfermodelbfjctp/20220821_models/kaggle/working/models/*.pth"
model_path = "../input/mayo-concat-tile-pooling/models/*.pth"
model_paths = glob.glob(model_path)

model_paths

['../input/mayo-concat-tile-pooling/models/concat-tile-pooling-128-fold2.pth',
 '../input/mayo-concat-tile-pooling/models/concat-tile-pooling-128-fold1.pth',
 '../input/mayo-concat-tile-pooling/models/concat-tile-pooling-128-fold3.pth',
 '../input/mayo-concat-tile-pooling/models/concat-tile-pooling-128-fold0.pth']

In [11]:
model2 = Model2().to(device)
model2_path = "../input/mayo-concat-tile-pooling-transformer/models/*.pth"
model2_paths = glob.glob(model2_path)

model2_paths

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


['../input/mayo-concat-tile-pooling-transformer/models/concat-tile-pooling-256-swinv2-fold1.pth',
 '../input/mayo-concat-tile-pooling-transformer/models/concat-tile-pooling-256-swinv2-fold2.pth',
 '../input/mayo-concat-tile-pooling-transformer/models/concat-tile-pooling-256-swinv2-fold4.pth',
 '../input/mayo-concat-tile-pooling-transformer/models/concat-tile-pooling-256-swinv2-fold0.pth',
 '../input/mayo-concat-tile-pooling-transformer/models/concat-tile-pooling-256-swinv2-fold3.pth']

In [12]:
def tile(img, sz=128, N=64):
    shape = img.shape
    pad0,pad1 = (sz - shape[0]%sz)%sz, (sz - shape[1]%sz)%sz
    img = np.pad(img,[[pad0//2,pad0-pad0//2],[pad1//2,pad1-pad1//2],[0,0]],constant_values=255)
    img = img.reshape(img.shape[0]//sz,sz,img.shape[1]//sz,sz,3)
    img = img.transpose(0,2,1,3,4).reshape(-1,sz,sz,3)
    if len(img) < N:
        img = np.pad(img,[[0,N-len(img)],[0,0],[0,0],[0,0]],constant_values=255)
    scores = []
    for im in img:
        scores.append(len(cv2.imencode(".jpg", im)[1]))
#     idxs = np.argsort(img.reshape(img.shape[0],-1).sum(-1))[:N]
#     img = img[idxs]
    scores, img = zip(*sorted(zip(scores, img), reverse=True, key=lambda x: x[0]))
    high_info_ind = pd.Series(scores).where(pd.Series(scores) > 30000).idxmin()
#     print(high_info_ind is not np.nan)
    
    bg_ind = pd.Series(scores).where(pd.Series(scores) > 10000).idxmin()
    bg_cand = img[bg_ind:]
    for bgi, bg in enumerate(bg_cand):
        bg = bg.reshape(bg.shape[0] * bg.shape[1], bg.shape[2])
        white, _ = scipy.stats.mode(bg, axis=0)
        diff_to_white = (255,255,255) - white
        if sum(diff_to_white[0]) < 128*3:
            break
        else:
            diff_to_white = [[0,0,0]]
    
    if high_info_ind < N or high_info_ind is np.nan:
        return img[:N], bg_cand[bgi], diff_to_white
    else:
        high_info_indexes = random.sample(list(range(high_info_ind)), N)
        img2 = []
        for i in high_info_indexes:
            img2.append(img[i])
        return img2, bg_cand[bgi], diff_to_white

def vips2numpy(vi):
    format_to_dtype = {
       'uchar': np.uint8,
       'char': np.int8,
       'ushort': np.uint16,
       'short': np.int16,
       'uint': np.uint32,
       'int': np.int32,
       'float': np.float32,
       'double': np.float64,
       'complex': np.complex64,
       'dpcomplex': np.complex128,
    }
    return np.ndarray(buffer=vi.write_to_memory(),dtype=format_to_dtype[vi.format],shape=[vi.height, vi.width, vi.bands])

def return_tiled_images(image_path, transform, N=64, ims_per_batch=16 ,max_size=20000, crop_size=384):
    image = pyvips.Image.thumbnail(image_path, max_size)
    image = vips2numpy(image)
    width, height, c = image.shape
    print(f"Input width: {width} height: {height}")
    images, bg, diff_to_white = tile(image, sz=crop_size, N=N)
    output_images = []
    for idx, img in enumerate(images):
#         img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
#         img = cv2.imencode(".jpg", img, [cv2.IMWRITE_JPEG_QUALITY, 100])[1]
#         img = cv2.imdecode(img, flags=cv2.IMREAD_COLOR)
        img = img + diff_to_white
        img = img / img.max()
        img = np.clip(img * 255, a_min = 0, a_max = 255).astype(np.uint8)
        img = transform(img)
        
        output_images.append(img)
    
    img_indexes = random.sample(list(range(0, N)), N)
    
    batched_images = []
    one_batch = []
    for i, index in enumerate(img_indexes):
        one_batch.append(output_images[index])
        if i%ims_per_batch==(ims_per_batch-1):
            batched_images.append(torch.stack(one_batch, dim=0))
            one_batch = []

    batched_images = torch.stack(batched_images, dim=0)
#     del img, image, images, output_images; gc.collect()
    return batched_images

In [13]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToPILImage(),
    torchvision.transforms.Resize((sz, sz)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

preds = []

for path in test_image_paths:
    images = return_tiled_images(image_path=path, transform=transform, N=N, max_size=20000, crop_size=tile_sz)
    images2 = []
    for batch in images:
        images2.append(torchvision.transforms.Resize((sz2, sz2))(batch))
        
    images2 = torch.stack(images2, dim=0)
    images = images.to(device)
    images2 = images2.to(device)

    
    for i, m_path in enumerate(model_paths):
        model.load_state_dict(
        torch.load(
            m_path, map_location=device
            )
        )
        model.train(False)
        with torch.cuda.amp.autocast():
#             pred = model(torch.unsqueeze(images, 0))
            pred = model(images)

        pred = torch.sigmoid(pred).to('cpu').detach().numpy().copy()

        pred_ce = ((1 - pred)**2).mean()
        pred_laa = (pred**2).mean()

        preds.append((path, i, "model1", pred_ce, pred_laa))

    for i, m_path in enumerate(model2_paths):
        model2.load_state_dict(
        torch.load(
            m_path, map_location=device
            )
        )
        model2.train(False)
        
        with torch.cuda.amp.autocast():
#             pred = model(torch.unsqueeze(images, 0))
            pred2 = model2(images2)
    
        pred2 = torch.sigmoid(pred2).to('cpu').detach().numpy().copy()
#         print(pred2)

        pred2_ce = ((1 - pred2)**2).mean()
        pred2_laa = (pred2**2).mean()

        preds.append((path, i, "model2", pred2_ce, pred2_laa))

        
    del images, images2
    gc.collect()

Input width: 20000 height: 11187
Input width: 20000 height: 4937
Input width: 20000 height: 4005
Input width: 9512 height: 20000


In [14]:
# if len(too_big_for_process)>0:
#     for i in too_big_for_process:
#         preds.append((i, 0, 0.82, 0.28))

In [15]:
# images = return_tiled_images(image_path=path, transform=transform, N=N, max_size=20000, crop_size=tile_sz)
# images2 = []
# for batch in images:
#     images2.append(torchvision.transforms.Resize((sz2, sz2))(batch))


In [16]:
# test_data = torch.zeros([4, 16, 3, 256, 256], dtype=torch.float16).to(device)
# with torch.cuda.amp.autocast():
#     t = model2(test_data)

In [17]:
# t

In [18]:
def path_to_patient_id(path):
    return os.path.basename(path).split("_")[0]

df = pd.DataFrame(preds, columns=("path", "fold", "model", "CE", "LAA"))
df["patient_id"] = df["path"].map(path_to_patient_id)

In [19]:
# df

In [20]:
df.groupby("patient_id").mean().drop("fold", axis=1).to_csv("submission.csv")

In [21]:
df.groupby("patient_id").mean().drop("fold", axis=1)

Unnamed: 0_level_0,CE,LAA
patient_id,Unnamed: 1_level_1,Unnamed: 2_level_1
006388,0.416504,0.252441
008e5c,0.485352,0.201294
00c058,0.461914,0.192383
01adc5,0.486816,0.123596
