In [1]:
DEBUG = False

In [2]:
import os, gc, ast, cv2, time, timm, pickle, random, pydicom

from timm0412 import timm as timm4smp 

import argparse
import warnings
import threading
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
from glob import glob
import albumentations
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp

import torch
import torch.nn as nn
import torch.optim as optim
import torch.cuda.amp as amp
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset


%matplotlib inline

device = torch.device('cuda')
torch.backends.cudnn.benchmark = True

  warn(f"Failed to load image Python extension: {e}")


In [3]:
data_dir = './'
image_size_seg = (128, 128, 128)
msk_size = image_size_seg[0]
image_size_cls = 224
n_slice_per_c = 15
n_ch = 5

batch_size_seg = 1
num_workers = 2
model_dir_seg = './kaggle'

In [4]:
if DEBUG:
    df = pd.read_csv(os.path.join(data_dir, 'train.csv'))[0:100]
    df = pd.DataFrame({
        'StudyInstanceUID': df['StudyInstanceUID'].unique().tolist()
    })
    df['image_folder'] = df['StudyInstanceUID'].apply(lambda x: os.path.join(data_dir, 'train_images', x))
else:
    df = pd.read_csv(os.path.join(data_dir, 'test.csv'))
    if df.iloc[0].row_id == '1.2.826.0.1.3680043.10197_C1':
        # test_images and test.csv are inconsistent in the dev dataset, fixing labels for the dev run.
        df = pd.DataFrame({"row_id": ['1.2.826.0.1.3680043.22327_C1', '1.2.826.0.1.3680043.25399_C1', '1.2.826.0.1.3680043.5876_C1'],
                               "StudyInstanceUID": ['1.2.826.0.1.3680043.22327', '1.2.826.0.1.3680043.25399', '1.2.826.0.1.3680043.5876'],
                               "prediction_type": ["C1", "C1", "C1"]}
                         )

    df = pd.DataFrame({
        'StudyInstanceUID': df['StudyInstanceUID'].unique().tolist()
    })
    df['image_folder'] = df['StudyInstanceUID'].apply(lambda x: os.path.join(data_dir, 'test_images', x))

df.tail()

Unnamed: 0,StudyInstanceUID,image_folder
0,1.2.826.0.1.3680043.22327,./test_images/1.2.826.0.1.3680043.22327
1,1.2.826.0.1.3680043.25399,./test_images/1.2.826.0.1.3680043.25399
2,1.2.826.0.1.3680043.5876,./test_images/1.2.826.0.1.3680043.5876


# Dataset

In [5]:
def load_dicom(path):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    data = cv2.resize(data, (image_size_seg[0], image_size_seg[1]), interpolation = cv2.INTER_AREA)
    return data


def load_dicom_line_par(path):

    t_paths = sorted(glob(os.path.join(path, "*")), key=lambda x: int(x.split('/')[-1].split(".")[0]))

    n_scans = len(t_paths)
#     print(n_scans)
    indices = np.quantile(list(range(n_scans)), np.linspace(0., 1., image_size_seg[2])).round().astype(int)
    t_paths = [t_paths[i] for i in indices]

    images = []
    for filename in t_paths:
        images.append(load_dicom(filename))
    images = np.stack(images, 0) # -1
    
    images = images - np.min(images)
    images = images / (np.max(images) + 1e-4)
    images = (images * 255).astype(np.uint8)

    return images, indices


class SegTestDataset(Dataset):

    def __init__(self, df):
        self.df = df.reset_index()

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        row = self.df.iloc[index]

        image, index = load_dicom_line_par(row.image_folder)
        if image.ndim < 4:
            image = np.expand_dims(image, axis=0).repeat(3, axis=0) # to 3ch

        image = image / 255.
        return torch.tensor(image).float(), index


# Model

In [6]:
from conv2d_same import Conv2dSame
from conv3d_same import Conv3dSame


