In [16]:
import os
import h5py

import torch
from torch import nn
import torchvision
from torchvision.transforms import v2

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from pathlib import Path

In [2]:
#!pip install astronn equinox einops

from astroNN.datasets import load_galaxy10
from astroNN.datasets.galaxy10 import galaxy10cls_lookup, galaxy10_confusion

2024-05-23 20:53:26.125299: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
# Just to make the session somewhat determinate
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
set_seed(0)

In [4]:
print(f"CUDA is available: {torch.cuda.is_available()}")

if not torch.cuda.is_available():
    print("If you want, you might want to switch to a GPU-accelerated session!")
    device = torch.device('cpu')
else:
    device = torch.device('cuda')

CUDA is available: False
If you want, you might want to switch to a GPU-accelerated session!


In [7]:
# To load images and labels (will download automatically at the first time)
# First time downloading location will be ~/.astroNN/datasets/
load_galaxy10()

/home/tristan/.astroNN/datasets/Galaxy10_DECals.h5 was found!


: 

In [None]:
useful_images = images[labels == 5]

In [None]:
train_split = 0.6
valid_split = 0.2

full_dataset = useful_images

test_split = 1 - train_split - valid_split

train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(
        full_dataset, [train_split, valid_split, test_split]
)

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=128, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)

In [39]:
import fastprogress


def train(dataloader, optimizer, model, loss_fn, device, master_bar,
          transform_common=None, transform_input=None):
    """Run one training epoch.

    Args:
        dataloader (DataLoader): Torch DataLoader object to load data
        optimizer: Torch optimizer object
        model (nn.Module): Torch model to train
        loss_fn: Torch loss function
        device (torch.device): Torch device to use for training
        master_bar (fastprogress.master_bar): Will be iterated over for each
            epoch to draw batches and display training progress
        transform_common (function): Transform to apply to input and target
        transform_input (function): Transform to apply to the input for de-noising.
            By default, no transform is carried out

    Returns:
        float: Mean loss of this epoch
    """
    epoch_loss = []

    for x, _ in fastprogress.progress_bar(dataloader, parent=master_bar):
        optimizer.zero_grad()
        model.train()

        x = transform_common(x) if transform_common else x
        x_inp = transform_input(x) if transform_input else x

        # Forward pass
        x = x.to(device)
        x_inp = x_inp.to(device)
        x_hat, mu, logvar = model(x_inp)

        # Compute loss
        loss = loss_fn(x_hat, x, mu, logvar)

        # Backward pass
        loss.backward()
        optimizer.step()

        # For plotting the train loss, save it for each sample
        epoch_loss.append(loss.item())
        master_bar.child.comment = f"Train Loss: {epoch_loss[-1]:.3f}"

    # Return the mean loss and the accuracy of this epoch
    return np.mean(epoch_loss)


def validate(dataloader, model, loss_fn, device, master_bar,
             transform_common=None, transform_input=None):
    """Compute loss on validation set.

    Args:
        dataloader (DataLoader): Torch DataLoader object to load data
        model (nn.Module): Torch model to train
        loss_fn: Torch loss function
        device (torch.device): Torch device to use for training
        master_bar (fastprogress.master_bar): Will be iterated over to draw
            batches and show validation progress
        transform_common (function): Transform to apply to input and target
        transform_input (function): Transform to apply to the input for de-noising.
            By default, no transform is carried out

    Returns:
        float: Mean loss on validation set
    """
    epoch_loss = []

    model.eval()
    with torch.no_grad():
        for x, _ in fastprogress.progress_bar(dataloader, parent=master_bar):
            x = transform_common(x) if transform_common else x

            x_inp = transform_input(x) if transform_input else x

            # make a prediction on test set
            x = x.to(device)
            x_inp = x_inp.to(device)
            x_hat, mu, logvar = model(x_inp)

            # Compute loss
            loss = loss_fn(x_hat, x, mu, logvar)

            # For plotting the train loss, save it for each sample
            epoch_loss.append(loss.item())
            master_bar.child.comment = f"Valid. Loss: {epoch_loss[-1]:.3f}"

    # Return the mean loss, the accuracy and the confusion matrix
    return np.mean(epoch_loss)




def train_model(model, optimizer, loss_function, device, num_epochs,
                train_dataloader, valid_dataloader,
                transform_common=None, transform_input=None):
    """Run model training.

    Args:
        model (nn.Module): Torch model to train
        optimizer: Torch optimizer object
        loss_fn: Torch loss function for training
        device (torch.device): Torch device to use for training
        num_epochs (int): Max. number of epochs to train
        train_dataloader (DataLoader): Torch DataLoader object to load the
            training data
        valid_dataloader (DataLoader): Torch DataLoader object to load the
            test data
        transform_common (function): Transform to apply to input and target
        transform_input (function): Transform to apply to the input for de-noising.
            By default, no transform is carried out

    Returns:
        list, list: Return list of train losses, test losses.
    """
    master_bar = fastprogress.master_bar(range(num_epochs))
    epoch_list, train_losses, valid_losses = [], [], []

    master_bar.names = ["Train", "Valid."]

    for epoch in master_bar:
        # Train the model
        epoch_train_loss = train(train_dataloader, optimizer, model, loss_function, device, master_bar, transform_common, transform_input)
        # Validate the model
        epoch_valid_loss = validate(valid_dataloader, model, loss_function, device, master_bar, transform_common, transform_input)

        # Save loss and acc for plotting
        epoch_list.append(epoch + 1)
        train_losses.append(epoch_train_loss)
        valid_losses.append(epoch_valid_loss)

        graphs = [[epoch_list, train_losses], [epoch_list, valid_losses]]
        x_bounds = [1, num_epochs]

        master_bar.write(
            f"Epoch {epoch + 1}, "
            f"avg. train loss: {epoch_train_loss:.3f}, "
            f"avg. valid. loss: {epoch_valid_loss:.3f}"
        )
        master_bar.update_graph(graphs, x_bounds)


    return train_losses, valid_losses

In [5]:
class Autoencoder(nn.Module):

    def __init__(self, image_size=64,num_channels=3, latent_dims=128, num_filters=32, do_sampling=False):
        super(Autoencoder, self).__init__()

        self.latent_dims  = latent_dims
        self.image_size   = image_size
        self.num_channels = num_channels
        self.num_filters  = num_filters
        self.do_sampling  = do_sampling

        # Encoder
        self.conv_encoder = nn.Sequential(
            # TODO: Build the convolutional layers (torch.nn.Conv2d) here
            torch.nn.Conv2d(self.num_channels, self.num_channels, (4,4), 2, 1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(self.num_channels,self.num_channels, (4,4), 2, 1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(self.num_channels,self.num_channels, (4,4), 2, 1),
            torch.nn.ReLU(),
        )

        # Linear Encoder
        # TODO: Match the dimensionality of the first and last layer here!
        self.fc_lin_down = nn.Linear(64*self.num_filters, 8 * self.num_filters)
        self.fc_mu       = nn.Linear(8 * self.num_filters, self.latent_dims)
        self.fc_logvar   = nn.Linear(self.latent_dims, self.latent_dims)
        self.fc_z        = nn.Linear(self.latent_dims, 8 * self.num_filters)
        self.fc_lin_up   = nn.Linear(8 * self.num_filters, 64*self.num_filters)

        # Decoder
        self.conv_decoder = nn.Sequential(
            # TODO: Implement the reverse of the encoder here using torch.nn.ConvTranspose2d layers
            # The last activation here should be a sigmoid to keep the pixel values clipped in [0, 1)
            torch.nn.Conv2d(self.num_channels, self.num_channels, (4,4), 2, 1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(self.num_channels,self.num_channels, (4,4), 2, 1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(self.num_channels,self.num_channels, (4,4), 2, 1),
            nn.Sigmoid(),
        )

    def encode(self, x):
        ''' Encoder: output is (mean, log(variance))'''
        x       = self.conv_encoder(x)
        # Here, we resize the convolutional output appropriately for a linear layer
        # TODO: Fill in the correct dimensionality for the reordering
        x       = x.view(-1, self.num_filters * 8 * 8)
        x       = self.fc_lin_down(x)
        x       = nn.functional.relu(x)
        mu      = self.fc_mu(x)
        logvar  = self.fc_logvar(x)
        return mu, logvar

    def sample(self, mu, logvar):
        ''' Sample from Gaussian with mean `mu` and SD `sqrt(exp(logvarz))`'''
        # Only use the full mean/stddev procedure if we want to later do sampling
        # And only reparametrise if we are in training mode
        if self.training and self.do_sampling:
            std = torch.exp(logvar * 0.5)
            eps = torch.randn_like(std)
            sample = mu + (eps * std)
            return sample
        else:
            return mu

    def decode(self, z):
        '''Decoder: produces reconstruction from sample of latent z'''
        z = self.fc_z(z)
        z = nn.functional.relu(z)
        z = self.fc_lin_up(z)
        z = nn.functional.relu(z)
        # TODO: Fill in the correct dimensionality for the reordering here again
        z = z.view(-1, self.num_filters, 8, 8)
        z = self.conv_decoder(z)
        return z

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.sample(mu, logvar)
        x_hat = self.decode(z)
        if self.do_sampling:
            return x_hat, mu, logvar
        else:
            return x_hat, None, None

In [53]:
def autoencoder_loss(recon_x, x, mu=None, logvar=None):
    mse_loss = torch.nn.functional.mse_loss(recon_x, x, reduction='sum') / x.size(dim=0)

    if mu is not None and logvar is not None:
        raise NotImplementedError("Looks like you still need to implement the KL divergence loss!")
    else:
        return mse_loss

In [37]:
model = Autoencoder()
torch.save(model,Path('model1'))
model_test = torch.load(Path('model1'))


In [38]:
model_test

Autoencoder(
  (conv_encoder): Sequential(
    (0): Conv2d(3, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(3, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(3, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (5): ReLU()
  )
  (fc_lin_down): Linear(in_features=2048, out_features=256, bias=True)
  (fc_mu): Linear(in_features=256, out_features=128, bias=True)
  (fc_logvar): Linear(in_features=128, out_features=128, bias=True)
  (fc_z): Linear(in_features=128, out_features=256, bias=True)
  (fc_lin_up): Linear(in_features=256, out_features=2048, bias=True)
  (conv_decoder): Sequential(
    (0): Conv2d(3, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(3, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(3, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (5): Sigmoid()
  )
)