In [1]:
import os
import cv2
import numpy as np
from scipy.stats import wilcoxon

import torch
import torchvision.transforms as transforms
from tqdm.notebook import tqdm

from models import SwinTransformer

## Load model

In [16]:
m = 0.39221061670618984
s = 0.11469786773730418
t = transforms.Compose([transforms.ToPILImage(),
                        transforms.Resize((224,224)),
                        transforms.ToTensor(),
                        transforms.Normalize((m, m, m), (s, s, s))])

In [20]:
def inference(path, st=3, ed=7, model=model, t=t):
    
    model.eval()
    img_list = os.listdir(path)
    sort_index = sorted(range(len(img_list)), key=lambda k: int(img_list[k].split('.')[0]))
    ct_len = len(sort_index)
    start_idx = int(round(ct_len / 10 * st, 0))
    end_idx = int(round(ct_len / 10 * ed, 0)) + 1
    
    pop = []
    for i in range(start_idx, end_idx):
        img_path = os.path.join(path, img_list[sort_index[i]])
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = t(img).to(device).unsqueeze(0)
        output = model(img)
        pop.append(output.item())

    p_value = wilcoxon_rank_test(pop)
    print(path)
    print(p_value)
    if p_value < 0.05:
        return 1
    else:
        return 0

In [21]:
def wilcoxon_rank_test(pop):
    pop = np.array(pop)
    postive_pop = pop[(pop >= 1 - np.sqrt(0.2) * 2) & (pop <= 1 + np.sqrt(0.2) * 2)]
    negative_pop = pop[(pop >= -1 - np.sqrt(0.2) * 2) & (pop <= -1 + np.sqrt(0.2) * 2)]
    total_pop = len(postive_pop) + len(negative_pop)
    if total_pop == 0:
        return 1.0
    else:
        w, p = wilcoxon(np.concatenate((postive_pop, negative_pop)), alternative='greater')
        return p

In [22]:
covid = []
non_covid = []

test_path = '/ssd2/covid/data/test/'
test_folder = os.listdir(test_path)
for folder in tqdm(test_folder):
    path = os.path.join(test_path, folder)
    pred = inference(path)
    if pred == 1:
        covid.append(folder)
    else:
        non_covid.append(folder)

  0%|          | 0/3455 [00:00<?, ?it/s]

/ssd2/covid/data/test/ac8e7009-1d1f-47df-822b-c496d0ecf323
0.9999998173583405
/ssd2/covid/data/test/a6c678e1-8421-41e4-af41-0d2849ca00f7
1.0
/ssd2/covid/data/test/a4428be1-9c3a-422d-8a97-ae37ad487945
0.9999999999999984
/ssd2/covid/data/test/37a446ba-5474-4607-872e-7673a9e1e310
0.9999999982385248
/ssd2/covid/data/test/4fb045c8-7516-419e-a4fc-913ec56dfea3
0.9999999878750473
/ssd2/covid/data/test/4f39ffae-8378-4b00-b6ac-e1bdcad6e0e7
0.9999974820229123
/ssd2/covid/data/test/76aadc6f-c00f-419e-9469-404d4bc6c99b
1.0
/ssd2/covid/data/test/367d91ac-ad1c-4234-b939-c565d392bc0f
1.0
/ssd2/covid/data/test/06e351ec-bc28-4545-bc74-f4767c4de6ba
1.0
/ssd2/covid/data/test/4e23758d-d529-4a7c-a77a-903cebff4c68
1.0
/ssd2/covid/data/test/284e200a-3b38-46b9-9bb7-951e22288563
4.37423947172168e-20
/ssd2/covid/data/test/71b55860-6bde-4752-8e7f-8ce07e62e3e3
0.9999999999742812
/ssd2/covid/data/test/25b3d787-b237-40df-a73b-b5a573bc6821
0.9999998173583405
/ssd2/covid/data/test/591e137f-9cc6-4373-a026-d443cca6dafc


In [23]:
len(non_covid) / len(covid)

5.657032755298651

In [24]:
len(non_covid)

2936

In [25]:
len(covid)

519

In [26]:
import csv

# open the file in the write mode
with open('submission/w_mse_epoch_70_f1_0.9172_alpha0.05/covid', 'w', encoding='UTF8') as f:
    # create the csv writer
    writer = csv.writer(f)
    # write a row to the csv file
    writer.writerow(covid)
    
with open('submission/w_mse_epoch_70_f1_0.9172_alpha0.05/non-covid', 'w', encoding='UTF8') as f:
    # create the csv writer
    writer = csv.writer(f)
    # write a row to the csv file
    writer.writerow(non_covid)