def convert_3d(module):

    module_output = module
    if isinstance(module, torch.nn.BatchNorm2d):
        module_output = torch.nn.BatchNorm3d(
            module.num_features,
            module.eps,
            module.momentum,
            module.affine,
            module.track_running_stats,
        )
        
        # module => BatchNorm2d
        # affine – a boolean value that when set to True, this module has learnable affine parameters.
        # parameters weight and bias are only defined if the argument affine is set to True.
        if module.affine:
            # torch.no_grad() temporarily sets all of the requires_grad flags to false.
            # 'requires_grad' flag is set then model will compute gradient w.r.t to parameter.
            with torch.no_grad():
            # with => ensures that resource is "cleaned up" when the code that uses it finishes running, even if exceptions are thrown.
            # with torch.no_grad() => disable gradient calculation in this context.
                module_output.weight = module.weight
                module_output.bias = module.bias
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
        if hasattr(module, "qconfig"):
            module_output.qconfig = module.qconfig
            
    elif isinstance(module, Conv2dSame):
        module_output = Conv3dSame(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.Conv2d):
        module_output = torch.nn.Conv3d(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
            padding_mode=module.padding_mode
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.MaxPool2d):
        module_output = torch.nn.MaxPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            dilation=module.dilation,
            ceil_mode=module.ceil_mode,
        )
    elif isinstance(module, torch.nn.AvgPool2d):
        module_output = torch.nn.AvgPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            ceil_mode=module.ceil_mode,
        )

    for name, child in module.named_children():
        module_output.add_module( name, convert_3d(child) )
    del module

    return module_output



class TimmSegModel(nn.Module):
    def __init__(self, backbone, segtype='unet', pretrained=False):
        super(TimmSegModel, self).__init__()

        self.encoder = timm4smp.create_model( # timm4smp
            backbone,
            in_chans=3,
            features_only=True,
            pretrained=pretrained
        )
        g = self.encoder(torch.rand(1, 3, 64, 64))
        encoder_channels = [1] + [_.shape[1] for _ in g]
        decoder_channels = [160, 64, 48, 24, 16] 
        if segtype == 'unet':
            self.decoder = smp.unet.decoder.UnetDecoder(
                encoder_channels=encoder_channels[:n_blocks+1],
                decoder_channels=decoder_channels[:n_blocks],
                n_blocks=n_blocks,
                attention_type='scse',
            )
        self.segmentation_head = nn.Conv2d(decoder_channels[n_blocks-1], 7, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    def forward(self,x):
        enc_features = self.encoder(x)[:n_blocks]
        f, d = enc_features[0].shape[0], enc_features[0].device
        a = [24,48,64,160]
        b = [63,31,15,7]
        enc_features = [torch.cat((feat, torch.zeros((f,a[i],1,b[i],b[i]), device=d).float()), dim=2) for i, feat in enumerate(enc_features)]
        enc_features = [torch.cat((feat, torch.zeros((f,a[i],b[i]+1,1,b[i]), device=d).float()), dim=3) for i, feat in enumerate(enc_features)]      
        enc_features = [torch.cat((feat, torch.zeros((f,a[i],b[i]+1,b[i]+1,1), device=d).float()), dim=4) for i, feat in enumerate(enc_features)]           
        global_features = [0] + enc_features
        seg_features = self.decoder(*global_features)
        seg_features = self.segmentation_head(seg_features)
        return seg_features
  
    
    
class TimmModel(nn.Module):
    def __init__(self, backbone, pretrained=False):
        super(TimmModel, self).__init__()

        self.encoder = timm.create_model(
            backbone,
            in_chans=in_chans,
            num_classes=out_dim,
            features_only=False,
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate,
            pretrained=pretrained
        )
        # self.encoder.default_cfg =>
        # {'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', 
        # 'num_classes': 1000, 'input_size': (3, 300, 300), 'pool_size': (10, 10), 'crop_pct': 1.0, 'interpolation': 
        # 'bicubic', 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 'first_conv': 'conv_stem', 'classifier': 'classifier', 
        # 'test_input_size': (3, 384, 384), 'architecture': 'tf_efficientnetv2_s_in21ft1k'}        


        
        if 'efficient' in backbone:
            hdim = self.encoder.conv_head.out_channels
                # (conv_head): Conv2d(256, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False) 
                # self.encoder.conv_head => Conv2d(256, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)  
                # self.encoder.conv_head.out_channels => 1280
                
                # nn.Identity() => Identity()
                # self.encoder.classifier => Linear(in_features=1280, out_features=1, bias=True)  
            # replace the last classifier layer with identity layer.
            self.encoder.classifier = nn.Identity()

        elif 'convnext' in backbone:
            hdim = self.encoder.head.fc.in_features
            self.encoder.head.fc = nn.Identity()


        self.lstm = nn.LSTM(hdim, 256, num_layers=2, dropout=drop_rate, bidirectional=True, batch_first=True)
        self.head = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.Dropout(drop_rate_last),
            nn.LeakyReLU(0.1),
            nn.Linear(256, out_dim),
        )

    def forward(self, x):  # (bs, nslice, ch, sz, sz)
        # x.shape => torch.Size([2, 15, 6, 224, 224])
        
        bs = x.shape[0]
        # Tensor.view(*shape) => Returns a new tensor with the same data as the self tensor but of a different shape.
        x = x.view(bs * n_slice_per_c, in_chans, image_size, image_size)
            # x.shape => torch.Size([30, 6, 224, 224])
        
        feat = self.encoder(x)        

            # feat.shape => torch.Size([30, 1280])        
        feat = feat.view(bs, n_slice_per_c, -1)
            # feat.shape => torch.Size([2, 15, 1280])
        
        feat, _ = self.lstm(feat) # multiple outputs by lstm layer.
        
        # tensor.contiguous() will create a copy of the tensor, and the element in the copy will be stored in the memory in a contiguous(ordered) way.
        # contiguous(ordered) => change the order of data in accordance to indices.
        # contiguous() function is usually required when we 'changed the shape of a tensor' and further reshaping (view) it. 
        feat = feat.contiguous().view(bs * n_slice_per_c, -1)
        
        feat = self.head(feat)
        feat = feat.view(bs, n_slice_per_c).contiguous()

        return feat

    
    
