In [2]:
import os
import h5py

import torch
import torchvision
from torchvision.transforms import v2

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline


In [11]:
!pip install astronn equinox einops
from astroNN.datasets import load_galaxy10
from astroNN.datasets.galaxy10 import galaxy10cls_lookup, galaxy10_confusion

from pathlib import Path



In [4]:
import fastprogress

CUDA is available: False
If you want, you might want to switch to a GPU-accelerated session!


In [5]:
# First time downloading location will be ~/.astroNN/datasets/
images, labels = load_galaxy10()
imagesBarredSpiral = images[labels == 5] # class 5 is 'Barred Spiral'

/home/simon/.astroNN/datasets/Galaxy10_DECals.h5 was found!


In [6]:
train_split = 0.6
valid_split = 0.2

test_split = 1 - train_split - valid_split

train_datasetBS, valid_datasetBS, test_datasetBS = torch.utils.data.random_split(
        imagesBarredSpiral, [train_split, valid_split, test_split]
)
#train_datasetMW, valid_datasetMW, test_datasetMW = torch.utils.data.random_split(
#        imagesMilky, [train_split, valid_split, test_split]
#)

In [7]:
fig_size = 4

def show_images(images, labels=None, num_columns=5, num_rows=3):
    fig, axs = plt.subplots(nrows=num_rows, ncols=num_columns, figsize=(fig_size * num_columns, fig_size * num_rows),
                        subplot_kw={'xticks': [], 'yticks': []})

    labels = [None] * num_columns * num_rows if labels is None else labels

    for ax, image, label in zip(axs.flat, images, labels):
        # Use a grayscale colourmap if we have only a single channel
        cmap = 'gray' if image.shape[0] == 1 else None

        # Imshow expects HWC, so backtransform here again
        ax.imshow(np.transpose(image, [1, 2, 0]), cmap=cmap)
        if label:
            ax.set_title(f"{label}")

    plt.tight_layout()
    plt.show()

In [8]:
removenumber = 3*10**2
addnumber = removenumber
epochs = 20
roundsofReplace = (2045/removenumber) + (1 - 2045/removenumber%1)

removenumbertrain = removenumber * train_split
addnumbertrain = addnumber * train_split

removenumbervalid = removenumber * valid_split
addnumbervalid = addnumber * valid_split

In [9]:
def autoencoder_loss(recon_x, x, mu=None, logvar=None):
    mse_loss = torch.nn.functional.mse_loss(recon_x, x, reduction='sum') / x.size(dim=0)

    if mu is not None and logvar is not None:
        raise NotImplementedError("Looks like you still need to implement the KL divergence loss!")
    else:
        return mse_loss

In [None]:
learning_rate = 2e-4
optimizer = torch.optim.Adam(model.parameters(), learning_rate)

In [12]:
model = torch.load(Path("model1"))

AttributeError: Can't get attribute 'Autoencoder' on <module '__main__'>

In [None]:
def train(data, optimizer, model, loss_fn, device, master_bar):
    """Run one training epoch.

    Args:
        data: data
        optimizer: Torch optimizer object
        model (nn.Module): Torch model to train
        loss_fn: Torch loss function
        device (torch.device): Torch device to use for training
        master_bar (fastprogress.master_bar): Will be iterated over for each
            epoch to draw batches and display training progress

    Returns:
        float: Mean loss of this epoch
    """
    epoch_loss = []

    for x, _ in fastprogress.progress_bar(data, parent=master_bar):
        optimizer.zero_grad()
        model.train()

        # Forward pass
        x = x.to(device)
        x_inp = x.to(device)
        x_hat, mu, logvar = model(x_inp)

        # Compute loss
        loss = loss_fn(x_hat, x, mu, logvar)

        # Backward pass
        loss.backward()
        optimizer.step()

        # For plotting the train loss, save it for each sample
        epoch_loss.append(loss.item())
        master_bar.child.comment = f"Train Loss: {epoch_loss[-1]:.3f}"

    # Return the mean loss and the accuracy of this epoch
    return np.mean(epoch_loss)


def validate(data, model, loss_fn, device, master_bar):
    """Compute loss on validation set.

    Args:
        data: data
        model (nn.Module): Torch model to train
        loss_fn: Torch loss function
        device (torch.device): Torch device to use for training
        master_bar (fastprogress.master_bar): Will be iterated over to draw
            batches and show validation progress

    Returns:
        float: Mean loss on validation set
    """
    epoch_loss = []

    model.eval()
    with torch.no_grad():
        for x, _ in fastprogress.progress_bar(data, parent=master_bar):
            # make a prediction on test set
            x = x.to(device)
            x_inp = x.to(device)
            x_hat, mu, logvar = model(x_inp)

            # Compute loss
            loss = loss_fn(x_hat, x, mu, logvar)

            # For plotting the train loss, save it for each sample
            epoch_loss.append(loss.item())
            master_bar.child.comment = f"Valid. Loss: {epoch_loss[-1]:.3f}"

    # Return the mean loss, the accuracy and the confusion matrix
    return np.mean(epoch_loss)

In [None]:
def train_model(model, optimizer, loss_function, device, num_epochs,
                train_data, valid_data):
    """Run model training.

    Args:
        model (nn.Module): Torch model to train
        optimizer: Torch optimizer object
        loss_fn: Torch loss function for training
        device (torch.device): Torch device to use for training
        num_epochs (int): Max. number of epochs to train
        train_data: training data
        valid_data: validation data

    Returns:
        list, list: Return list of train losses, test losses.
    """
    master_bar = fastprogress.master_bar(range(num_epochs))
    epoch_list, train_losses, valid_losses = [], [], []

    master_bar.names = ["Train", "Valid."]

    for epoch in master_bar:
        # Train the model
        epoch_train_loss = train(train_data, optimizer, model, loss_function, device, master_bar)
        # Validate the model
        epoch_valid_loss = validate(valid_data, model, loss_function, device, master_bar)

        # Save loss and acc for plotting
        epoch_list.append(epoch + 1)
        train_losses.append(epoch_train_loss)
        valid_losses.append(epoch_valid_loss)

        graphs = [[epoch_list, train_losses], [epoch_list, valid_losses]]
        x_bounds = [1, num_epochs]

        master_bar.write(
            f"Epoch {epoch + 1}, "
            f"avg. train loss: {epoch_train_loss:.3f}, "
            f"avg. valid. loss: {epoch_valid_loss:.3f}"
        )
        master_bar.update_graph(graphs, x_bounds)


    return train_losses, valid_losses

In [None]:
def mixdata(dataBS, dataMW, quoteBS, quoteMW)

In [None]:
train = train_datasetBS
valid = valid_datasetBS
for i in range(roundsofReplace):
  train_model(model, optimizer, autoencoder_loss, device, epochs, train, valid)
  lowtrain = i*addnumbertrain % len(train_datasetMW)
  hightrain = (i+1)*addnumbertrain % len(train_datasetMW)
