# Test for the CNN Model

This notebook will be used to create and train a dummy version of the model that will be used for classifying the transit light curves.

In [1]:
# conda install pytorch torchvision -c pytorch


# Imports
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from bokeh.plotting import figure, show
from bokeh.io import output_notebook

import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from torch import nn

In [2]:
# Check whether GPU is available and choose a device to run the model on
gpu_available = torch.cuda.is_available()
device_name = "cuda" if gpu_available else "cpu"
device = torch.device(device_name)

print(f"Using {device_name}")

Using cuda


We will need to create a dummy dataset that resembles our final data.


In [3]:
def normal(dimensions, mean=0., stddev=1.):
    """Torch tensor of samples from a normal distribution, with given shape.

    Attributes:
        dimensions (tuple): shape of the output tensor 
        mean (float, optional): mean of the sampled normal distribution
        stddev (float, optional): standard deviation of the normal distribution
    
    Returns:
        torch.Tensor: PyTorch multi-dimensional matrix with the samples

    """

    return torch.from_numpy(stddev * np.random.randn(*dimensions) + mean)

def uniform(dimensions, mini=0., maxi=1.):
    """Torch tensor of samples from an uniform distribution, with given shape.

    Attributes:
        dimensions (tuple): shape of the output tensor 
        mini (float, optional): lower limit of the sampled uniform distribution
        maxi (float, optional): upper limit of the sampled uniform distribution
    
    Returns:
        torch.Tensor: PyTorch multi-dimensional matrix with the samples

    """
    return torch.from_numpy(np.random.random(dimensions)*(maxi-mini) + mini)

In [4]:
def transit_model(transit_duration, contact_ratio, time):
    """Models a transit lightcurve for the given parameters and normalized time.

    Normalization means that the curve outside the transit is centered around 0 
    and the depth of the transit is set to -1. The transit is in the middle of
    the vector, with the "greatest transit" exactly in the center.
    
    It calculates the normalized light values for the given times.

    It can work with broadcasting.
    
    Attributes:
        transit_duration (float): normalized duration of the transit with 
            respect to the orbital period
        contact_ratio (float): ratio between ingress or egress and the transit
        time_view (array): time values for which the model will be evaluated

    Returns:
        torch.Tensor: PyTorch multi-dimensional matrix. Returns values for each
        time in the input, in order.

    """
    # Calculate the times of the contact points at each lightcurve, from 0 to 1
    contact_1 = 0.5 - transit_duration/2 - contact_ratio * transit_duration
    contact_2 = 0.5 - transit_duration/2
    contact_3 = 0.5 + transit_duration/2
    contact_4 = 0.5 + transit_duration/2 + contact_ratio * transit_duration
    
    # Calculate masks for each section of the light curve
    mask_ingress = (time > contact_1) & (time <= contact_2)
    mask_transit = (time > contact_2) & (time <= contact_3)
    mask_egress  = (time > contact_3) & (time <= contact_4)
    
    # Calculate normalized light values by section of the lightcurve
    ingress = torch.cos((time-contact_1)/(contact_2-contact_1)
                        *np.pi) * 0.5 - 0.5
    transit = -1.
    egress  = torch.cos((time-contact_3)/(contact_4-contact_3)
                        *np.pi) *-0.5 - 0.5
    
    # Sum all sections
    lightcurve = (ingress * mask_ingress +
                  transit * mask_transit +
                  egress  * mask_egress)

    return lightcurve
    
def create_transit_lightcurve(len_global_lightcurve, len_local_lightcurve,
                              transit_duration, contact_ratio,
                              local_ratio = 4.,
                              noise_power=0., time_view=(-1,)):
    """Creates a local and global view for a transit-like normalized lightcurve.

    Normalization means that the curve outside the transit is centered around 0 
    and the depth of the transit is set to -1. The transit is in the middle of
    the vector, with the "greatest transit" exactly in the center. The global 
    view shows the full period, and the local view zooms on the transit, with 
    a constant width for the transit and curve before and after. 
    
    Global and local views are concatenated in the same vector.

    It can work with broadcasting.
    
    Attributes:
        len_global_lightcurve (int): number of points in the global lightcurve
        len_local_lightcurve (int): number of points in the local lightcurve
        transit_duration (float): normalized duration of the transit with 
            respect to the orbital period
        contact_ratio (float): ratio between ingress or egress and the transit
        local_ratio (float, optional): ratio between transit including contact
            and rest of the curve that is represented in the local view
        noise_power (float, optional): variance of gaussian noise to add to the 
            lightcurve. No noise by default
        time_view (tuple, optional): description of dimension along which the
            lightcurve vector will be set. 1D by default

    Returns:
        torch.Tensor: PyTorch multi-dimensional matrix. In the time view, the 
        values corresponding to the global view go first, and are followed by
        the local view.

    """
    assert type(len_global_lightcurve) == int
    assert type(len_local_lightcurve) == int
    
    # Calculate times for the local window
    local_start = 0.5 - (transit_duration * (0.5+contact_ratio)) * local_ratio
    local_end   = 0.5 + (transit_duration * (0.5+contact_ratio)) * local_ratio

    # Normalized time tensor, from 0 to 1 inclusive
    global_time = torch.linspace(0., 1., len_global_lightcurve
                                 ).view(*time_view)
    local_time  = torch.linspace(0, 1, len_local_lightcurve
                                 ).view(*time_view
                                 ) * (local_end - local_start) + local_start
    
    # Apply the transit model
    global_lightcurve = transit_model(transit_duration, 
                                      contact_ratio, 
                                      global_time)
    
    local_lightcurve  = transit_model(transit_duration, 
                                      contact_ratio, 
                                      local_time)
    
    # Calculate random noise
    global_noise = normal(tuple(global_lightcurve.size()), 
                          stddev=(noise_power**0.5).numpy())
    
    local_noise  = normal(tuple(local_lightcurve .size()), 
                          stddev=(noise_power**0.5).numpy())

    return torch.cat((global_lightcurve + global_noise,
                      local_lightcurve + local_noise), 2)