class TimmModelType2(nn.Module):
    def __init__(self, backbone, pretrained=False):
        super(TimmModelType2, self).__init__()

        self.encoder = timm.create_model(
            backbone,
            in_chans=in_chans,
            num_classes=out_dim,
            features_only=False,
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate,
            pretrained=pretrained
        )

        if 'efficient' in backbone:
            hdim = self.encoder.conv_head.out_channels
            self.encoder.classifier = nn.Identity()
        elif 'convnext' in backbone:
            hdim = self.encoder.head.fc.in_features
            self.encoder.head.fc = nn.Identity()

        self.lstm = nn.LSTM(hdim, 256, num_layers=2, dropout=drop_rate, bidirectional=True, batch_first=True)
        self.head = nn.Sequential(
            nn.Linear(512, 256),
            nn.InstanceNorm1d(256), # replaced BatchNorm1d for training with batch_size = 1
            nn.Dropout(drop_rate_last),
            nn.LeakyReLU(0.1),
            nn.Linear(256, out_dim),
        )
        self.lstm2 = nn.LSTM(hdim, 256, num_layers=2, dropout=drop_rate, bidirectional=True, batch_first=True)
        self.head2 = nn.Sequential(
            nn.Linear(512, 256),
            nn.InstanceNorm1d(256), # replaced BatchNorm1d for training with batch_size = 1
            nn.Dropout(drop_rate_last),
            nn.LeakyReLU(0.1),
            nn.Linear(256, 1),
        )



    def forward(self, x):  # (bs, nc*7, ch, sz, sz)
        bs = x.shape[0]

        x = x.view(bs * n_slice_per_c * 7, in_chans, image_size, image_size)

        feat = self.encoder(x)
        feat = feat.view(bs, n_slice_per_c * 7, -1)
        feat1, _ = self.lstm(feat)
        feat1 = feat1.contiguous().view(bs * n_slice_per_c * 7, 512)
        feat2, _ = self.lstm2(feat)

        return self.head(feat1), self.head2(feat2[:, 0])

