<a href="https://colab.research.google.com/github/liangyuRain/chest-xray-pneumonia/blob/master/Project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import torch.nn as nn
import numpy as np
import os

import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [0]:
def train(model, device, train_loader, optimizer, epoch, log_interval):
    model.train()
    losses = []
    for batch_idx, (data, label) in enumerate(train_loader):
        data, label = data.to(device), label.to(device)
        optimizer.zero_grad()
        
        output = model(data)
        loss = model.loss(output, label)
        losses.append(loss.item())
        loss.backward()
        
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
    return np.mean(losses)

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, label in test_loader:
            data, label = data.to(device), label.to(device)
            output = model(data)
            test_loss += model.loss(output, label, reduction='sum').item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(label.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_accuracy = 100. * correct / (len(test_loader.dataset) * test_loader.dataset.sequence_length)
    
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return test_loss, test_accuracy

In [0]:
LR = 0.9
WEIGHT_DECAY = 0.0001
MOMENTUM = 0.9
EPOCHS = 5000

In [0]:
model = torchvision.models.resnet18(pretrained=True)

# avoid training pretrained model
for param in model_conv.parameters():
    param.requires_grad = False

# set categories to 2
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)

model_conv = model_conv.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model_conv.fc.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)

In [0]:
use_cuda = USE_CUDA and torch.cuda.is_available()
torch.manual_seed(SEED)
device = torch.device("cuda" if use_cuda else "cpu")
print('Using device', device)
import multiprocessing
print('num cpus:', multiprocessing.cpu_count())
kwargs = {'num_workers': multiprocessing.cpu_count(),
          'pin_memory': True} if use_cuda else {}

try:
    for epoch in range(start_epoch, EPOCHS + 1):
        train(model, device, train_loader, optimizer, epoch, LOG_INTERVAL)
        test(model, device, test_loader)
except KeyboardInterrupt as ke:
    print('Interrupted')
except:
    import traceback
    traceback.print_exc()
finally: