In [1]:
import os
import gc
import cv2
import math
import copy
import time
import random

# For data manipulation
import numpy as np
import pandas as pd

# Pytorch Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp

# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2


from sklearn.metrics import f1_score,roc_auc_score


import timm
from timm.models.efficientnet import *

# Utils
import joblib
from tqdm import tqdm
from collections import defaultdict


import warnings
warnings.filterwarnings("ignore")

from sklearn.metrics import f1_score

import matplotlib.pyplot as plt
import glob

import pickle

In [2]:
CONFIG = {"seed": 2022,

          "valid_batch_size": 32,

          "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),

          
          
          "train_batch":16,
          
          }

In [3]:
def set_seed(seed=42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    
set_seed()

In [4]:
test_ct_list=list(glob.glob(os.path.join("work_test/test_crop", "*"))) 

In [5]:
df=pd.DataFrame(test_ct_list,columns=["path"])

In [6]:
with open('work_test/test_dic1_05.pickle', 'rb') as f:
    test_dic = pickle.load(f)

In [7]:
len(test_dic)

5281

In [8]:
class Covid19Dataset_valid(Dataset):
    def __init__(self, df,train_batch=10,transforms=None):
        self.df = df
        self.path = df['path'].values
        
        self.transforms = transforms
        self.img_batch=train_batch
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        
        img_path = self.path[index]
        img_path_l = os.listdir(img_path)
        img_path_l_ = [file[2:] if file.startswith("._") else file for file in img_path_l]
        
        img_list = [int(i.split('.')[0]) for i in img_path_l_]
        index_sort = sorted(range(len(img_list)), key=lambda k: img_list[k])
        ct_len = len(img_list)

        
        start_idx,end_idx=test_dic[img_path]
        


        img_sample = torch.zeros((self.img_batch, 3, 256, 256))
        

        
        if (end_idx-start_idx) >= self.img_batch:
            sample_idx = random.sample(range(start_idx, end_idx),self.img_batch)
        elif ct_len>20:
            sample_idx = [random.choice(range(start_idx, end_idx)) for _ in range(self.img_batch)]
        else:
            sample_idx = [random.choice(range(ct_len)) for _ in range(self.img_batch)]
        
        for count, idx in enumerate(sample_idx):

            img_path_ = os.path.join(img_path, img_path_l_[index_sort[idx]])
            
            img = cv2.imread(img_path_)
          
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            
            img = self.transforms(image=img)['image']
          
            
            img_sample[count] = img[:]
            
        return {
            'image': img_sample,
            'id': img_path
        }
        
        
        
        
        


In [9]:
def prepare_loaders():


  

    valid_dataset = Covid19Dataset_valid(df,CONFIG['train_batch'], transforms=data_transforms["valid"])


    valid_loader = DataLoader(valid_dataset, batch_size=CONFIG["valid_batch_size"], 
                              num_workers=8, shuffle=False, pin_memory=True)
    
    return  valid_loader

In [10]:
data_transforms = {

    
    "valid": A.Compose([
        A.Resize(256, 256),

        A.Normalize(),
        ToTensorV2()], p=1.)
}

In [11]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        e = efficientnet_b3a(pretrained=True, drop_rate=0.3, drop_path_rate=0.2)
        self.b0 = nn.Sequential(
            e.conv_stem,
            e.bn1,
            e.act1,
        )
        self.b1 = e.blocks[0]
        self.b2 = e.blocks[1]
        self.b3 = e.blocks[2]
        self.b4 = e.blocks[3]
        self.b5 = e.blocks[4]
        self.b6 = e.blocks[5]
        self.b7 = e.blocks[6]
        self.b8 = nn.Sequential(
            e.conv_head, 
            e.bn2,
            e.act2,
        )

        self.emb = nn.Linear(1536,224)
        self.logit = nn.Linear(224,1)
        

    def forward(self, image):
        batch_size = len(image)
        x = 2*image-1     

        x = self.b0(x) 
        x = self.b1(x) 
        x = self.b2(x)
        x = self.b3(x) 
        x = self.b4(x) 
        x = self.b5(x) 

        x = self.b6(x) 
        x = self.b7(x) 
        x = self.b8(x) 
        x = F.adaptive_avg_pool2d(x,1).reshape(batch_size,-1)

        x = self.emb(x)
        logit = self.logit(x)
    
        return logit



In [12]:
@torch.inference_mode()
def inference(model, dataloader, device):
    model.eval()
    
    dataset_size = 0

    
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    IDS=[]
    pred_y=[]
    for step, data in bar:
        ids = data["id"]
        ct_b, img_b, c, h, w = data['image'].size()
        data_img = data['image'].reshape(-1, c, h, w)
        
        
        images = data_img.to(device, dtype=torch.float)
        
        
        batch_size = images.size(0)

        outputs = model(images)


        
        pred_y.append(torch.sigmoid(outputs).cpu().numpy())
        IDS.append(ids)
        

        
    

    pred_y=np.concatenate(pred_y)
    IDS = np.concatenate(IDS)
    
    
   
    gc.collect()
    

  
  
    pred_y=np.array(pred_y).reshape(-1,1)
    pred_y=np.array(pred_y).reshape(-1,img_b)

    pred_y=pred_y.mean(axis=1)
    
    return pred_y,IDS
    

In [13]:
weights_path="model_weights/job_51_effnetb3a.bin"

In [14]:
model=Net()
#model = nn.DataParallel(model)
model.load_state_dict(torch.load(weights_path))
model=model.cuda()

In [15]:
test_loader=prepare_loaders()

In [16]:
total_pred=[]
for i in range(1):
    pred_y,name=inference(model, test_loader, device=CONFIG['device'])
    total_pred.append(pred_y)

100%|██████████| 166/166 [01:19<00:00,  2.10it/s]


In [17]:
final_pred=np.mean(total_pred,axis=0)

In [18]:
dict_all=dict(zip(name, final_pred))

In [19]:
cnn_one_pred_df=pd.DataFrame(list(dict_all.items()),
                   columns=['path', 'pred'])

In [20]:
cnn_one_pred_df.to_csv("output/cnn_one_pred_df.csv",index=False)

In [21]:
times_list=[10,50]

In [22]:
for times in times_list:
    total_pred=[]
    for i in range(times):
        pred_y,name=inference(model, test_loader, device=CONFIG['device'])
        total_pred.append(pred_y)
    final_pred=np.mean(total_pred,axis=0)
    dict_all=dict(zip(name, final_pred))

    cnn_times_pred_df=pd.DataFrame(list(dict_all.items()),
                       columns=['path', 'pred'])
    cnn_times_pred_df.to_csv(f"output/cnn_{times}_pred_df.csv",index=False)
    print("save")

100%|██████████| 166/166 [01:15<00:00,  2.20it/s]
100%|██████████| 166/166 [01:15<00:00,  2.20it/s]
100%|██████████| 166/166 [01:15<00:00,  2.19it/s]
100%|██████████| 166/166 [01:14<00:00,  2.21it/s]
100%|██████████| 166/166 [01:15<00:00,  2.18it/s]
100%|██████████| 166/166 [01:15<00:00,  2.18it/s]
100%|██████████| 166/166 [01:14<00:00,  2.22it/s]
100%|██████████| 166/166 [01:16<00:00,  2.18it/s]
100%|██████████| 166/166 [01:14<00:00,  2.24it/s]
100%|██████████| 166/166 [01:15<00:00,  2.21it/s]


save


100%|██████████| 166/166 [01:15<00:00,  2.19it/s]
100%|██████████| 166/166 [01:13<00:00,  2.25it/s]
100%|██████████| 166/166 [01:14<00:00,  2.23it/s]
100%|██████████| 166/166 [01:14<00:00,  2.23it/s]
100%|██████████| 166/166 [01:14<00:00,  2.23it/s]
100%|██████████| 166/166 [01:14<00:00,  2.23it/s]
100%|██████████| 166/166 [01:14<00:00,  2.21it/s]
100%|██████████| 166/166 [01:14<00:00,  2.22it/s]
100%|██████████| 166/166 [01:14<00:00,  2.22it/s]
100%|██████████| 166/166 [01:14<00:00,  2.22it/s]
100%|██████████| 166/166 [01:14<00:00,  2.23it/s]
100%|██████████| 166/166 [01:14<00:00,  2.22it/s]
100%|██████████| 166/166 [01:14<00:00,  2.21it/s]
100%|██████████| 166/166 [01:14<00:00,  2.22it/s]
100%|██████████| 166/166 [01:15<00:00,  2.21it/s]
100%|██████████| 166/166 [01:14<00:00,  2.22it/s]
100%|██████████| 166/166 [01:16<00:00,  2.18it/s]
100%|██████████| 166/166 [01:16<00:00,  2.18it/s]
100%|██████████| 166/166 [01:15<00:00,  2.19it/s]
100%|██████████| 166/166 [01:15<00:00,  2.20it/s]


save



