In [5]:
import torchvision.models as models
import torchvision
import torch.nn as nn
model = torchvision.models.efficientnet_v2_m(pretrained=True)
in_features = model.classifier[1]
in_features

Linear(in_features=1280, out_features=1000, bias=True)

In [50]:
!pip install -U torchvision

Collecting torchvision
  Downloading torchvision-0.13.1-cp38-cp38-manylinux1_x86_64.whl (19.1 MB)
[K     |████████████████████████████████| 19.1 MB 3.3 MB/s eta 0:00:01
Collecting torch==1.12.1
  Downloading torch-1.12.1-cp38-cp38-manylinux1_x86_64.whl (776.3 MB)
[K     |████████████████████████████████| 776.3 MB 1.6 kB/s eta 0:00:01
Installing collected packages: torch, torchvision
  Attempting uninstall: torch
    Found existing installation: torch 1.11.0
    Uninstalling torch-1.11.0:
      Successfully uninstalled torch-1.11.0
  Attempting uninstall: torchvision
    Found existing installation: torchvision 0.12.0
    Uninstalling torchvision-0.12.0:
      Successfully uninstalled torchvision-0.12.0
Successfully installed torch-1.12.1 torchvision-0.13.1


In [52]:
import torch

print(torchvision.__version__)

0.12.0


In [1]:
import os
from PIL import Image

In [16]:
for i in range(1,31):
    filepath = f'/home/bernardo/datasets/pan-radiographs/splits/{i:02d}.txt'
    with open(filepath) as f:
        for line in f: #ler cada linha do txt
            fname = line.strip().split('/')[2] #retirar o \n
            sex = fname.split('-')[10]
            age = fname.split('-')[-2][1:]
            months = fname.split('-')[-1][1:3] #home/bernardo/datasets/pan-radiographs/1st-set

            if fname.split('-')[0] == 'pan': #separar os arquivos pan e panreport
                fpath = os.path.join(f'/home/bernardo/datasets/pan-radiographs/1st-set/images/{fname}')
            else:
                fpath = os.path.join(f'/home/bernardo/datasets/pan-radiographs/2nd-set/images/{fname}')
            im = Image.open(fpath)

In [11]:
import os
import cv2
import numpy as np

from PIL import Image

import torch
from torch.utils.data import Dataset
from torchvision import transforms as T

from skimage.io import imread
from skimage.transform import rescale, resize, downscale_local_mean

MEAN = [0.485, 0.456, 0.406]
STDV = [0.229, 0.224, 0.225]


In [26]:
transform = T.Compose([
    T.Resize((224, 224)),
    T.TrivialAugmentWide(),
    T.ToTensor(),
    T.Normalize((0.5), (0.5))
])

class RadiographSexDataset_17000(Dataset):
    def __init__(
        self,
        root_dir: str,
        fold_nums: list,
        transforms = transform,
        albumentations_package: bool=False
    ):
        super().__init__()

        self.root_dir = root_dir
        self.fold_nums = fold_nums
        self.transforms = transforms
        self.albumentations = albumentations_package
        
        train = [i for i in range(1, 20 + 1)]
        val   = [i for i in range(21, 25 + 1)]
        test  = [i for i in range(26, 30 + 1)]
        
        # labels
        self.filepaths = []
        for i in range(1,31):
            filepath = f'/home/bernardo/datasets/pan-radiographs/splits/{i:02d}.txt'
            with open(filepath) as f:
                for line in f:
                    fname = line.strip()
                    self.filepaths.append(fname)
                    
        # this maybe useful later for reproducibility
        self.filepaths.sort()

    def __len__(self) -> int:
        return len(self.filepaths)

    def _getitem_albumentations(self, filepath: str):
        # Read image with OpenCV2 and convert it from BGR (OpenCV2) to RGB (most common format)
        image = cv2.imread(filepath)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # apply transformation with albumentations package
        if self.transforms is not None:
            img_tensor = self.transforms(image=image)["image"]

        return img_tensor

    def _getitem_torchvision(self, filepath: str):
        image = Image.open(filepath)

        if self.transforms is not None:
            img_tensor = self.transforms(image)

        return img_tensor

    def __getitem__(self, index: int):
        # image and label
        filepath = self.filepaths[index]
        if fname.split('-')[0] == 'pan': 
            fpath = os.path.join(f'/home/bernardo/datasets/pan-radiographs/1st-set/images/{fname}')
        else:
            fpath = os.path.join(f'/home/bernardo/datasets/pan-radiographs/2nd-set/images/{fname}')

        # get label
        sex = fname.split('-')[10]
        age = fname.split('-')[-2][1:]
        months = fname.split('-')[-1][1:3]
        
        assert sex in ['F', 'M']
        if sex == 'F':
            label = 0
        else:
            label = 1
        
        label_tensor = torch.tensor(label, dtype=torch.int64)
        img_tensor = Image.open(fpath)        
        
        return img_tensor, label_tensor

In [27]:
ds = RadiographSexDataset_17000(
    root_dir=None,
    fold_nums=30,
    transforms=transform,
    albumentations_package=False
)

In [28]:
for img, label in ds:
    pass

In [29]:
# labels
filepaths = []
for i in range(1,31):
    filepath = f'/home/bernardo/datasets/pan-radiographs/splits/{i:02d}.txt'
    with open(filepath) as f:
        for line in f:
            fname = line.strip()
            filepaths.append(fname)

# for fold_num in range(self.fold_nums):
#     if fname.split('-')[0] == 'pan': 
#         fpath = os.path.join(f'/home/bernardo/datasets/pan-radiographs/1st-set/images/{fname}')
#     else:
#         fpath = os.path.join(f'/home/bernardo/datasets/pan-radiographs/2nd-set/images/images/{fname}')


## Contagem do nº de homens e mulheres

In [36]:
count_m = 0
count_f = 0
for i in range(1,30+1):
    filepath = f'/home/bernardo/datasets/pan-radiographs/splits/{i:02d}.txt'
    with open(filepath) as f:
        for line in f:
            img_relpath = line.strip()
            filename = img_relpath.split('/')[-1]
            sex = filename.split('-')[10]
            #print(sex)
            if sex == 'M':
                count_m +=1
            else:
                count_f +=1

print('Males:',count_m,'Females:',count_f)

Males: 6341 Females: 10483


In [58]:
count_1 = 0
count_2 = 0
for i in range(1,31):
    filepath = f'/home/bernardo/datasets/pan-radiographs/splits/{i:02d}.txt'
    with open(filepath) as f:
        for line in f: #ler cada linha do txt
            fname = line.strip().split('/')[2] #retirar o \n
            if fname.split('-')[0] == 'pan': #separar os arquivos pan e panreport
                fpath = os.path.join(f'/home/bernardo/datasets/pan-radiographs/1st-set/images/{fname}')
                count_1 +=1
            else:
                fpath = os.path.join(f'/home/bernardo/datasets/pan-radiographs/2nd-set/images/{fname}')
                count_2 +=1
            #im = Image.open(fpath)
    
print(count_1 + count_2)

16824


## Número total de radiografias

In [2]:
count = 0
for i in range(1,30+1):
    filepath = f'/home/bernardo/datasets/pan-radiographs/splits/{i:02d}.txt'
    with open(filepath) as f:
        content = f.read()
        colist = content.split('\n')
        for i in colist:
            if i:
                count +=1
print('Number of radiographs:',count)       

Number of radiographs: 16824


## Número de radiografias para o treinamento

In [7]:
# train_folds: int = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20],
# val_folds: int = [21,22,23,24,25]

count = 0
for i in range(1,20+1):
    filepath = f'/home/bernardo/datasets/pan-radiographs/splits/{i:02d}.txt'
    with open(filepath) as f:
        content = f.read()
        colist = content.split('\n')
        for i in colist:
            if i:
                count +=1
print('Number of radiographs:',count)   

Number of radiographs: 11212


In [9]:
for i in range(21,25+1):
    print(i)

21
22
23
24
25


## Número de radiografias para a validação

In [51]:
count = 0
for i in range(21,25+1):
    filepath = f'/home/bernardo/datasets/pan-radiographs/splits/{i:02d}.txt'
    with open(filepath) as f:
        content = f.read()
        colist = content.split('\n')
        for i in colist:
            if i:
                count +=1
print('Number of radiographs:',count)   

Number of radiographs: 2808


## Número de radiografias para o teste

In [6]:
count = 0
for i in range(25,30+1):
    filepath = f'/home/bernardo/datasets/pan-radiographs/splits/{i:02d}.txt'
    with open(filepath) as f:
        content = f.read()
        colist = content.split('\n')
        for i in colist:
            if i:
                count +=1
print('Number of radiographs:',count)  

Number of radiographs: 3357


## Nº de homens e mulheres durante o treinamento

In [33]:
# train folds : fold 1 {1,2,3,4,5} ; fold 2 {6,7,8,9,10} ; fold 3 {11,12,13,14,15} ; fold 4 {16,17,18,19,20}
# val folds : fold 1 {21,22,23,24,25}

count_m = 0
count_f = 0
for i in range(1,20+1):
    filepath = f'/home/bernardo/datasets/pan-radiographs/splits/{i:02d}.txt'
    with open(filepath) as f:
        for line in f:
            img_relpath = line.strip()
            filename = img_relpath.split('/')[-1]
            sex = filename.split('-')[10]
            if sex == 'M':
                count_m +=1
            else:
                count_f +=1

print('Males in train:',count_m)
print('Females in train:',count_f)

Males in train: 4200
Females in train: 7012


## Nº de homens e mulheres durante a validação

In [8]:
# train folds : fold 1 {1,2,3,4,5} ; fold 2 {6,7,8,9,10} ; fold 3 {11,12,13,14,15} ; fold 4 {16,17,18,19,20}
# val folds : fold 1 {21,22,23,24,25}

count_m = 0
count_f = 0
for i in range(21,25+1):
    filepath = f'/home/bernardo/datasets/pan-radiographs/splits/{i:02d}.txt'
    with open(filepath) as f:
        for line in f:
            img_relpath = line.strip()
            filename = img_relpath.split('/')[-1]
            sex = filename.split('-')[10]
            if sex == 'M':
                count_m +=1
            else:
                count_f +=1

print('Males in val:',count_m)
print('Females in val:',count_f)

Males in val: 1073
Females in val: 1735


## Nº de pessoas por intervalo de idade

In [38]:
# intervalos de idade: [0-10] ; [11-20] ; [21-30] ; [31-40] ; [41-50] ; [51-60] ; [61-70] ; [71-80] ; [81-90] ; [91-100]
count = 0
count_1 = 0 
count_2 = 0
count_3 = 0
count_4 = 0
count_5 = 0
count_6 = 0
count_7 = 0
count_8 = 0
count_9 = 0
count_10 = 0

count_ = 0

for i in range(1,30+1):
    filepath = f'/home/bernardo/datasets/pan-radiographs/splits/{i:02d}.txt'
    with open(filepath) as f:
        for line in f:
            img_relpath = line.strip()
            filename = img_relpath.split('/')[-1]
            age = filename.split('-')[12][1:]
            
            if age == 'NA':
                continue
            else:
                age = int(filename.split('-')[12][1:])
                #print(age)
                count += 1
                
            if age <= 10:
                count_1 += 1
            elif 10< age <= 20:
                count_2 += 1
            elif 20< age <=30 :
                count_3 += 1
            elif 30< age <=40:
                count_4 += 1
            elif 40< age <=50:
                count_5 += 1
            elif 50< age <=60:
                count_6 += 1
            elif 60< age <=70:
                count_7 += 1
            elif 70< age <=80:
                count_8 += 1
            elif 80< age <=90:
                count_9 += 1
            elif 90< age <=100:
                count_10 += 1 
            else:
                count_ += 1
            
        #contagem = {'Males':count_males,'Females':count_females}

print('Nº de radiografias com idade:',count)
print('Nº de pessoas entre [0,10] anos:',count_1)
print('Nº de pessoas entre [11,20] anos:',count_2)
print('Nº de pessoas entre [21,30] anos:', count_3)
print('Nº de pessoas entre [31,40] anos:', count_4)
print('Nº de pessoas entre [41,50] anos:', count_5)
print('Nº de pessoas entre [51,60] anos:', count_6)
print('Nº de pessoas entre [61,70] anos:', count_7)
print('Nº de pessoas entre [71,80] anos:', count_8)
print('Nº de pessoas entre [81,90] anos:', count_9)
print('Nº de pessoas entre [91,100] anos:', count_10)
print('Pessoas de fora:',count_)

Nº de radiografias com idade: 8013
Nº de pessoas entre [0,10] anos: 730
Nº de pessoas entre [11,20] anos: 1414
Nº de pessoas entre [21,30] anos: 2006
Nº de pessoas entre [31,40] anos: 1435
Nº de pessoas entre [41,50] anos: 1090
Nº de pessoas entre [51,60] anos: 765
Nº de pessoas entre [61,70] anos: 385
Nº de pessoas entre [71,80] anos: 139
Nº de pessoas entre [81,90] anos: 49
Nº de pessoas entre [91,100] anos: 0
Pessoas de fora: 0


## Nº de homens e mulheres por intervalo de idade

In [76]:
# intervalos de idade: [0-10] ; [11-20] ; [21-30] ; [31-40] ; [41-50] ; [51-60] ; [61-70] ; [71-80] ; [81-90] ; [91-100]
count = 0
count_1 = 0 
count_2 = 0
count_3 = 0
count_4 = 0
count_5 = 0
count_6 = 0
count_7 = 0
count_8 = 0
count_9 = 0
count_10 = 0

count_ = 0

count_11 = 0 
count_21 = 0
count_31 = 0
count_41 = 0
count_51 = 0
count_61 = 0
count_71 = 0
count_81 = 0
count_91 = 0
count_101 = 0

count__ = 0

for i in range(1,30+1):
    filepath = f'/home/bernardo/datasets/pan-radiographs/splits/{i:02d}.txt'
    with open(filepath) as f:
        for line in f:
            img_relpath = line.strip()
            filename = img_relpath.split('/')[-1]
            age = filename.split('-')[12][1:]
            if age == 'NA':
                continue
            else:
                age = int(filename.split('-')[12][1:])
                count += 1
            sex = filename.split('-')[10]
            if sex == 'M':
                if age <= 10:
                    count_11 += 1
                elif 10< age <= 20:
                    count_21 += 1
                elif 20< age <=30 :
                    count_31 += 1
                elif 30< age <=40:
                    count_41 += 1
                elif 40< age <=50:
                    count_51 += 1
                elif 50< age <=60:
                    count_61 += 1
                elif 60< age <=70:
                    count_71 += 1
                elif 70< age <=80:
                    count_81 += 1
                elif 80< age <=90:
                    count_91 += 1
                elif 90< age <=100:
                    count_101 += 1 
                else:
                    count__ += 1
                    
                
            else:
                if age <= 10:
                    count_1 += 1
                elif 10< age <= 20:
                    count_2 += 1
                elif 20< age <=30 :
                    count_3 += 1
                elif 30< age <=40:
                    count_4 += 1
                elif 40< age <=50:
                    count_5 += 1
                elif 50< age <=60:
                    count_6 += 1
                elif 60< age <=70:
                    count_7 += 1
                elif 70< age <=80:
                    count_8 += 1
                elif 80< age <=90:
                    count_9 += 1
                elif 90< age <=100:
                    count_10 += 1 
                else:
                    count_ += 1 
            
        #contagem = {'Males':count_males,'Females':count_females}

print('Nº de radiografias com idade:',count)
print('Nº de mulheres entre [0,10] anos:',count_1)
print('Nº de mulheres entre [11,20] anos:',count_2)
print('Nº de mulheres entre [21,30] anos:', count_3)
print('Nº de mulheres entre [31,40] anos:', count_4)
print('Nº de mulheres entre [41,50] anos:', count_5)
print('Nº de mulheres entre [51,60] anos:', count_6)
print('Nº de mulheres entre [61,70] anos:', count_7)
print('Nº de mulheres entre [71,80] anos:', count_8)
print('Nº de mulheres entre [81,90] anos:', count_9)
print('Nº de mulheres entre [91,100] anos:', count_10)
print('Mulheres de fora:',count_)
print()
print('Nº de radiografias com idade:',count)
print('Nº de homens entre [0,10] anos:',count_11)
print('Nº de homens entre [11,20] anos:',count_21)
print('Nº de homens entre [21,30] anos:', count_31)
print('Nº de homens entre [31,40] anos:', count_41)
print('Nº de homens entre [41,50] anos:', count_51)
print('Nº de homens entre [51,60] anos:', count_61)
print('Nº de homens entre [61,70] anos:', count_71)
print('Nº de homens entre [71,80] anos:', count_81)
print('Nº de homens entre [81,90] anos:', count_91)
print('Nº de homens entre [91,100] anos:', count_101)
print('Homens de fora:',count__)

Nº de radiografias com idade: 8013
Nº de mulheres entre [0,10] anos: 378
Nº de mulheres entre [11,20] anos: 843
Nº de mulheres entre [21,30] anos: 1267
Nº de mulheres entre [31,40] anos: 893
Nº de mulheres entre [41,50] anos: 692
Nº de mulheres entre [51,60] anos: 502
Nº de mulheres entre [61,70] anos: 233
Nº de mulheres entre [71,80] anos: 81
Nº de mulheres entre [81,90] anos: 31
Nº de mulheres entre [91,100] anos: 0
Mulheres de fora: 0

Nº de radiografias com idade: 8013
Nº de homens entre [0,10] anos: 352
Nº de homens entre [11,20] anos: 571
Nº de homens entre [21,30] anos: 739
Nº de homens entre [31,40] anos: 542
Nº de homens entre [41,50] anos: 398
Nº de homens entre [51,60] anos: 263
Nº de homens entre [61,70] anos: 152
Nº de homens entre [71,80] anos: 58
Nº de homens entre [81,90] anos: 18
Nº de homens entre [91,100] anos: 0
Homens de fora: 0


## Contagem de acertos

In [None]:
def get_count(ground_truth= list, prediction= list):
    acertos_homens, acertos_mulheres = 0, 0
    total_homens, total_mulheres = 0, 0
    
    for idx, (img, label) in tqdm(enumerate(val_dataloader)):
        img, label = img.cuda(), label.cuda()
        preds = model(img)
        prediction = torch.argmax(preds).item()
        ground_truth = label.item()
        
        if ground_truth == 0:
            total_mulheres += 1
        else:
            total_homens += 1

        if ground_truth == prediction:
            if ground_truth == 0:
                acertos_mulheres += 1
            else:
                acertos_homens += 1

    print(f'Mulheres: {total_mulheres}')
    print(f'Homens: {total_homens}')
    print(f'Total: {total_mulheres + total_homens}')
    print(f'Acertos Mulheres: {acertos_mulheres}')
    print(f'Acertos Homens: {acertos_homens}')

    return acertos_mulheres, acertos_homens

get_count(ground_truth, prediction)

In [None]:
# precison
def precision(outputs, labels):
    op = outputs.to(device)
    la = labels.to(device)
    _, preds = torch.max(op, dim=1)
    return torch.tensor(precision_score(la,preds, average=‘weighted’))


In [60]:
pip install captum

Collecting captum
  Downloading captum-0.5.0-py3-none-any.whl (1.4 MB)
[K     |████████████████████████████████| 1.4 MB 992 kB/s eta 0:00:01
Installing collected packages: captum
Successfully installed captum-0.5.0
Note: you may need to restart the kernel to use updated packages.


In [61]:
import matplotlib.pyplot as plt
import numpy as np

%matplotlib inline

import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF

from torchvision import models

from captum.attr import IntegratedGradients
from captum.attr import Saliency
from captum.attr import DeepLift
from captum.attr import NoiseTunnel
from captum.attr import visualization as viz

In [62]:
from torchvision import transforms as T
transform = T.Compose([
                T.Resize((inputs.img_size, inputs.img_size)),
                T.ToTensor(),
                T.Normalize(inputs.MEAN, inputs.STDV)
            ])

In [64]:
from torch.utils.data import DataLoader

val_dataset = FullRadiographSexDataset(root_dir=inputs.DATASET_DIR,fold_nums=inputs.val_folds,transforms=get_transforms(inputs, subset=subset))

val_dataloader = DataLoader(val_dataset,batch_size=1,shuffle=False,num_workers=0)

train_dataset = FullRadiographSexDataset(root_dir=inputs.DATASET_DIR,fold_nums=inputs.train_folds,transforms=get_transforms(inputs, subset=subset))

train_dataloader = DataLoader(train_dataset,batch_size=1,shuffle=False,num_workers=0)

Using only horizontal flip augmentation.
Using only horizontal flip augmentation.


In [69]:
# get the classes
for i in range(1,31):
    filepath = f'/home/bernardo/datasets/pan-radiographs/splits/{i:02d}.txt'
    with open(filepath) as f:
        for line in f: #ler cada linha do txt
            fname = line.strip().split('/')[2] #retirar o \n
            sex = fname.split('-')[10]
            #print(sex)
            if fname.split('-')[0] == 'pan': #separar os arquivos pan e panreport
                fpath = os.path.join(f'/home/bernardo/datasets/pan-radiographs/1st-set/images/{fname}')
            else:
                fpath = os.path.join(f'/home/bernardo/datasets/pan-radiographs/2nd-set/images/{fname}')

In [70]:
def attribute_image_features(algorithm, input, **kwargs):
    net.zero_grad()
    tensor_attributions = algorithm.attribute(input,
                                              target=labels[ind],
                                              **kwargs
                                             )
    
    return tensor_attributions

In [98]:
dataiter = iter(val_dataloader)
images, labels = dataiter.next()
ind = 1

In [104]:
inputs = Inputs(selected_model='efficientnet-b0')
model = get_classification_model(inputs.model_name, 2)
checkpoint = torch.load('/home/bernardo/github/sex-age-estimation/backup-bia/patch-1/checkpoint-efficientnet-b0-fold-2-max-acc.pth.tar')
model.load_state_dict(checkpoint['state_dict'])

saliency = Saliency(model)
grads = saliency.attribute(images, target=labels[ind].item())
    
grads = np.transpose(grads.squeeze().cpu().detach().numpy(), (1, 2, 0))

IndexError: index 1 is out of bounds for dimension 0 with size 1

In [105]:
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
import torch    
import cv2
import numpy as np
import requests
import torchvision.transforms as transforms
from pytorch_grad_cam import EigenCAM
from pytorch_grad_cam.utils.image import show_cam_on_image, scale_cam_image
from PIL import Image

In [106]:
COLORS = np.random.uniform(0, 255, size=(80, 3))

In [108]:
def parse_detections(results):
    detections = results.pandas().xyxy[0]
    detections = detections.to_dict()
    boxes, colors, names = [], [], []

    for i in range(len(detections["xmin"])):
        confidence = detections["confidence"][i]
        if confidence < 0.2:
            continue
        xmin = int(detections["xmin"][i])
        ymin = int(detections["ymin"][i])
        xmax = int(detections["xmax"][i])
        ymax = int(detections["ymax"][i])
        name = detections["name"][i]
        category = int(detections["class"][i])
        color = COLORS[category]

        boxes.append((xmin, ymin, xmax, ymax))
        colors.append(color)
        names.append(name)
    return boxes, colors, names

In [109]:
def draw_detections(boxes, colors, names, img):
    for box, color, name in zip(boxes, colors, names):
        xmin, ymin, xmax, ymax = box
        cv2.rectangle(
            img,
            (xmin, ymin),
            (xmax, ymax),
            color, 
            2)

        cv2.putText(img, name, (xmin, ymin - 5),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2,
                    lineType=cv2.LINE_AA)
    return img