<a href="https://colab.research.google.com/github/liangyuRain/chest-xray-pneumonia/blob/master/Project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import numpy as np
import os

import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms

import torch.utils.data

import matplotlib as plt


In [9]:
def train(model, device, train_loader, optimizer, epoch, log_interval):
    model.train()
    losses = []
    for batch_idx, (data, label) in enumerate(train_loader):
        data, label = data.to(device), label.to(device)
        optimizer.zero_grad()
        
        output = model(data)
        loss = model.loss(output, label)
        losses.append(loss.item())
        loss.backward()
        
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
    return np.mean(losses)


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, label in test_loader:
            data, label = data.to(device), label.to(device)
            output = model(data)
            test_loss += model.testLoss(output, label).item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(label.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_accuracy = 100. * correct / len(test_loader.dataset)
    
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return test_loss, test_accuracy


In [12]:
LR = 0.1
WEIGHT_DECAY = 0.0001
MOMENTUM = 0.9
EPOCHS = 100
BATCH_SIZE = 32
USE_CUDA = True
SEED = 0
LOG_INTERVAL = 100


In [4]:
use_cuda = USE_CUDA and torch.cuda.is_available()
torch.manual_seed(SEED)
device = torch.device("cuda" if use_cuda else "cpu")
print('Using device', device)

import multiprocessing
print('num cpus:', multiprocessing.cpu_count())
kwargs = {'num_workers': multiprocessing.cpu_count(),
          'pin_memory': True} if use_cuda else {}


Using device cuda
num cpus: 8


In [5]:
model = torchvision.models.resnet18(pretrained=True)

# avoid training pretrained model
for param in model.parameters():
    param.requires_grad = False

# set categories to 2
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.fc.parameters(), 
                      lr=LR, 
                      momentum=MOMENTUM, 
                      weight_decay=WEIGHT_DECAY)

model.loss = nn.CrossEntropyLoss()
model.testLoss = nn.CrossEntropyLoss(reduction='sum')


In [6]:
INPUT_SIZE = (224, 224)


class ImageLoader(object):

    def __init__(self, batchSize):
        super(ImageLoader, self).__init__()
        transform_train = transforms.Compose([
            transforms.Resize(size=INPUT_SIZE),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        transform_test = transforms.Compose([
            transforms.Resize(size=INPUT_SIZE),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        train_dataset = datasets.ImageFolder(root='chest_xray_dataset/train',
                                             transform=transform_train)
        self.trainloader = torch.utils.data.DataLoader(train_dataset,
                                                       batch_size=batchSize,
                                                       shuffle=True,
                                                       num_workers=2)

        test_dataset = datasets.ImageFolder(root='chest_xray_dataset/test',
                                            transform=transform_test)
        self.testloader = torch.utils.data.DataLoader(test_dataset,
                                                      batch_size=batchSize,
                                                      shuffle=False,
                                                      num_workers=2)

        self.classes = ('normal', 'pneumonia')


In [7]:
def plot(data, title, xlab, ylab):
    x_val = [x[0] for x in data]
    y_val = [x[1] for x in data]
    plt.plot(x_val, y_val)
    plt.title(title)
    plt.xlabel(xlab)
    plt.ylabel(ylab)
    plt.show()


In [13]:
loader = ImageLoader(batchSize=BATCH_SIZE)
train_losses = []
test_losses = [] 
test_accuracies = []

try:
    for epoch in range(EPOCHS + 1):
        train_loss = train(model, device, loader.trainloader, optimizer, epoch, LOG_INTERVAL)
        test_loss, test_accuracy = test(model, device, loader.testloader)
        train_losses.append((epoch, train_loss))
        test_losses.append((epoch, test_loss))
        test_accuracies.append((epoch, test_accuracy))
except KeyboardInterrupt as ke:
    print('Interrupted')
except:
    import traceback
    traceback.print_exc()
finally:
    plot(train_losses, 'Train Losses', 'Epochs', 'Loss')
    plot(test_losses, 'Test Losses', 'Epochs', 'Loss')
    plot(test_accuracies, 'Test Accuracies', 'Epochs', 'Accuracy')







Test set: Average loss: 39.8822, Accuracy: 539/624 (86%)








Test set: Average loss: 39.6694, Accuracy: 545/624 (87%)








Test set: Average loss: 47.0202, Accuracy: 528/624 (85%)








Test set: Average loss: 25.7603, Accuracy: 537/624 (86%)








Test set: Average loss: 28.5177, Accuracy: 544/624 (87%)








Test set: Average loss: 26.3838, Accuracy: 542/624 (87%)








Test set: Average loss: 42.1078, Accuracy: 521/624 (83%)








Test set: Average loss: 44.7245, Accuracy: 505/624 (81%)








Test set: Average loss: 40.7128, Accuracy: 549/624 (88%)








Test set: Average loss: 56.2701, Accuracy: 507/624 (81%)








Test set: Average loss: 27.1260, Accuracy: 544/624 (87%)








Test set: Average loss: 30.9603, Accuracy: 526/624 (84%)








Test set: Average loss: 139.1444, Accuracy: 427/624 (68%)








Test set: Average loss: 28.0086, Accuracy: 541/624 (87%)








Test set: Average loss: 137.9536, Accuracy: 425/624 (68%)








Test set: Average loss: 59.2031, Accuracy: 484/624 (78%)








Test set: Average loss: 46.5067, Accuracy: 512/624 (82%)








Test set: Average loss: 55.2816, Accuracy: 497/624 (80%)








Test set: Average loss: 34.9512, Accuracy: 546/624 (88%)








Test set: Average loss: 131.1398, Accuracy: 441/624 (71%)








Test set: Average loss: 39.9204, Accuracy: 524/624 (84%)








Test set: Average loss: 24.5827, Accuracy: 540/624 (87%)








Test set: Average loss: 21.8341, Accuracy: 547/624 (88%)








Test set: Average loss: 42.9477, Accuracy: 510/624 (82%)








Test set: Average loss: 20.0852, Accuracy: 548/624 (88%)








Test set: Average loss: 14.5494, Accuracy: 558/624 (89%)








Test set: Average loss: 29.6680, Accuracy: 529/624 (85%)








Test set: Average loss: 55.6908, Accuracy: 497/624 (80%)








Test set: Average loss: 45.6431, Accuracy: 498/624 (80%)








Test set: Average loss: 25.9798, Accuracy: 547/624 (88%)








Test set: Average loss: 60.0829, Accuracy: 483/624 (77%)








Test set: Average loss: 73.0796, Accuracy: 477/624 (76%)








Test set: Average loss: 99.1867, Accuracy: 459/624 (74%)








Test set: Average loss: 148.1672, Accuracy: 415/624 (67%)








Test set: Average loss: 45.2439, Accuracy: 488/624 (78%)








Test set: Average loss: 90.8027, Accuracy: 451/624 (72%)








Test set: Average loss: 25.9665, Accuracy: 544/624 (87%)








Test set: Average loss: 33.3870, Accuracy: 532/624 (85%)







Interrupted


AttributeError: module 'matplotlib' has no attribute 'plot'