def sample_transit_lightcurves(nof_lightcurves, 
                               len_global_lightcurve, len_local_lightcurve, 
                               transit_duration_range = (0.001, 0.01),
                               contact_ratio_range = (0.1, 1.0),
                               noise_power_range = (0.001, 0.01),
                               ):
    """Creates a series of transit-like normalized lightcurves.

    Normalization means that the curve outside the transit is centered around 0 
    and the depth of the transit is set to -1. The transit is in the middle of
    the vector, with the "greatest transit" exactly in the center.
    
    Attributes:
        nof_lightcurves (int): number of lightcurves that will be created, 
            which will be stacked along the first dimension of the tensor
        len_global_lightcurve (int): number of points in the global lightcurve, 
            which will be set along the third dimension of the tensor
        len_local_lightcurve (int): number of points in the local lightcurve, 
            which will be set along the third dimension of the tensor
        transit_duration_range (tuple): range of the uniform distribution from
            which transit durations will be sampled for each light curve
        contact_ratio_range (tuple): range of the uniform distribution from
            which contact ratios will be sampled for each light curve
        noise_power_range (tuple): range of the uniform distribution from
            which noise powers will be sampled for each light curve

    Returns:
        torch.Tensor: PyTorch multi-dimensional matrix of size (nof_lightcurves,
            1, len_lightcurves), with different light curves along the third
            dimension

    """
    assert type(nof_lightcurves) == int
    assert type(len_global_lightcurve) == int
    assert type(len_local_lightcurve) == int

    # Random distribution for the normalized transit duration, defined as the 
    # ratio of the time between second and third contacts and the period
    transit_duration = uniform((nof_lightcurves, 1, 1), *transit_duration_range)

    # Random distribution for the time between first and second contacts divided 
    # by the transit duration, dependent on the relative size of planet and star
    contact_ratio = uniform((nof_lightcurves, 1, 1), *contact_ratio_range)
    
    # Random distribution for the noise power, measured as variance, for the 
    # gaussian distributions from which it will be sampled
    noise_power = uniform((nof_lightcurves, 1, 1), *noise_power_range)

    lightcurves = create_transit_lightcurve(len_global_lightcurve, len_local_lightcurve, 
                                            transit_duration, 
                                            contact_ratio, 
                                            noise_power=noise_power,
                                            time_view=(1, 1, -1))

    return lightcurves

In [5]:
def binary_model(transit_duration, contact_ratio, time):
    """Models a binary lightcurve for the given parameters and normalized time.

    Normalization means that the curve outside the transit is centered around 0 
    and the depth of the transit is set to -1. The main eclipse is in the middle 
    of the vector, with the "greatest occultation" exactly in the center.
    
    It calculates the normalized light values for the given times.

    It can work with broadcasting.
    
    Attributes:
        transit_duration (float): normalized duration of the eclipse with 
            respect to the orbital period
        contact_ratio (float): ratio between ingress or egress and the transit
        time_view (array): time values for which the model will be evaluated

    Returns:
        torch.Tensor: PyTorch multi-dimensional matrix. Returns values for each
        time in the input, in order.

    """
    # Calculate the times of the contact points at each lightcurve, from 0 to 1
    contact_1 = 0.5 - transit_duration/2 - contact_ratio * transit_duration
    contact_2 = 0.5 - transit_duration/2
    contact_3 = 0.5 + transit_duration/2
    contact_4 = 0.5 + transit_duration/2 + contact_ratio * transit_duration
    
    # Parameter for the model
    contact_depth = contact_ratio * np.exp(-contact_ratio)*0.2-1
    
    # Calculate masks for each section of the light curve
    mask_ingress = (time > contact_1) & (time <= contact_2)
    mask_transit = (time > contact_2) & (time <= contact_3)
    mask_egress  = (time > contact_3) & (time <= contact_4)
    
    # Calculate normalized light values by section of the lightcurve
    ingress = (time-contact_1) / (contact_2-contact_1) * contact_depth
    transit = -(1+contact_depth)*torch.cos(
               (time-0.5)*np.pi/(contact_3-contact_2))+contact_depth
    egress  = contact_depth - (time-contact_3) / (
                               contact_4-contact_3) * contact_depth 

    # Sum all sections
    lightcurve = (ingress * mask_ingress +
                  transit * mask_transit +
                  egress  * mask_egress)

    return lightcurve


def create_binary_lightcurve(len_global_lightcurve, len_local_lightcurve,
                              transit_duration, contact_duration,
                              local_ratio = 4.,
                              noise_power=0., time_view=(-1,)):
    """Creates a local and global view for a binary eclipse-like normalized 
    lightcurve.

    Normalization means that the curve outside the eclipse is centered around 0 
    and the depth of the eclipse is set to -1. The main eclipse is in the 
    middle of the vector, with the "greatest occultation" exactly in the center. 
    
    The global view shows the full period, and the local view zooms on the 
    eclipse, with a constant width for the transit and curve before and after. 
    
    Global and local views are concatenated in the same vector.

    It can work with broadcasting.
    
    Attributes:
        len_global_lightcurve (int): number of points in the global lightcurve
        len_local_lightcurve (int): number of points in the local lightcurve
        transit_duration (float): normalized duration of the eclipse with 
            respect to the orbital period
        contact_duration (float): normalized duration of ingress or egress with
            respect to the orbital period
        local_ratio (float, optional): ratio between transit including contact
            and rest of the curve that is represented in the local view
        noise_power (float, optional): variance of gaussian noise to add to the 
            lightcurve. No noise by default
        time_view (tuple, optional): description of dimension along which the
            lightcurve vector will be set. 1D by default

    Returns:
        torch.Tensor: PyTorch multi-dimensional matrix. In the time view, the 
        values corresponding to the global view go first, and are followed by
        the local view.

    """
    assert type(len_global_lightcurve) == int
    assert type(len_local_lightcurve) == int
    
    # Parameters for the model
    contact_ratio = contact_duration/transit_duration
    
    # Calculate times for the local window
    local_start = 0.5 - (transit_duration * (0.5+contact_ratio)) * local_ratio
    local_end   = 0.5 + (transit_duration * (0.5+contact_ratio)) * local_ratio

    # Normalized time tensor, from 0 to 1 inclusive
    global_time = torch.linspace(0., 1., len_global_lightcurve
                                 ).view(*time_view)
    local_time  = torch.linspace(0, 1, len_local_lightcurve
                                 ).view(*time_view
                                 ) * (local_end - local_start) + local_start
    
    # Apply the transit model
    global_lightcurve = binary_model(transit_duration, 
                                     contact_ratio, 
                                     global_time)
    
    local_lightcurve  = binary_model(transit_duration, 
                                     contact_ratio, 
                                     local_time)
    
    # Calculate random noise
    global_noise = normal(tuple(global_lightcurve.size()), 
                          stddev=(noise_power**0.5).numpy())
    
    local_noise  = normal(tuple(local_lightcurve .size()), 
                          stddev=(noise_power**0.5).numpy())

    return torch.cat((global_lightcurve + global_noise,
                      local_lightcurve + local_noise), 2)


def sample_binary_lightcurves(nof_lightcurves, 
                              len_global_lightcurve, len_local_lightcurve, 
                              transit_duration_range = (0.001, 0.01),
                              contact_duration_range = (0.001, 0.01),
                              noise_power_range = (0.001, 0.01),
                              ):
    """Creates a binary eclipse-like normalized lightcurves.

    Normalization means that the curve outside the transit is centered around 0 
    and the depth of the transit is set to -1. The transit is in the middle of
    the vector, with the "greatest transit" exactly in the center.
    
    Attributes:
        nof_lightcurves (int): number of lightcurves that will be created, 
            which will be stacked along the first dimension of the tensor
        len_global_lightcurve (int): number of points in the global lightcurve, 
            which will be set along the third dimension of the tensor
        len_local_lightcurve (int): number of points in the local lightcurve, 
            which will be set along the third dimension of the tensor
        transit_duration_range (tuple): range of the uniform distribution from
            which transit durations will be sampled for each light curve
        contact_duration_range (tuple): range of the uniform distribution from
            which contact durations will be sampled for each light curve
        noise_power_range (tuple): range of the uniform distribution from
            which noise powers will be sampled for each light curve

    Returns:
        torch.Tensor: PyTorch multi-dimensional matrix of size (nof_lightcurves,
            1, len_lightcurves), with different light curves along the third
            dimension

    """
    assert type(nof_lightcurves) == int
    assert type(len_global_lightcurve) == int
    assert type(len_local_lightcurve) == int

    # Random distribution for the normalized transit duration, defined as the 
    # ratio of the time between second and third contacts and the period
    transit_duration = uniform((nof_lightcurves, 1, 1), *transit_duration_range)

    # Random distribution for the time between first and second contacts divided 
    # by the transit duration, dependent on the relative size of planet and star
    contact_duration = uniform((nof_lightcurves, 1, 1), *contact_duration_range)
    
    # Random distribution for the noise power, measured as variance, for the 
    # gaussian distributions from which it will be sampled
    noise_power = uniform((nof_lightcurves, 1, 1), *noise_power_range)

    lightcurves = create_binary_lightcurve(len_global_lightcurve, len_local_lightcurve,
                                           transit_duration,
                                           contact_duration, 
                                           noise_power=noise_power,
                                           time_view=(1, 1, -1))

    return lightcurves

In [6]:
# Create dummy tensors for input

nof_transit_like = 5000  # Number of samples of transit-like curves
nof_binary_like = 5000  # Number of samples of binary eclipse-like curves

len_global_lightcurves = 2049  # Length of first column of input array, should be odd
len_local_lightcurves  =  257  # Length of second column of input array, should be odd
len_extra_parameters   =    0  # Length of extra parameters for the model input

# Labels will be a categorical variable directly
transit_lightcurves = sample_transit_lightcurves(
                          nof_transit_like, 
                          len_global_lightcurves,
                          len_local_lightcurves,
                          transit_duration_range = (0.001, 0.01),
                          contact_ratio_range = (0.1, 1.0),
                          noise_power_range = (0.0001, 0.001),
                          )
transit_label = torch.zeros(nof_transit_like)

binary_lightcurves = sample_binary_lightcurves(
                          nof_binary_like,  
                          len_global_lightcurves,
                          len_local_lightcurves,
                          transit_duration_range = (0.001, 0.01),
                          contact_duration_range = (0.001, 0.01),
                          noise_power_range = (0.0001, 0.001),
                          )
binary_label  = torch.ones(nof_binary_like)

Let's plot from the planetary transit dataset.

In [7]:
output_notebook()
global_transit_plot = figure(x_axis_label=("Normalized time"), 
                y_axis_label=("Normalized light intensity"), 
                plot_width=800, plot_height=350)
global_transit_plot.scatter(
    x=np.linspace(0, 1, len_global_lightcurves), 
    y=transit_lightcurves[0, 0, :len_global_lightcurves].cpu().numpy().flatten()
                            )

local_transit_plot = figure(x_axis_label=("Normalized time"), 
                y_axis_label=("Normalized light intensity"), 
                plot_width=800, plot_height=350)
local_transit_plot.scatter(
    x=np.linspace(0, 1, len_local_lightcurves), 
    y=transit_lightcurves[0, 0, len_global_lightcurves:].cpu().numpy().flatten()
                           )


show(global_transit_plot)
show(local_transit_plot)

And now from the binary eclipse dataset.

In [8]:
output_notebook()
global_binary_plot = figure(x_axis_label=("Normalized time"), 
                y_axis_label=("Normalized light intensity"), 
                plot_width=800, plot_height=350)
global_binary_plot.scatter(
    x=np.linspace(0, 1, len_global_lightcurves), 
    y=binary_lightcurves[0, 0, :len_global_lightcurves].cpu().numpy().flatten(),
    color="red"
                           )

local_binary_plot = figure(x_axis_label=("Normalized time"), 
                y_axis_label=("Normalized light intensity"), 
                plot_width=800, plot_height=350)
local_binary_plot.scatter(
    x=np.linspace(0, 1, len_local_lightcurves), 
    y=binary_lightcurves[0, 0, len_global_lightcurves:].cpu().numpy().flatten(),
    color="red"
                          )


show(global_binary_plot)
show(local_binary_plot)

In [9]:
# Concatenate tensors from both types and store them in GPU if available
X, y = (torch.cat((binary_lightcurves, 
                   transit_lightcurves,
                   ), dim=0).to(
            device, dtype=torch.float), 
        torch.cat((transit_label, 
                   binary_label,
                   ), dim=0).to(
            device, dtype=torch.long))

# Create an iterable dataset from the input and label tensors
dataset = TensorDataset(X, y)

In [10]:
# Perform the train/test/validation split
train_size = int(0.7 * len(y))
test_size  = int(0.2 * len(y))
val_size = len(y) - train_size - test_size
train_dataset, test_dataset, val_dataset = torch.utils.data.random_split(
                                             dataset, 
                                             [train_size, test_size, val_size])

In [11]:
# Choose a batch size and create a data loader from the dataset

bs = 50  # Batch size
train_dl = DataLoader(train_dataset, batch_size=bs, shuffle=True)  # Needs shuffling to prevent correlation between batches
test_dl  = DataLoader(test_dataset,  batch_size=bs * 2)  # Test can use higher batch size because it needs less memory
valid_dl = DataLoader(val_dataset, batch_size=bs * 2)  # Validation can use higher batch size because it needs less memory

In [12]:
class Test_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(1, 8, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv1d(8, 8, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv1d(8, 4, kernel_size=3, stride=1, padding=1)
        self.fc1   = nn.Linear(int((len_global_lightcurves+len_local_lightcurves )/4)*4, 50)
        self.fc2   = nn.Linear(50, 2)
        
    def forward(self, xb):  # xb is of size (1, L)
        
        # Convolutions and pooling
        xb = F.relu(self.conv1(xb))  # Size (8, L)
        xb = F.relu(self.conv2(xb))  # Size (8, L)
        xb = F.relu(self.conv3(xb))  # Size (4, L)
        xb = F.max_pool1d(xb, 4)     # Size (4, floor(L/4))
        
        # Reshape to input the linear layer
        # -1 infers the necessary dimesion from the rest
        xb = xb.view(-1, int(int((len_global_lightcurves+len_local_lightcurves )/4)*4))

        # Apply fully connected layers
        xb = F.relu(self.fc1(xb))    # Size (1, 50)
        xb = self.fc2(xb)            # Size (1, 2)
        
        return xb

In [13]:
class AstroNET_v1(nn.Module):
    def __init__(self, 
                 len_global_lightcurves = 2049, 
                 len_local_lightcurves = 257, 
                 len_extra_parameters = 0,
                 len_fully_connected = 512,
                 input_channels = 1, 
                 output_classes = 2,
                 pooling_type='max'):
        
        super().__init__()
        
        # Pooling dimensional reduction funciton. Order is number of poolings that the
        # column uses. Assumes that:   2*padding = kernel_size-1 and stride = 2
        def pooling_reduction(input_dim, order):
            if order == 1:
                output_dim = int(((input_dim - 1)/2 + 1)//1)
                return output_dim
            else:
                next_dim = int(((input_dim - 1)/2 + 1)//1)
                return pooling_reduction(next_dim, order-1)
        
        
        # General configuration:
        self.input_channels = input_channels
        
        self.len_global_lightcurves = len_global_lightcurves
        self.len_local_lightcurves = len_local_lightcurves
        self.len_extra_parameters = len_extra_parameters
        
        self.len_total_input = (len_global_lightcurves + 
                                len_local_lightcurves + 
                                len_extra_parameters)
        
        self.len_fully_connected = len_fully_connected
        self.output_classes = output_classes
        
                                
        # Calculate the length of the vectors after the convolutional columns
        self.len_global_col = pooling_reduction(self.len_global_lightcurves, 5) * 256
        self.len_local_col  = pooling_reduction(self.len_local_lightcurves, 2) * 32
                                
        # Calculate the input size for the first fully connected layer
        self.len_fc_input = (self.len_global_col + 
                             self.len_local_col + 
                             self.len_extra_parameters * self.input_channels)
        
                                
        # Layers for convolutional columns of the model
        # Layers with same config must be repeated because they will need different weights
        c = self.input_channels
        
        # Convolutions for global view column
        self.conv_5_16_g_a  = nn.Conv1d(  c,  16, kernel_size=5, stride=1, padding=2)
        self.conv_5_16_g_b  = nn.Conv1d( 16,  16, kernel_size=5, stride=1, padding=2)
        
        self.conv_5_32_g_a  = nn.Conv1d( 16,  32, kernel_size=5, stride=1, padding=2)
        self.conv_5_32_g_b  = nn.Conv1d( 32,  32, kernel_size=5, stride=1, padding=2)
        
        self.conv_5_64_g_a  = nn.Conv1d( 32,  64, kernel_size=5, stride=1, padding=2)
        self.conv_5_64_g_b  = nn.Conv1d( 64,  64, kernel_size=5, stride=1, padding=2)
        
        self.conv_5_128_g_a = nn.Conv1d( 64, 128, kernel_size=5, stride=1, padding=2)
        self.conv_5_128_g_b = nn.Conv1d(128, 128, kernel_size=5, stride=1, padding=2)
        
        self.conv_5_256_g_a = nn.Conv1d(128, 256, kernel_size=5, stride=1, padding=2)
        self.conv_5_256_g_b = nn.Conv1d(256, 256, kernel_size=5, stride=1, padding=2)
                                
        # Convolutions for local view column   
        self.conv_5_16_l_a  = nn.Conv1d(  c,  16, kernel_size=5, stride=1, padding=2)
        self.conv_5_16_l_b  = nn.Conv1d( 16,  16, kernel_size=5, stride=1, padding=2)
        
        self.conv_5_32_l_a  = nn.Conv1d( 16,  32, kernel_size=5, stride=1, padding=2)
        self.conv_5_32_l_b  = nn.Conv1d( 32,  32, kernel_size=5, stride=1, padding=2)
        
                                
        # Pooling layers
        # These contain no parameters, so they can be shared
        if pooling_type == 'max':
            self.pool_5_2     = nn.MaxPool1d(5, stride=2, padding=2)
            self.pool_7_2     = nn.MaxPool1d(7, stride=2, padding=3)
        elif pooling_type == 'avg':
            self.pool_5_2     = nn.AvgPool1d(5, stride=2, padding=2)
            self.pool_7_2     = nn.AvgPool1d(7, stride=2, padding=3)
        
        # Dense layers for classification of extracted features
        self.fc_512_a     = nn.Linear(self.len_fc_input, self.len_fully_connected)
        self.fc_512_b     = nn.Linear(self.len_fully_connected, self.len_fully_connected)
        self.fc_512_c     = nn.Linear(self.len_fully_connected, self.len_fully_connected)
        self.fc_512_d     = nn.Linear(self.len_fully_connected, self.len_fully_connected)
        
        self.fc_out       = nn.Linear(self.len_fully_connected, self.output_classes)
        
    def forward(self, xb):  # xb is of size (batch_size, input_channels, len_total_input)
        
        batch_size     = xb.size()[0]
        
        # Extract input of different columns for whole batch and all channels
        gb, lb, eb = torch.split(xb, [self.len_global_lightcurves,
                                      self.len_local_lightcurves,
                                      self.len_extra_parameters],
                                 dim=2)
                            
        # Convolutions for global view
        gb = F.relu(self.conv_5_16_g_a (gb))
        gb = F.relu(self.conv_5_16_g_b (gb))
        gb = self.pool_5_2(gb)        
                                
        gb = F.relu(self.conv_5_32_g_a (gb))
        gb = F.relu(self.conv_5_32_g_b (gb))
        gb = self.pool_5_2(gb)                        
                                
        gb = F.relu(self.conv_5_64_g_a (gb))
        gb = F.relu(self.conv_5_64_g_b (gb))
        gb = self.pool_5_2(gb)                        
                                
        gb = F.relu(self.conv_5_128_g_a(gb))
        gb = F.relu(self.conv_5_128_g_b(gb))
        gb = self.pool_5_2(gb)                        
                                
        gb = F.relu(self.conv_5_256_g_a(gb))
        gb = F.relu(self.conv_5_256_g_b(gb))
        gb = self.pool_5_2(gb)
        
        gb = torch.flatten(gb, 1, 2)  # Flatten channels and features but NOT batches
                                
        # Convolutions for local view
        lb = F.relu(self.conv_5_16_l_a (lb))
        lb = F.relu(self.conv_5_16_l_b (lb))
        lb = self.pool_7_2(lb)        
                                
        lb = F.relu(self.conv_5_32_l_a (lb))
        lb = F.relu(self.conv_5_32_l_b (lb))
        lb = self.pool_7_2(lb)
        
        lb = torch.flatten(lb, 1, 2)
                                
        # Reshape extra features
        eb = torch.flatten(eb, 1, 2)
                                
        # Concatenate results maintaining batch positioning in first dimension
        fb = torch.cat((gb, lb, eb), dim=1)

        # Apply fully connected layers
        fb = F.relu(self.fc_512_a(fb))
        fb = F.relu(self.fc_512_b(fb))
        fb = F.relu(self.fc_512_c(fb))
        fb = F.relu(self.fc_512_d(fb))

        # Output layer
        fb = self.fc_out(fb)
        
        return fb

In [63]:
class ExoplaNET_v1(nn.Module):
    def __init__(self, 
                 len_global_lightcurves = 2049, 
                 len_local_lightcurves = 257, 
                 len_secondary_lightcurves = 0, 
                 len_extra_parameters = 0,
                 len_fully_connected = 512,
                 input_channels = 1, 
                 output_classes = 2,
                 pooling_type='max'):
        
        super().__init__()
        
        # Pooling dimensional reduction funciton. Order is number of poolings that the
        # column uses. Assumes that:   2*padding = kernel_size-1 and stride = 2
        def pooling_reduction(input_dim, order):
            if order == 1:
                output_dim = int(((input_dim - 1)/2 + 1)//1)
                return output_dim
            else:
                next_dim = int(((input_dim - 1)/2 + 1)//1)
                return pooling_reduction(next_dim, order-1)
        
        
        # General configuration:
        self.len_global_lightcurves    = len_global_lightcurves
        self.len_local_lightcurves     = len_local_lightcurves
        self.len_secondary_lightcurves = len_secondary_lightcurves
        
        self.len_extra_parameters = len_extra_parameters
        self.input_channels = input_channels
        
        self.len_total_input = (len_global_lightcurves + 
                                len_local_lightcurves + 
                                len_secondary_lightcurves +
                                len_extra_parameters)
        
        self.len_fully_connected = len_fully_connected
        self.output_classes = output_classes
        
                                
        # Calculate the length of the vectors after the convolutional columns
        self.len_global_col     = pooling_reduction(self.len_global_lightcurves,    5) * 256
        self.len_local_col      = pooling_reduction(self.len_local_lightcurves,     2) *  32
        self.len_secondary_col  = pooling_reduction(self.len_secondary_lightcurves, 2) *  32
                                
        # Calculate the input size for the first fully connected layer
        self.len_fc_input = (self.len_global_col + 
                             self.len_local_col + 
                             self.len_secondary_col +
                             self.len_extra_parameters * self.input_channels)
        
                                
        # Layers for convolutional columns of the model
        # Layers with same config must be repeated because they will need different weights
        c = self.input_channels
        
        # Convolutions for global view column
        self.conv_5_16_g_a  = nn.Conv1d(  c,  16, kernel_size=5, stride=1, padding=2)
        self.conv_5_16_g_b  = nn.Conv1d( 16,  16, kernel_size=5, stride=1, padding=2)
        
        self.conv_5_32_g_a  = nn.Conv1d( 16,  32, kernel_size=5, stride=1, padding=2)
        self.conv_5_32_g_b  = nn.Conv1d( 32,  32, kernel_size=5, stride=1, padding=2)
        
        self.conv_5_64_g_a  = nn.Conv1d( 32,  64, kernel_size=5, stride=1, padding=2)
        self.conv_5_64_g_b  = nn.Conv1d( 64,  64, kernel_size=5, stride=1, padding=2)
        
        self.conv_5_128_g_a = nn.Conv1d( 64, 128, kernel_size=5, stride=1, padding=2)
        self.conv_5_128_g_b = nn.Conv1d(128, 128, kernel_size=5, stride=1, padding=2)
        
        self.conv_5_256_g_a = nn.Conv1d(128, 256, kernel_size=5, stride=1, padding=2)
        self.conv_5_256_g_b = nn.Conv1d(256, 256, kernel_size=5, stride=1, padding=2)
                                
        # Convolutions for local view column   
        self.conv_5_16_l_a  = nn.Conv1d(  c,  16, kernel_size=5, stride=1, padding=2)
        self.conv_5_16_l_b  = nn.Conv1d( 16,  16, kernel_size=5, stride=1, padding=2)
        
        self.conv_5_32_l_a  = nn.Conv1d( 16,  32, kernel_size=5, stride=1, padding=2)
        self.conv_5_32_l_b  = nn.Conv1d( 32,  32, kernel_size=5, stride=1, padding=2)

        # Convolutions for secondary view column   
        self.conv_5_16_s_a  = nn.Conv1d(  c,  16, kernel_size=5, stride=1, padding=2)
        self.conv_5_16_s_b  = nn.Conv1d( 16,  16, kernel_size=5, stride=1, padding=2)
        
        self.conv_5_32_s_a  = nn.Conv1d( 16,  32, kernel_size=5, stride=1, padding=2)
        self.conv_5_32_s_b  = nn.Conv1d( 32,  32, kernel_size=5, stride=1, padding=2)
                                
        # Pooling layers
        # These contain no parameters, so they can be shared
        if pooling_type == 'max':
            self.pool_5_2     = nn.MaxPool1d(5, stride=2, padding=2)
            self.pool_7_2     = nn.MaxPool1d(7, stride=2, padding=3)
        elif pooling_type == 'avg':
            self.pool_5_2     = nn.AvgPool1d(5, stride=2, padding=2)
            self.pool_7_2     = nn.AvgPool1d(7, stride=2, padding=3)
        
        # Dense layers for classification of extracted features
        self.fc_512_a     = nn.Linear(self.len_fc_input, self.len_fully_connected)
        self.fc_512_b     = nn.Linear(self.len_fully_connected, self.len_fully_connected)
        self.fc_512_c     = nn.Linear(self.len_fully_connected, self.len_fully_connected)
        self.fc_512_d     = nn.Linear(self.len_fully_connected, self.len_fully_connected)
        
        self.fc_out       = nn.Linear(self.len_fully_connected, self.output_classes)
        
    def forward(self, xb):  # xb is of size (batch_size, input_channels, len_total_input)
        
        batch_size     = xb.size()[0]
        
        # Extract input of different columns for whole batch and all channels
        gb, lb, sb, eb = torch.split(xb, [self.len_global_lightcurves,
                                          self.len_local_lightcurves,
                                          self.len_secondary_lightcurves,
                                          self.len_extra_parameters],
                                     dim=2)
        
        if min(gb.size()) > 0:  # Only run the column if the input exists
            # Convolutions for global view
            gb = F.relu(self.conv_5_16_g_a (gb))
            gb = F.relu(self.conv_5_16_g_b (gb))
            gb = self.pool_5_2(gb)        

            gb = F.relu(self.conv_5_32_g_a (gb))
            gb = F.relu(self.conv_5_32_g_b (gb))
            gb = self.pool_5_2(gb)                        

            gb = F.relu(self.conv_5_64_g_a (gb))
            gb = F.relu(self.conv_5_64_g_b (gb))
            gb = self.pool_5_2(gb)                        

            gb = F.relu(self.conv_5_128_g_a(gb))
            gb = F.relu(self.conv_5_128_g_b(gb))
            gb = self.pool_5_2(gb)                        

            gb = F.relu(self.conv_5_256_g_a(gb))
            gb = F.relu(self.conv_5_256_g_b(gb))
            gb = self.pool_5_2(gb)
        
        gb = torch.flatten(gb, 1, 2)  # Flatten channels and features but NOT batches
        
        if min(lb.size()) > 0:
            # Convolutions for local view
            lb = F.relu(self.conv_5_16_l_a (lb))
            lb = F.relu(self.conv_5_16_l_b (lb))
            lb = self.pool_7_2(lb)        

            lb = F.relu(self.conv_5_32_l_a (lb))
            lb = F.relu(self.conv_5_32_l_b (lb))
            lb = self.pool_7_2(lb)

            lb = torch.flatten(lb, 1, 2)
        
        if min(sb.size()) > 0:
            # Convolutions for secondary view
            sb = F.relu(self.conv_5_16_s_a (sb))
            sb = F.relu(self.conv_5_16_s_b (sb))
            sb = self.pool_7_2(sb)        

            sb = F.relu(self.conv_5_32_l_a (sb))
            sb = F.relu(self.conv_5_32_l_b (sb))
            sb = self.pool_7_2(sb)
        
        sb = torch.flatten(sb, 1, 2)
                                
        # Reshape extra features
        eb = torch.flatten(eb, 1, 2)
                                
        # Concatenate results maintaining batch positioning in first dimension
        fb = torch.cat((gb, lb, sb, eb), dim=1)

        # Apply fully connected layers
        fb = F.relu(self.fc_512_a(fb))
        fb = F.relu(self.fc_512_b(fb))
        fb = F.relu(self.fc_512_c(fb))
        fb = F.relu(self.fc_512_d(fb))

        # Output layer
        fb = self.fc_out(fb)
        
        return fb

In [64]:
def get_model():
    model = Test_CNN()
    return model, optim.SGD(model.parameters(), lr=lr)

In [65]:
def loss_batch(model, loss_func, xb, yb, opt=None):
    loss = loss_func(model(xb), yb)

    # If an optimizer is used, then run as if trianing, otherwise as if testing
    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    return loss.item(), len(xb)

In [66]:
def acc_batch(model, xb, yb):
    max_vals, max_indices = torch.max(model(xb), 1)  # Maximum along output dimension (0 is batch dimension)
    accuracy = float((max_indices == yb).sum().float()/len(yb))
    return accuracy, len(yb)

In [67]:
def fit(epochs, model, loss_func, opt, train_dl, test_dl, verbose=True):
    # Initialize the lists were progress will be recorded
    epoch_record, train_loss_record, test_loss_record, test_acc_record = [], [], [], []
    
    # For each epoch, train the model over the whole dataset in batches
    for epoch in range(epochs):
        
        # Training
        model.train()  # Sets the model to training mode
        train_losses, nums = zip(  
            *[loss_batch(model, loss_func, xb, yb, opt) for xb, yb in train_dl]
            )  # Evaluates the cross entropy for each batch and performs backpropagation and step
        train_loss = np.sum(np.multiply(train_losses, nums)) / np.sum(nums)  # Average the individual losses of batches

        # Evaluation
        model.eval()  # Sets the model to evaluation mode (would activate dropout and batchnorm)
        with torch.no_grad():  # This will deactivate the autograd engine and save memory
            test_losses, n_loss = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in test_dl]
                )  # Evaluates the cross entropy for each test batch, but does not backpropagate
            test_accs, n_acc = zip(*[acc_batch(model, xb, yb) for xb, yb in test_dl])  # Calculates accuracy as well
            
        test_loss = np.sum(np.multiply(test_losses, n_loss)) / np.sum(n_loss)  # Average the individual losses in the batch
        test_acc  = np.sum(np.multiply(test_accs,    n_acc)) / np.sum(n_acc )  # Average the individual accuracies in the batch
        
        # Record results for this epoch
        epoch_record.append(epoch)
        train_loss_record.append(train_loss)
        test_loss_record.append(test_loss)
        test_acc_record.append(test_acc)
        
        # Print results for this epoch
        if verbose:
            print(f"Epoch: {(epoch+1):3}    Train loss: {train_loss:7.5f}    Test loss: {test_loss:7.5f}    Test accuracy: {test_acc: 5.3f}")
    
    return epoch_record, train_loss_record, test_loss_record, test_acc_record

In [68]:
learning_rate = 0.05
momentum      = 0.1
epochs        = 100

loss_func = F.cross_entropy
#model = AstroNET_v1().to(device)
model = ExoplaNET_v1(#len_extra_parameters = 1
                    ).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)

In [69]:
# Train the model and record the progress at each epoch
epoch_record, train_loss_record, test_loss_record, test_acc_record = fit(
    epochs, model, loss_func, optimizer, train_dl, test_dl)

Epoch:   1    Train loss: 0.69331    Test loss: 0.69322    Test accuracy:  0.482
Epoch:   2    Train loss: 0.69333    Test loss: 0.69323    Test accuracy:  0.482
Epoch:   3    Train loss: 0.69336    Test loss: 0.69401    Test accuracy:  0.482
Epoch:   4    Train loss: 0.69333    Test loss: 0.69516    Test accuracy:  0.482
Epoch:   5    Train loss: 0.69342    Test loss: 0.69400    Test accuracy:  0.482
Epoch:   6    Train loss: 0.69327    Test loss: 0.69268    Test accuracy:  0.518
Epoch:   7    Train loss: 0.69329    Test loss: 0.69267    Test accuracy:  0.518
Epoch:   8    Train loss: 0.69340    Test loss: 0.69383    Test accuracy:  0.482
Epoch:   9    Train loss: 0.69336    Test loss: 0.69392    Test accuracy:  0.482
Epoch:  10    Train loss: 0.69334    Test loss: 0.69381    Test accuracy:  0.482
Epoch:  11    Train loss: 0.69332    Test loss: 0.69329    Test accuracy:  0.482
Epoch:  12    Train loss: 0.69339    Test loss: 0.69309    Test accuracy:  0.518
Epoch:  13    Train loss: 0.

In [70]:
output_notebook()
loss_plot = figure(x_axis_label=("Epoch"), 
                   y_axis_label=("Cross-entropy loss"), 
                   plot_width=800, plot_height=350)
loss_plot.line(
    x=epoch_record, 
    y=test_loss_record,
    color="blue",
    line_width=2, alpha=0.8, legend_label="Test Loss"
               )
loss_plot.line(
    x=epoch_record, 
    y=list(train_loss_record),
    color="orange",
    line_width=2, alpha=0.8, legend_label="Training Loss"
               )

loss_plot.line(
    x=epoch_record, 
    y=list(test_acc_record),
    color="green",
    line_width=2, alpha=0.8, legend_label="Accuracy"
               )

show(loss_plot)

In [None]:
val_losses = [loss_batch(model, loss_func, xb, yb, opt=None)[0] for xb, yb in valid_dl]
val_accurs = [acc_batch(model, xb, yb)[0] for xb, yb in valid_dl]
print(f"Validation cross entropy loss: {sum(val_losses)/len(val_losses)}")
print(f"Validation accuracy: {sum(val_accurs)/len(val_accurs)}")

In [49]:
#torch.save(model.state_dict(), "astronet_trained_w_sim_v01.pt")
#torch.save(optimizer.state_dict(), "astronet_trained_w_sim_v01.opt")

In [50]:
def pooling_reduction(input_dim, order):
            if order == 1:
                output_dim = int(((input_dim - 1)/2 + 1)//1)
                return output_dim
            else:
                next_dim = int(((input_dim - 1)/2 + 1)//1)
                return pooling_reduction(next_dim, order-1)

In [61]:
lb = torch.ones([50, 1, 5])
sb = torch.zeros([50, 1, 0])
sb = torch.flatten(sb, 1, 2)
lb = torch.flatten(lb, 1, 2)