In [1]:
import os, sys
import copy
import json
import math
import time
from datetime import datetime

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision
from torchvision import transforms

from PIL import Image

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [2]:
data_dir = '../../data/processed'

In [3]:
class ChestXrayDataSet(Dataset):
    def __init__(self, data_filepath, split, round_number):
        self.split = split
        self.round_number = round_number
        dataframe = pd.read_csv(data_filepath)
        dataframe['target'] = dataframe["class label"].apply(lambda x: 0 if x == 'No Finding' else 1)
        self.dataframe = dataframe[((dataframe['split'] == split) 
                                    & (dataframe['round_number'] <= round_number))].reset_index(drop=True)

        self.image_paths = self.dataframe["img_filepath"].values
        self.targets = torch.FloatTensor(self.dataframe['target'].values)
        self.CLASSES_LABELS = ['Healthy', 'Sick']
        self.TARGET_DICT = {self.CLASSES_LABELS[i]: i for i in range(len(self.CLASSES_LABELS))}

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        image = Image.open(self.image_paths[index]).convert('RGB')
        preprocess = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(224),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            transforms.RandomHorizontalFlip(),
        ])
        image = preprocess(image)

        return image, self.targets[index]

In [4]:
rounds_stats = []
index = []
for database in [1,2,3,4,5]:
    for round_number in [1,2,3,4,5]:
        data_filepath = f'{data_dir}/imgs_data0{database}.csv'
        
        train_set = ChestXrayDataSet(data_filepath, 'train', round_number)
        valid_set = ChestXrayDataSet(data_filepath, 'valid', round_number)
        test_set = ChestXrayDataSet(data_filepath, 'test', round_number)

        # print(f'\nDatabase {database} round {round_number}:')
        # print('train_set size:', train_set.__len__(), pd.value_counts(np.asarray(train_set.targets)))
        # print('valid_set size:', valid_set.__len__(), pd.value_counts(np.asarray(valid_set.targets)))
        # print('test_set size:', test_set.__len__(), pd.value_counts(np.asarray(test_set.targets)))
        # print('total:', train_set.__len__() + valid_set.__len__() + test_set.__len__())
        
        stats = {'train_set size': train_set.__len__(), 
                 '0 target (train) [%]': 100*pd.value_counts(np.asarray(train_set.targets))[0]/train_set.__len__(),
                 '1 target (train) [%]': 100*pd.value_counts(np.asarray(train_set.targets))[1]/train_set.__len__(),
                 'valid_set size': valid_set.__len__(), 
                 '0 target (valid) [%]': 100*pd.value_counts(np.asarray(valid_set.targets))[0]/valid_set.__len__(),
                 '1 target (valid) [%]': 100*pd.value_counts(np.asarray(valid_set.targets))[1]/valid_set.__len__(),
                 'test_set size': test_set.__len__(),
                 '0 target (test) [%]': 100*pd.value_counts(np.asarray(test_set.targets))[0]/test_set.__len__(),
                 '1 target (test) [%]': 100*pd.value_counts(np.asarray(test_set.targets))[1]/test_set.__len__(),
                 'total': train_set.__len__() + valid_set.__len__() + test_set.__len__(),
                }
        index.append(f'db{database} round{round_number}')
        rounds_stats.append(stats)
        
rounds_stats = pd.DataFrame(rounds_stats, index=index)
rounds_stats

Unnamed: 0,train_set size,0 target (train) [%],1 target (train) [%],valid_set size,0 target (valid) [%],1 target (valid) [%],test_set size,0 target (test) [%],1 target (test) [%],total
db1 round1,5576,55.344333,44.655667,1634,55.569155,44.430845,1661,56.291391,43.708609,8871
db1 round2,7687,54.741772,45.258228,2237,55.163165,44.836835,2359,55.871132,44.128868,12283
db1 round3,9783,54.533374,45.466626,2865,55.636998,44.363002,2965,56.053963,43.946037,15613
db1 round4,11821,53.827933,46.172067,3492,55.183276,44.816724,3607,55.031882,44.968118,18920
db1 round5,13904,53.27244,46.72756,4130,54.430993,45.569007,4260,54.765258,45.234742,22294
db2 round1,5262,56.746484,43.253516,1850,53.675676,46.324324,1859,53.738569,46.261431,8971
db2 round2,7213,56.176348,43.823652,2531,53.299091,46.700909,2540,53.385827,46.614173,12284
db2 round3,9270,55.382956,44.617044,3200,53.0,47.0,3240,53.425926,46.574074,15710
db2 round4,11204,55.016066,44.983934,3868,53.619442,46.380558,3903,53.394824,46.605176,18975
db2 round5,13188,54.231119,45.768881,4484,53.256021,46.743979,4586,53.118186,46.881814,22258
