In [1]:
import os
import gc
import cv2
import math
import copy
import time
import random

# For data manipulation
import numpy as np
import pandas as pd

# Pytorch Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp

# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2


from sklearn.metrics import f1_score,roc_auc_score


import timm
from timm.models.efficientnet import *

# Utils
import joblib
from tqdm import tqdm
from collections import defaultdict


import warnings
warnings.filterwarnings("ignore")

from sklearn.metrics import f1_score

import matplotlib.pyplot as plt
import glob

In [2]:
test_ct_all_list=list(glob.glob("work_test/test_crop/*/*")) 

In [3]:
len(test_ct_all_list)

874235

In [4]:
weights_path="model_weights/job_51_effnetb3a.bin"

In [5]:
df = pd.DataFrame(test_ct_all_list, columns = ['path'])

In [6]:
class Covid19Dataset(Dataset):
    def __init__(self, df,transforms=None):
        self.df = df
  
        self.path = df['path'].values
 
        self.transforms = transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
 

        img_path_ = self.path[index]
     

        img = cv2.imread(img_path_)
        try:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        except:
            print(img_path_ )
        img = self.transforms(image=img)['image']

            

        return {
            'image': img,

            'id' : img_path_ 
        }
            

In [7]:
# df

In [8]:
data_transforms = {

    
    "valid": A.Compose([
        A.Resize(256, 256),

        A.Normalize(),
        ToTensorV2()], p=1.)
}

In [9]:
def prepare_loaders():


  

    test_dataset = Covid19Dataset(df, transforms=data_transforms["valid"])


    test_loader = DataLoader(test_dataset, batch_size=128, 
                              num_workers=16, shuffle=False, pin_memory=True)

    
    return test_loader

In [10]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        e = efficientnet_b3a(pretrained=True, drop_rate=0.3, drop_path_rate=0.2)
        self.b0 = nn.Sequential(
            e.conv_stem,
            e.bn1,
            e.act1,
        )
        self.b1 = e.blocks[0]
        self.b2 = e.blocks[1]
        self.b3 = e.blocks[2]
        self.b4 = e.blocks[3]
        self.b5 = e.blocks[4]
        self.b6 = e.blocks[5]
        self.b7 = e.blocks[6]
        self.b8 = nn.Sequential(
            e.conv_head, 
            e.bn2,
            e.act2,
        )

        
        self.emb = nn.Linear(1536,224)
        self.logit = nn.Linear(224,1)



    # @torch.cuda.amp.autocast()
    def forward(self, image):
        batch_size = len(image)
        x = 2*image-1     

        x = self.b0(x) 
        x = self.b1(x) 
        x = self.b2(x) 
        x = self.b3(x) 
        x = self.b4(x) 
        x = self.b5(x) 

        x = self.b6(x) 
        x = self.b7(x) 
        x = self.b8(x) 
        x = F.adaptive_avg_pool2d(x,1).reshape(batch_size,-1)

        x = self.emb(x)
        logit = self.logit(x)
        
        return x



In [11]:
model=Net()
#model = nn.DataParallel(model)
model.load_state_dict(torch.load(weights_path))
model=model.cuda()

In [12]:
@torch.inference_mode()
def get_embeddings(model, dataloader, device):
    model.eval()
    
    LABELS = []
    EMBEDS = []
    IDS = []
    
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    
    for step, data in bar:  

        
    
        data_img = data['image']
    
        images = data_img.to(device, dtype=torch.float)



        ids = data['id']

        outputs = model(images)
        

        EMBEDS.append(outputs.cpu().numpy())
        IDS.append(ids)
    

    
    return EMBEDS, IDS

In [13]:
test_loader = prepare_loaders()

In [14]:
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [15]:
embed,name=get_embeddings(model,test_loader,device)

100%|██████████| 6830/6830 [13:11<00:00,  8.63it/s]


In [16]:
embed = np.vstack(embed)
name = np.concatenate(name)


In [17]:
dict_all=dict(zip(name, embed))

In [18]:
df_224=pd.DataFrame(list(dict_all.items()),
                   columns=['path', 'embed'])

In [19]:
# df_224.head()

In [20]:
df_224["ct_path"]=df_224["path"].apply(lambda x: x.split("/")[-2])
df_224["ct_slice"]=df_224["path"].apply(lambda x: int(x.split("/")[-1].split(".")[0]))

In [21]:
df_224["ct_len"]=df_224.groupby(["ct_path"])["ct_slice"].transform('count')

In [22]:
df_224.sort_values(by=['ct_path', 'ct_slice'], inplace=True)

In [23]:
df_224=df_224.reset_index(drop=True)

In [24]:
# df_224.ct_path.nunique()

In [25]:
df_224.to_pickle("work_test/test_224_embed.pkl")  