In [1]:
import os
import gc
import cv2
import math
import copy
import time
import random

# For data manipulation
import numpy as np
import pandas as pd

# Pytorch Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp

# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2


from sklearn.metrics import f1_score,roc_auc_score


import timm
from timm.models.efficientnet import *

# Utils
import joblib
from tqdm import tqdm
from collections import defaultdict


import warnings
warnings.filterwarnings("ignore")

from sklearn.metrics import f1_score

import matplotlib.pyplot as plt
import glob

In [2]:
test_ct_all_list=list(glob.glob("/home/fate/covid19_CT2/input/test_crop/*/*")) 

In [3]:
len(test_ct_all_list)

874235

In [4]:
weights_path="/home/fate/covid19_CT/model/f1/job_51_effnetb3a.bin"

In [6]:
df = pd.DataFrame(test_ct_all_list, columns = ['path'])

In [5]:
class Covid19Dataset(Dataset):
    def __init__(self, df,transforms=None):
        self.df = df
  
        self.path = df['path'].values
 
        self.transforms = transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
 

        img_path_ = self.path[index]
     

        img = cv2.imread(img_path_)
        try:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        except:
            print(img_path_ )
        img = self.transforms(image=img)['image']

            

        return {
            'image': img,

            'id' : img_path_ 
        }
            

In [7]:
df

Unnamed: 0,path
0,/home/fate/covid19_CT2/input/test_crop/ct_scan...
1,/home/fate/covid19_CT2/input/test_crop/ct_scan...
2,/home/fate/covid19_CT2/input/test_crop/ct_scan...
3,/home/fate/covid19_CT2/input/test_crop/ct_scan...
4,/home/fate/covid19_CT2/input/test_crop/ct_scan...
...,...
874230,/home/fate/covid19_CT2/input/test_crop/ct_scan...
874231,/home/fate/covid19_CT2/input/test_crop/ct_scan...
874232,/home/fate/covid19_CT2/input/test_crop/ct_scan...
874233,/home/fate/covid19_CT2/input/test_crop/ct_scan...


In [8]:
data_transforms = {

    
    "valid": A.Compose([
        A.Resize(256, 256),

        A.Normalize(),
        ToTensorV2()], p=1.)
}

In [14]:
def prepare_loaders():


  

    test_dataset = Covid19Dataset(df, transforms=data_transforms["valid"])


    test_loader = DataLoader(test_dataset, batch_size=128, 
                              num_workers=16, shuffle=False, pin_memory=True)

    
    return test_loader

In [15]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        e = efficientnet_b3a(pretrained=True, drop_rate=0.3, drop_path_rate=0.2)
        self.b0 = nn.Sequential(
            e.conv_stem,
            e.bn1,
            e.act1,
        )
        self.b1 = e.blocks[0]
        self.b2 = e.blocks[1]
        self.b3 = e.blocks[2]
        self.b4 = e.blocks[3]
        self.b5 = e.blocks[4]
        self.b6 = e.blocks[5]
        self.b7 = e.blocks[6]
        self.b8 = nn.Sequential(
            e.conv_head, #384, 1536
            e.bn2,
            e.act2,
        )
        #self.logit = nn.Linear(1536,1)
        
        
        self.emb = nn.Linear(1536,224)
        self.logit = nn.Linear(224,1)
        
#         self.mask = nn.Sequential(
#             nn.Conv2d(136, 128, kernel_size=3, padding=1),
#             nn.BatchNorm2d(128),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(128, 128, kernel_size=3, padding=1),
#             nn.BatchNorm2d(128),
#             nn.ReLU(inplace=True),
#             nn.Conv2d(128, 1, kernel_size=1, padding=0),
#         )


    # @torch.cuda.amp.autocast()
    def forward(self, image):
        batch_size = len(image)
        x = 2*image-1     # ; print('input ',   x.shape)

        x = self.b0(x) #; print (x.shape)  # torch.Size([2, 40, 256, 256])train_batch
        x = self.b1(x) #; print (x.shape)  # torch.Size([2, 24, 256, 256])
        x = self.b2(x) #; print (x.shape)  # torch.Size([2, 32, 128, 128])
        x = self.b3(x) #; print (x.shape)  # torch.Size([2, 48, 64, 64])
        x = self.b4(x) #; print (x.shape)  # torch.Size([2, 96, 32, 32])
        x = self.b5(x) #; print (x.shape)  # torch.Size([2, 136, 32, 32])
        #------------
#         mask = self.mask(x)
        #-------------
        x = self.b6(x) #; print (x.shape)  # torch.Size([2, 232, 16, 16])
        x = self.b7(x) #; print (x.shape)  # torch.Size([2, 384, 16, 16])
        x = self.b8(x) #; print (x.shape)  # torch.Size([2, 1536, 16, 16])
        x = F.adaptive_avg_pool2d(x,1).reshape(batch_size,-1)
#         return logitx = F.dropout(x, 0.2, training=self.training)
        x = self.emb(x)
        logit = self.logit(x)
        #return logit, mask
        return x



In [16]:
model=Net()
#model = nn.DataParallel(model)
model.load_state_dict(torch.load(weights_path))
model=model.cuda()

In [17]:
@torch.inference_mode()
def get_embeddings(model, dataloader, device):
    model.eval()
    
    LABELS = []
    EMBEDS = []
    IDS = []
    
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    
    for step, data in bar:  

        
    
        data_img = data['image']
    
        images = data_img.to(device, dtype=torch.float)



        ids = data['id']

        outputs = model(images)
        

        EMBEDS.append(outputs.cpu().numpy())
        IDS.append(ids)
    

    
    return EMBEDS, IDS

In [18]:
test_loader = prepare_loaders()

In [21]:
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [22]:
embed,name=get_embeddings(model,test_loader,device)

100%|██████████| 6830/6830 [13:00<00:00,  8.75it/s]


In [23]:
embed = np.vstack(embed)
name = np.concatenate(name)


In [24]:
name.shape

(874235,)

In [25]:
dict_all=dict(zip(name, embed))

In [30]:
df_224=pd.DataFrame(list(dict_all.items()),
                   columns=['path', 'embed'])

In [34]:
df_224.head()

Unnamed: 0,path,embed
0,/home/fate/covid19_CT2/input/test_crop/ct_scan...,"[0.017912194, -2.0158145e-06, 1.0221355, 0.618..."
1,/home/fate/covid19_CT2/input/test_crop/ct_scan...,"[0.022885323, -2.4122714e-06, 1.2911112, 0.791..."
2,/home/fate/covid19_CT2/input/test_crop/ct_scan...,"[0.005369455, -7.034596e-07, 0.33232358, 0.212..."
3,/home/fate/covid19_CT2/input/test_crop/ct_scan...,"[0.015921496, -1.4665876e-06, 0.85001236, 0.55..."
4,/home/fate/covid19_CT2/input/test_crop/ct_scan...,"[0.013488156, -1.3310564e-06, 0.72830224, 0.46..."


In [40]:
df_224["ct_path"]=df_224["path"].apply(lambda x: x.split("/")[-2])
df_224["ct_slice"]=df_224["path"].apply(lambda x: int(x.split("/")[-1].split(".")[0]))

In [42]:
df_224["ct_len"]=df_224.groupby(["ct_path"])["ct_slice"].transform('count')

In [43]:
df_224.sort_values(by=['ct_path', 'ct_slice'], inplace=True)

In [45]:
df_224=df_224.reset_index(drop=True)

In [47]:
df_224

Unnamed: 0,path,embed,ct_path,ct_slice,ct_len
0,/home/fate/covid19_CT2/input/test_crop/ct_scan...,"[0.0029444182, 2.0627795e-08, 0.1728667, 0.119...",ct_scan_0,0,50
1,/home/fate/covid19_CT2/input/test_crop/ct_scan...,"[0.002479718, 1.1204702e-07, 0.17129861, 0.109...",ct_scan_0,1,50
2,/home/fate/covid19_CT2/input/test_crop/ct_scan...,"[-3.0830503e-05, 2.5171153e-07, 0.0064469655, ...",ct_scan_0,2,50
3,/home/fate/covid19_CT2/input/test_crop/ct_scan...,"[0.0011787838, 1.9301933e-07, 0.09569399, 0.06...",ct_scan_0,3,50
4,/home/fate/covid19_CT2/input/test_crop/ct_scan...,"[0.0054802196, -4.1533087e-07, 0.27055162, 0.1...",ct_scan_0,4,50
...,...,...,...,...,...
874230,/home/fate/covid19_CT2/input/test_crop/ct_scan...,"[0.0043582036, -9.32421e-07, 0.17744961, 0.096...",ct_scan_999,38,43
874231,/home/fate/covid19_CT2/input/test_crop/ct_scan...,"[0.009558492, -9.1470235e-07, 0.6620663, 0.387...",ct_scan_999,39,43
874232,/home/fate/covid19_CT2/input/test_crop/ct_scan...,"[0.0074868137, -7.4435866e-07, 0.5292205, 0.32...",ct_scan_999,40,43
874233,/home/fate/covid19_CT2/input/test_crop/ct_scan...,"[0.00828557, -1.0583517e-06, 0.5592833, 0.3367...",ct_scan_999,41,43


In [49]:
df_224.ct_path.nunique()

5281

In [50]:
df_224.to_pickle("test_224_embed.pkl")  

In [51]:
df_p=pd.read_pickle("test_224_embed.pkl") 

In [54]:
from pandas.util.testing import assert_frame_equal
assert_frame_equal(df_p,df_224)