In [None]:
import cv2
import numpy as np
import pydicom
%matplotlib inline
import matplotlib.pyplot as plt
import torch

In [None]:
# This function reads in a .dcm file, checks the important fields for our device, and returns a numpy array
# of just the imaging data
def check_dicom(filename): 
    print('Load file {} ...'.format(filename))
    ds = pydicom.dcmread(filename)
    if ds.BodyPartExamined=='CHEST' and ds.Modality=='DX' and ds.PatientPosition in ['AP', 'PA']:
        print('File {} loaded.'.format(filename))
    else:
        print('File {} is not adequate to be scores by PneumoNet.'.format(filename))
        return None
    img = ds.pixel_array
    return img
    
    
# This function takes the numpy array output by check_dicom and 
# runs the appropriate pre-processing needed for our model input
def preprocess_image(img,img_mean,img_std,img_size):

    img_prep = cv2.cvtColor(cv2.resize(img, img_size), cv2.COLOR_GRAY2RGB)
    img_prep = (img_prep - img_mean)/img_std

    return np.expand_dims(img_prep, 0)

# This function loads in our trained model w/ weights and compiles it 
def load_model(model_path):

    model = torch.load(model_path)
    model.eval()
    return model

# This function uses our device's threshold parameters to predict whether or not
# the image shows the presence of pneumonia using our trained model
def predict_image(model, img, thresh): 
    pred = model(img)
    prediction = 0
    if pred > thresh:
        prediction = 1
    return prediction 

In [None]:
test_dicoms = ['test1.dcm','test2.dcm','test3.dcm','test4.dcm','test5.dcm','test6.dcm']

model_path = 'my_model.pt' #path to saved model
weight_path = 'PneumoVGG16_weights.pt' #path to saved best weights

IMG_SIZE=(1,224,224,3) # This might be different if you did not use vgg16
img_mean = np.array([0.485, 0.456, 0.406]) # loads the mean image value they used during training preprocessing
img_std = np.array([0.229, 0.224, 0.225]) # loads the std dev image value they used during training preprocessing
my_model = load_model(model_path) #loads model
thresh = 0.5 #loads the threshold they chose for model classification

# use the .dcm files to test your prediction
for i in test_dicoms:
    
    img = check_dicom(i)
    
    if img is None:
        continue
        
    img_proc = preprocess_image(img,img_mean,img_std,IMG_SIZE)
    pred = predict_image(my_model,img_proc,thresh)
    print(pred)