In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from tqdm import tqdm
from CheXpertDataset import CheXpertDataset

IMG_DIR = "/groups/CS156b/2022/team_dirs/liquid_death/train_img_npy"
LABEL_FILE_PATH = "/groups/CS156b/2022/team_dirs/liquid_death/train_labels.npy"
STATE_DICT_PATH = "/home/dnee/CS156b/simple_net_save.pt"

BATCH_SIZE = 4
NUM_CORES = 8
NUM_EPOCHS = 2

DISEASES = [
    "No Finding",
    "Enlarged Cardiomediastinum",
    "Cardiomegaly",
    "Lung Opacity",
    "Lung Lesion",
    "Edema",
    "Consolidation",
    "Pneumonia",
    "Atelectasis",
    "Pneumothorax",
    "Pleural Effusion",
    "Pleural Other",
    "Fracture",
    "Support Devices"
]

CLASSES = [
    "Negative",
    "Unsure",
    "Positive"
]

print("Setup Done")

Setup Done


In [None]:
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv1 = nn.Conv2d(1, 32, 5)
        self.conv2 = nn.Conv2d(32, 32, 3)
        self.conv3 = nn.Conv2d(32, 16, 3)
        self.fc1 = nn.Linear(6 * 6 * 16, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 42)
        self.sm = nn.Softmax(dim=1)
    
    def forward(self, x):
        # Start with 1024x1024x1
        x = self.pool(F.relu(self.conv1(x))) # Turns to 510x510x32
        x = self.pool(F.relu(self.conv2(x))) # Turns to 254x254x32
        x = self.pool(F.relu(self.conv2(x))) # Turns to 126x126x32
        x = self.pool(F.relu(self.conv2(x))) # Turns to 62x62x32
        x = self.pool(F.relu(self.conv2(x))) # Turns to 30x30x32
        x = self.pool(F.relu(self.conv2(x))) # Turns to 14x14x32
        x = self.pool(F.relu(self.conv3(x))) # Turns to 6X6X16
        x = torch.flatten(x, 1) # start on dim 1
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        # reshape into 3 categories per disease
        x = torch.reshape(self.fc3(x), (-1, 3, 14)) 
        # Softmax along dimension to normalize probabilities
        x = self.sm(x)
        return x

In [8]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

dataset = CheXpertDataset(IMG_DIR, LABEL_FILE_PATH)
N = len(dataset)
trainset, testset = torch.utils.data.random_split(dataset, [N - (N//5), N//5])

trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_CORES)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_CORES)

net = SimpleNet()
net.to(device)

criterion = nn.CrossEntropyLoss(reduction="sum")
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
print("Objects Done")

Device: cuda:0
Objects Done


In [9]:
for epoch in range(NUM_EPOCHS):  # loop over the dataset multiple times
    for i, data in enumerate(tqdm(trainloader)):
        # get the inputs; data is a list of [inputs, labels]
        inputs, targets = data[0].to(device), data[1].to(device)
        
        # zero the parameter gradients
        optimizer.zero_grad()
        
        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

torch.save(net.state_dict(), STATE_DICT_PATH)
print('Finished Training')

100%|██████████| 30460/30460 [34:26<00:00, 14.74it/s]
100%|██████████| 30460/30460 [34:11<00:00, 14.85it/s]

Finished Training





In [None]:
correct_count = [[0 for _ in CLASSES] for _ in DISEASES]
total_count = [[0 for _ in CLASSES] for _ in DISEASES]
with torch.no_grad():
    for i, data in enumerate(tqdm(testloader)):
        images, labels = data[0].to(device), data[1].to(device)
        outputs = net(images)
        _, preds = torch.max(outputs, 1)
        for i in range(len(images)):
            for j in range(len(DISEASES)):
                if preds[i][j] == labels[i][j]:
                    correct_count[j][labels[i][j]] += 1
                total_count[j][labels[i][j]] += 1

In [13]:
for i, d in enumerate(DISEASES):
    print("------------")
    print(f"Disease {d}: ")
    for j, c in enumerate(CLASSES):
        print(f"Class {c}: {correct_count[i][j]} / {total_count[i][j]}")

------------
Disease No Finding: 
Class Negative: 26248 / 26248
Class Unsure: 0 / 1353
Class Positive: 0 / 2858
------------
Disease Enlarged Cardiomediastinum: 
Class Negative: 0 / 2567
Class Unsure: 26663 / 26663
Class Positive: 0 / 1229
------------
Disease Cardiomegaly: 
Class Negative: 0 / 1778
Class Unsure: 25205 / 25205
Class Positive: 0 / 3476
------------
Disease Lung Opacity: 
Class Negative: 0 / 773
Class Unsure: 14634 / 14634
Class Positive: 0 / 15052
------------
Disease Lung Lesion: 
Class Negative: 0 / 106
Class Unsure: 29238 / 29238
Class Positive: 0 / 1115
------------
Disease Edema: 
Class Negative: 0 / 2478
Class Unsure: 20173 / 20173
Class Positive: 0 / 7808
------------
Disease Consolidation: 
Class Negative: 0 / 3216
Class Unsure: 25474 / 25474
Class Positive: 0 / 1769
------------
Disease Pneumonia: 
Class Negative: 0 / 250
Class Unsure: 29809 / 29809
Class Positive: 0 / 400
------------
Disease Atelectasis: 
Class Negative: 0 / 69
Class Unsure: 25750 / 25750
Cla

<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=9e63b1a8-baad-4be7-b61d-5ce16070297e' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>