In [1]:
import os

import torch
import torchvision.utils
from matplotlib import pyplot as plt
from torch import nn
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

import d2l.torch

writer = SummaryWriter('runs/astronomical_sources_classifier')

# import from local project
import utils
from FitsImageFolder import FitsImageFolder

In [8]:
print("torch version: ", torch.__version__)

src_root_path = os.path.join("/home/duncan/PycharmProjects/MyResearchProject_Duncan", "data/sources")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

torch version:  1.10.1+cu102


In [4]:
data_transforms = {
    'train': transforms.Compose([
        transforms.ToTensor(),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(degrees=45)]),
    'val': transforms.Compose([
        transforms.ToTensor(),
    ])
}

In [9]:
dataset = FitsImageFolder(root=src_root_path, transform=data_transforms['train'])
train_set_size = int(len(dataset) * 0.8)
validation_set_size = int(len(dataset) * 0.1)
test_set_size = len(dataset) - train_set_size - validation_set_size
train_set, validation_set, test_set = random_split(dataset, [train_set_size, validation_set_size, test_set_size])

training_loader = DataLoader(
    train_set,
    batch_size=16,
    shuffle=True,
    num_workers=8
)

validation_loader = DataLoader(
    validation_set,
    batch_size=16,
    shuffle=True,
    num_workers=8
)

dataloaders = {'train': training_loader, 'val': validation_loader}

print("Full set size:", len(dataset))
print("Train set size: ", train_set_size)
print("Validation set size: ", validation_set_size)
print("Test set size: ", test_set_size)

class_names = dataset.classes
print(class_names)

Full set size: 76144
Train set size:  60915
Validation set size:  7614
Test set size:  7615
['GALAXY', 'QSO', 'STAR']


In [11]:
dataiter = iter(training_loader)
images, labels = dataiter.__next__()
i = images[0][0]
r = images[0][1]
g = images[0][2]

color_fits = torch.stack((i, r, g), dim=0)
img_grid = torchvision.utils.make_grid(color_fits)
writer.add_image('an fake rgb source image', img_grid)

In [12]:
def train_model(model, lr, wd, lr_period, lr_decay, num_epochs=10):
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=wd)
    criterion = nn.CrossEntropyLoss(reduction="none")
    # Decay LR by a factor of [gamma] every [step_size] epochs
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_period, gamma=lr_decay)
    num_batches, timer = len(training_loader), d2l.torch.Timer()
    legend = ['train loss', 'train acc', 'valid acc']
    animator = d2l.torch.Animator(xlabel='epoch', xlim=[1, num_epochs], legend=legend)
    devices = d2l.torch.try_all_gpus()
    model = nn.DataParallel(model, device_ids=devices).to(devices[0])
    for epoch in range(num_epochs):
        model.train()
        metric = d2l.torch.Accumulator(3)
        for i, (inputs, labels) in enumerate(training_loader, 0):
            timer.start()
            l, acc = d2l.torch.train_batch_ch13(net=model, X=inputs, y=labels, loss=criterion,
                                                trainer=optimizer, devices=devices)
            metric.add(l, acc, labels.shape[0])
            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(x=epoch + (i + 1) / num_batches,
                             y=(metric[0] / metric[2], metric[1] / metric[2], None))

        valid_acc = d2l.torch.evaluate_accuracy_gpu(model, validation_loader)
        animator.add(epoch + 1, (None, None, valid_acc))
        scheduler.step()
    measures = (f'train loss {metric[0] / metric[2]:.6f}, '
                f'train acc {metric[1] / metric[2]:.6f}')

    print(measures + f'\n{metric[2] * num_epochs / timer.sum():.1f}'
                     f' examples/sec on {str(devices)}')

In [13]:
def test_accuracy(net):
    testloader = DataLoader(test_set, batch_size=8, shuffle=False, num_workers=2)
    confusion_matrix = torch.zeros(3, 3)
    correct = 0
    total = 0
    net.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            for t, p in zip(labels.view(-1), predicted.view(-1)):
                confusion_matrix[t.long(), p.long()] += 1

    print("confusion_matrix: ", confusion_matrix)
    print(confusion_matrix.diag() / confusion_matrix.sum(1))
    print("accuracy over all: ", correct / total)
    return correct / total

In [14]:
def visualize_model(model, validation_dataloader, num_images=9):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(validation_dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images // 3, 3, images_so_far)
                ax.set_title('predicted: {}'.format(class_names[preds[j]]))
                utils.showImages(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)


In [None]:
model = d2l.torch.resnet18(num_classes=3, in_channels=5)
torch.cuda.empty_cache()
train_model(model, lr=2e-4, wd=5e-4, lr_period=4, lr_decay=0.9, num_epochs=10)