# Load Models

In [7]:
kernel_type = 'timm3d_effv2_unet4b_128_128_128_dsv2_flip12_shift333p7_gd1p5_bs4_lr3e4_20x50ep'
backbone = 'tf_efficientnetv2_s_in21ft1k'
models_seg = []

n_blocks = 4
for fold in range(5):
    model = TimmSegModel(backbone, pretrained=False)
    model = convert_3d(model)
    model = model.to(device)
    load_model_file = os.path.join(model_dir_seg, f'{kernel_type}_fold{fold}_best.pth')
    sd = torch.load(load_model_file)
    if 'model_state_dict' in sd.keys():
        sd = sd['model_state_dict']
    sd = {k[7:] if k.startswith('module.') else k: sd[k] for k in sd.keys()}
    model.load_state_dict(sd, strict=True)
    model.eval()
    models_seg.append(model)

len(models_seg)

5

In [8]:
kernel_type = '0920_1bonev2_effv2s_224_15_6ch_augv2_mixupp5_drl3_rov1p2_bs8_lr23e5_eta23e6_50ep'
backbone = 'tf_efficientnetv2_s_in21ft1k'

image_size = 224
in_chans = 6
models_cls1 = []
out_dim = 1
drop_rate = 0.
drop_rate_last = 0.
drop_path_rate = 0.
for fold in range(5):
    model = TimmModel(backbone, pretrained=False)
    model = model.to(device)
    load_model_file = os.path.join(model_dir_seg, f'{kernel_type}_fold{fold}_best.pth')
    sd = torch.load(load_model_file)
    if 'model_state_dict' in sd.keys():
        sd = sd['model_state_dict']
    sd = {k[7:] if k.startswith('module.') else k: sd[k] for k in sd.keys()}
    model.load_state_dict(sd, strict=True)
    model.eval()
    models_cls1.append(model)

len(models_cls1)

5

In [9]:
kernel_type = '0920_2d_lstmv22headv2_convnn_224_15_6ch_8flip_augv2_drl3_rov1p2_rov3p2_bs4_lr6e5_eta6e6_lw151_50ep'
backbone = 'convnext_nano'
in_chans = 6
models_cls2 = []

for fold in range(5):
    model = TimmModelType2(backbone, pretrained=False)
    model = model.to(device)
    load_model_file = os.path.join(model_dir_seg, f'{kernel_type}_fold{fold}_best.pth')
    sd = torch.load(load_model_file)
    if 'model_state_dict' in sd.keys():
        sd = sd['model_state_dict']
    sd = {k[7:] if k.startswith('module.') else k: sd[k] for k in sd.keys()}
    model.load_state_dict(sd, strict=True)
    model.eval()
    models_cls2.append(model)

len(models_cls2)

5

In [10]:
def load_bone(msk, cid, t_paths, cropped_images, index, inst_id):
    sema.acquire() #  threading topic    
    n_scans = len(t_paths)
    bone = []
    try:
        msk_b = msk[cid] > 0.1
            # msk_b.shape => (128, 128, 128)
            # msk_b => 
            # [[[False False False ... False False False]        
            #   ...
            #   ...
            #   [False False False ... False False False]
            #   [False False False ... False False False]]]          
          
        msk_c = msk[cid] > 0.05
                
            # msk_b.sum(axis = 1).shape => (128, 128)
            # msk_b.sum(1).sum(1).shape =>  (128,) 
            # msk_b.sum(1).sum(1) => 
            # [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   1
            #   18  75 157 239 307 382 485 592 655 663 688 809 810 601 323 100  36  10
            #    0   0   1   2   0   0   0   0   0   0   0   0   0   0   0   0   0   0
            #    ...
            #    ...
            #    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
            #    0   0]            
            # np.where(msk_b.sum(1).sum(1) > 0) => 
            # (array([17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
            #        34, 35, 38, 39]),)              

        # finding indices of three coordinate axes of mask_cuboid (128,128,128) where data is present.
        # (128,128,128) => (slices, W, H)
        y = np.where(msk_b.sum(0).sum(0) > 0)[0]
        x = np.where(msk_b.sum(0).sum(1) > 0)[0]        
        z = np.where(msk_b.sum(1).sum(1) > 0)[0]


        if len(x) == 0 or len(y) == 0 or len(z) == 0:
            y = np.where(msk_c.sum(0).sum(0) > 0)[0]            
            x = np.where(msk_c.sum(0).sum(1) > 0)[0]
            z = np.where(msk_c.sum(1).sum(1) > 0)[0]            

        # msk.shape => (7, 128, 128, 128)
        x1, x2 = max(0, x[0] - 1), min(msk.shape[1], x[-1] + 1)
        y1, y2 = max(0, y[0] - 1), min(msk.shape[2], y[-1] + 1)
        z1, z2 = max(0, z[0] - 1), min(msk.shape[3], z[-1] + 1)

        
            # z1, z2 => 26 98
        zz1, zz2 = int(z1 / msk_size * n_scans), int(z2 / msk_size * n_scans) # msk_size = 128, defined in config section
            # from BODMAS rule => zz1, zz2 => 49 186
            # (z1 / msk_size) => proportion of z1-coordinate over index length 128 (msk_size).
            # (z1 / msk_size) * n_scans => z1-coordinate over index length 'n_scans'.

        ## return 15 (n_slice_per_c) evenly spaced indexes.
        inds_ = np.linspace(z1 ,z2-1 ,n_slice_per_c).astype(int)                 
        inds = np.linspace(zz1 ,zz2-1 ,n_slice_per_c).astype(int)
            # inds_ = np.linspace(26, 98, 128) => [26 31 36 41 46 51 56 61 66 71 76 81 86 91 97]
            # inds => [ 49  58  68  78  87  97 107 117 126 136 146 155 165 175 185]
            # np.linspace(start, stop, num_of_samples_to_generate, ...) => Return evenly spaced numbers over a specified interval.

#         # taking in between range of 128 slices (dcm files).
#         inds_ = np.linspace(29 ,101 ,n_slice_per_c).astype(int)
#             # inds_ => [ 29  34  39  44  49  54  59  65  70  75  80  85  90  95 101]            


        for sid, (ind_, ind) in enumerate(zip(inds_, inds)):
#         for sid, (ind_) in enumerate(zip(inds_)):

            images = []
            for i in range(-n_ch//2+1, n_ch//2+1): # n_ch = 5, defined in config 
                # for i in (-2, 3):
                # 5//2 = 2, -5//2 = -3 (1 extra in -ve)
                try:
                    dicom = pydicom.read_file(t_paths[ind+i])
#                     dicom = pydicom.read_file(t_paths[index[0][ind_[0]+i]])                                        
                        # ind_ = 26 read 24, 25, 26, 27, 28 dicom files 
                        # t_paths[index[0][ind_+i]] => picking files which are also in 'images' during prediction of mask.

                    images.append(dicom.pixel_array)
                except:
                    images.append(np.zeros((512, 512)))


            data = np.stack(images, -1)
                # data.shape => (512, 512, 5)

            data = data - np.min(data)
            data = data / (np.max(data) + 1e-4)
            data = (data * 255).astype(np.uint8)  
                # prior to any type of transformation(resize, augmentation etc.) convert data to uint8.

            xx1 = int(x1 / msk_size * data.shape[0])
            xx2 = int(x2 / msk_size * data.shape[0])
            yy1 = int(y1 / msk_size * data.shape[1])
            yy2 = int(y2 / msk_size * data.shape[1])

            data = data[xx1:xx2, yy1:yy2]
                # data.shape => (96, 172, 5)
            data = np.stack([cv2.resize(data[:, :, i], (image_size_cls, image_size_cls), interpolation = cv2.INTER_LINEAR) for i in range(n_ch)], -1)
                # data.shape => (224, 224, 5)                     

            if (ind_+n_ch//2) >= msk_size:
                msk_this = msk[cid, ind_, :, :] 
#                 msk_this = msk[cid, ind_[0], :, :] 
            else:
                msk_this = msk[cid, ind_+n_ch//2, :, :] # ind_ = 26            
#                 msk_this = msk[cid, ind_[0]+n_ch//2, :, :] # ind_ = 26
                # msk[0,26,128,128].shape => (128, 128)         

            msk_this = msk_this[x1:x2, y1:y2]
                # msk_this[16:40, 23:66].shape => (24, 43)                                                
            msk_this = (msk_this * 255).astype(np.uint8)
                # prior to any type of transformation(resize, augmentation etc.) convert data to uint8.
                # np.unique(msk_this).astype(np.uint8)) =>             
                # [  0   1   2   4   6   8  11  14  21  23  48  49  50  61  69  72  73  81
                #  107 109 113 116 121 128 143 144 147 161 181 192 204 207 209 214 232 236
                #  240 241 245 247 250 252 253 254 255]                                

            msk_this = cv2.resize(msk_this, (image_size_cls, image_size_cls), interpolation = cv2.INTER_LINEAR)            
                # msk_this.shape => (224, 224)

            # appending mask at as 6th channel of data. 
            data = np.concatenate([data, msk_this[:, :, np.newaxis]], -1)            
                 # data.shape => (224, 224, 6)
            
            bone.append(torch.tensor(data))            

    except:
        for sid in range(n_slice_per_c):
            bone.append(torch.zeros((image_size_cls, image_size_cls, n_ch+1)).to(torch.uint8))            
                
    cropped_images[cid] = torch.stack(bone, 0)    
        # cropped_images[cid].shape => torch.Size([15, 224, 224, 6])    
        
    ### using local cache
#     image_file = os.path.join(data_dir, f'numpy_1/{inst_id}_{cid+1}.npy')
#     np.save(image_file, cropped_images[cid])
    time.sleep(2) #  threading topic
    sema.release() #  threading topic    

def load_cropped_images(msk, image_folder, index, inst_id, n_ch=n_ch):

    t_paths = sorted(glob(os.path.join(image_folder, "*")), key=lambda x: int(x.split('/')[-1].split(".")[0]))
    
    for cid in range(7): # 7
        threads[cid] = threading.Thread(target=load_bone, args=(msk, cid, t_paths, cropped_images, index, inst_id))
        threads[cid].start()
    for cid in range(7):
        threads[cid].join()

        # torch.cat(tensors, dim=0,..) => Concatenates the given sequence of seq tensors in the given dimension. 
        # torch.cat(cropped_images, 0).shape => torch.Size([105, 224, 224, 6])    
    return torch.cat(cropped_images, 0)


# Predict

In [11]:
sema = threading.Semaphore(value=12) # value => setting number of threads
dataset_seg = SegTestDataset(df) # df[3:4] df[1:2]
loader_seg = torch.utils.data.DataLoader(dataset_seg, batch_size=batch_size_seg, shuffle=False, num_workers=num_workers)

In [12]:
outputs1 = []
outputs2 = []

bar = tqdm(loader_seg)
with torch.no_grad():
    for batch_id, (images, indices) in enumerate(bar):
        indices = indices.numpy()
        images = images.cuda()

        # SEG
        pred_masks = []
        for model in models_seg:
            pmask = model(images)
            pmask = pmask.sigmoid()
            pred_masks.append(pmask)
        pred_masks = torch.stack(pred_masks, 0).mean(0).cpu().numpy()
        
        # Build cls input
        cls_inp = []
        threads = [None] * 7
        cropped_images = [None] * 7

        for i in range(pred_masks.shape[0]):
            row = df.iloc[batch_id*batch_size_seg+i]
            cropped_images = load_cropped_images(pred_masks[i], row.image_folder, indices, row.StudyInstanceUID)
#             cropped_images = cropped_images[:105,:,:,:] # torch.Size([119, 224, 224, 6]) reduces to torch.Size([105, 224, 224, 6])
            cls_inp.append((cropped_images.permute(0, 3, 1, 2) / 255.).float())
        cls_inp = torch.stack(cls_inp, 0).to(device)  # (1, 105, 6, 224, 224)
                
        pred_cls1, pred_cls2 = [], []
        # CLS 2
        for _, model in enumerate(models_cls2):
            logits, logits2 = model(cls_inp)
            pred_cls1.append(logits.sigmoid().view(-1, 7, n_slice_per_c))
            pred_cls2.append(logits2.sigmoid())
            
        # CLS 1
        cls_inp = cls_inp.view(7, 15, 6, image_size_cls, image_size_cls).contiguous()
        for _, model in enumerate(models_cls1):
            logits = model(cls_inp)          
            pred_cls1.append(logits.sigmoid().view(-1, 7, n_slice_per_c))
        
        pred_cls1 = torch.stack(pred_cls1, 0).mean(0)
        pred_cls2 = torch.stack(pred_cls2, 0).mean(0)
        outputs1.append(pred_cls1.cpu())
            # len(outputs1) = no. of test patients = 3
            # outputs1[0].shape => torch.Size([1, 7, 15])
        outputs2.append(pred_cls2.cpu())
            # len(outputs2) = no. of test patients = 3
            # outputs2 => [tensor([[0.4593]]), tensor([[0.0320]]), ..., tensor([[0.2525]])]


100%|█████████████████████████████████████████████| 3/3 [00:27<00:00,  9.11s/it]


# Output

In [13]:
outputs1 = torch.cat(outputs1)
    # concatenates list elements to tensor.
    # outputs1.shape => torch.Size([3, 7, 15])
    # outputs1[0].shape => torch.Size([7, 15])
    # outputs1[0] => 
    # tensor([[0.0365, 0.0338, 0.0387, 0.0372, 0.0386, 0.0373, 0.0376, 0.0375, 0.0406,
    #          0.0466, 0.0511, 0.0583, 0.0711, 0.0731, 0.0946],
    #         ...
    #         ...
    #         [0.2809, 0.2936, 0.2806, 0.2562, 0.2504, 0.2489, 0.2307, 0.2511, 0.2575,
    #          0.2629, 0.2683, 0.2799, 0.2696, 0.2700, 0.2663]])    
outputs2 = torch.cat(outputs2)

In [14]:
PRED1 = (outputs1.mean(-1)).clamp(0.0001, 0.9999)
    # outputs1.mean(-1).shape => torch.Size([3, 7])
    # outputs1[0].mean(-1).shape => torch.Size([7])
    # .clamp(0.0001, 0.9999) => clamp all inputs in the range
    # PRED1.shape => torch.Size([3, 7])
PRED2 = (outputs2.view(-1)).clamp(0.0001, 0.9999)
    # PRED2.shape => torch.Size([3])

In [15]:
row_ids = []
for _, row in df.iterrows():
    for i in range(7):
        row_ids.append(row.StudyInstanceUID + f'_C{i+1}')
    row_ids.append(row.StudyInstanceUID + '_patient_overall')

In [16]:
df_sub = pd.DataFrame({
    'row_id': row_ids,
    'fractured': torch.cat([PRED1, PRED2.unsqueeze(1)], 1).view(-1),
    # torch.cat([PRED1, PRED2.unsqueeze(1)], 1).shape => torch.Size([3, 8])
    # PRED2.unsqueeze(1).shape => torch.Size([3, 1])
})

In [17]:
df_sub.to_csv('submission.csv', index=False)

In [18]:
df_sub

Unnamed: 0,row_id,fractured
0,1.2.826.0.1.3680043.22327_C1,0.014081
1,1.2.826.0.1.3680043.22327_C2,0.048528
2,1.2.826.0.1.3680043.22327_C3,0.009748
3,1.2.826.0.1.3680043.22327_C4,0.016578
4,1.2.826.0.1.3680043.22327_C5,0.03251
5,1.2.826.0.1.3680043.22327_C6,0.100792
6,1.2.826.0.1.3680043.22327_C7,0.493738
7,1.2.826.0.1.3680043.22327_patient_overall,0.565615
8,1.2.826.0.1.3680043.25399_C1,0.012958
9,1.2.826.0.1.3680043.25399_C2,0.013387
