In [1]:
import os

from astropy.io import fits

from matplotlib.colors import LogNorm
import matplotlib.pyplot as plt

import numpy as np

import pandas as pd

import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, progress
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.model_selection import train_test_split

import torch
from torch import nn, Tensor
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader, Dataset, RandomSampler, random_split

import torchmetrics

import torchvision
import torchvision.transforms.v2 as T

In [2]:
torch.manual_seed(123)
np.random.seed(123)

In [18]:
gpu = True
augmentation_strategy = T.TrivialAugmentWide()
model_name = "regnet"
optimiser = optim.AdamW
num_channels = 1
lr = 1e-4
num_outputs = 1
xy_size = 380
batch_size = 32
num_epochs = 30
test_split = 0.2
val_split = 0.2

In [6]:
if gpu:
    if torch.backends.mps.is_available():
        device = torch.device("mps")
        accelerator = "gpu"
    elif torch.cuda.is_available():
        device = torch.device("cuda:0")
        accelerator = "gpu"
    else:
        device = torch.device("cpu")
        accelerator = "cpu"
else:
    device = torch.device("cpu")
    accelerator = "cpu"
print(f"Using {device}")

Using mps


In [7]:
def normalise_channels(arr):
    arr -= np.min(arr)
    arr /= np.max(arr)

In [8]:
def load_fits(
    folder_path: os.PathLike,
    labels_dict: dict[str, np.ndarray],
    normalise: bool = True,
):
    data_names = os.listdir(folder_path)
    data_names = np.array([x for x in data_names if ".fits" in x])
    run_nums = np.array([int(x.split("planet")[1].split("_")[0]) for x in data_names])
    order = np.argsort(run_nums)
    run_nums = run_nums[order]
    data_names = data_names[order]
    data = dict([(name, fits.open(f"{folder_path}/{name}")[0].data.squeeze()[0]) for name in data_names])
    if normalise:
        for im in data.values():
            normalise_channels(im)
    nums = {}
    for (name, run) in zip(data_names, run_nums):
        label = labels_dict["n"][np.where(labels_dict["runs"] == run)][0]
        nums[name] = int(label)
        labels[name] = int(label > 0)
    xy_dim = data[data_names[0]].shape[1]
    X = np.empty((len(data_names), 1, xy_dim, xy_dim))
    y = np.empty((len(data_names), 1))
    for i, name in enumerate(data_names):
        X[i, :, :] = data[name]
        y[i, 0] = labels[name]
    return (X.astype(np.float32), y.astype(np.float32))

In [9]:
def load_labels(path):
    label_df = pd.read_csv("train_info.csv").drop("Unnamed: 0", axis=1)
    runs = label_df.run.to_numpy()
    Ns = label_df.n.to_numpy()
    return {
        "runs": runs,
        "n": Ns
    }

In [10]:
labels = load_labels("train_info.csv")
X, y = load_fits("Full_Train_Data", labels)

In [11]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_split)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=val_split)

In [12]:
class DiskDataset(Dataset):

    """Data loader"""

    def __init__(
        self,
        X: np.ndarray,
        y: np.ndarray,
        transform: list,
        device: torch.device,
        num_workers: int = 13,
    ) -> None:

        self.X = X
        self.y = y
        self.transform = transform
        self.num_workers = num_workers
        self.device = device

    def __len__(self) -> int:
        return len(self.X)

    def __getitem__(self, idx) -> torch.Tensor:
        x_, y_ = self.X[idx], self.y[idx]

        x_, y_ = torch.from_numpy(x_),\
                torch.from_numpy(y_)

        if self.transform:
            x_ = self.transform(x_)
        return x_.to(self.device), y_.to(self.device)

In [13]:
transform = T.Compose([
    T.Resize((xy_size, xy_size), antialias=True),
    T.Normalize(mean=[0.5], std=[0.5]),
])

if augmentation_strategy is not None:
    train_transform = T.Compose(
        [
            augmentation_strategy,
            transform
        ]
    )
else:
    train_transform = transform

In [14]:
train_data = DiskDataset(X_train, y_train, device=device, transform=train_transform)
val_data = DiskDataset(X_val, y_val, device=device, transform=transform)
test_data = DiskDataset(X_test, y_test, device=device, transform=transform)
train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size)
val_loader = DataLoader(val_data, batch_size=batch_size)

In [22]:
class CustomModel(pl.LightningModule):
    def __init__(
        self,
        num_channels:int,
        xy_dim: int,
    ):
        super().__init__()
        self.save_hyperparameters()
        

        self.criterion = nn.BCEWithLogitsLoss()

        # Initialize containers to store outputs
        self.validation_outputs = []
        self.test_outputs = []

        self.example_input_array = torch.randn((1, num_channels, xy_dim, xy_dim)).float()
        
    def forward(self, x):
        return self.model(x)

    def _process_batch(self, batch, when: str = "train"):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        self.log(f"{when}_loss", loss)
        if when != "train":
            return {f"{when}_loss": loss, "y_hat": y_hat, "y": y}
        return loss

    def training_step(self, batch, batch_idx):
        return self._process_batch(batch, when="train")

    def validation_step(self, batch, batch_idx):
        outputs = self._process_batch(batch, when="val")
        self.validation_outputs.append(outputs)
        return outputs

    def test_step(self, batch, batch_idx):
        outputs = self._process_batch(batch, when="test")
        self.test_outputs.append(outputs)
        return outputs

    def configure_optimizers(self):
        return optimiser(
            self.parameters(),
            lr=self.hparams.lr,
        )

    def _roc_epoch_end(self, outputs, when: str = "val"):
        """Logs AUC during validation/testing"""
        y_hat = torch.cat([x["y_hat"] for x in outputs]).detach().cpu().numpy()
        y = torch.cat([x["y"] for x in outputs]).detach().cpu().numpy()
        auc = self.calculate_auc(y_hat, y)
        self.log(f"{when}_auc", auc)

    def on_validation_epoch_end(self,):
        self._roc_epoch_end(self.validation_outputs, when="val")
        self.validation_outputs.clear()

    def on_test_epoch_end(self,):
        self._roc_epoch_end(self.test_outputs, when="test")
        self.test_outputs.clear()

    def calculate_auc(self, y_hat, y):
        # Apply sigmoid to predictions if using BCEWithLogitsLoss
        y_hat = torch.sigmoid(torch.tensor(y_hat).float()).numpy()
        auc = roc_auc_score(y, y_hat)
        return np.float32(auc)
    
        

class CustomEfficientNetv2(CustomModel):
    def __init__(self,
                 num_channels: int,
                 num_outputs: int,
                 lr: float,
                 xy_dim: int,
                ):
        super().__init__(num_channels, xy_dim)
        self.model = torchvision.models.efficientnet_v2_s()

        # Modify the first convolutional layer if input channels are different from 3
        if num_channels != 3:
            self.model.features[0][0] = nn.Conv2d(num_channels,
                                                  self.model.features[0][0].out_channels,
                                                  kernel_size=self.model.features[0][0].kernel_size,
                                                  stride=self.model.features[0][0].stride,
                                                  padding=self.model.features[0][0].padding,
                                                  bias=False,
                                                 )
            
        in_features = self.model.classifier[1].in_features
        self.model.classifier[1] = nn.Linear(in_features, num_outputs)

class CustomRegNet(CustomModel):
    def __init__(
        self,
        num_channels: int,
        num_outputs: int,
        lr: float,
        xy_dim: int,
    ):
        super().__init__(num_channels, xy_dim)
        self.model = torchvision.models.regnet_y_16gf()

        # Replace the first conv layer in the stem if num_channels != 3
        if num_channels != 3:
            old_conv = self.model.stem[0]
            self.model.stem[0] = nn.Conv2d(
                in_channels=num_channels,
                out_channels=old_conv.out_channels,
                kernel_size=old_conv.kernel_size,
                stride=old_conv.stride,
                padding=old_conv.padding,
                bias=old_conv.bias is not None
            )

        in_features = self.model.fc.in_features
        self.model.fc = nn.Linear(in_features, num_outputs)

    

In [23]:
if model_name == "efficientnet":
    model = CustomEfficientNetv2(
        num_channels=num_channels,
        num_outputs=num_outputs,
        lr=lr,
        xy_dim=xy_size,
    )
elif model_name == "regnet":
    model = CustomRegNet(
        num_channels=num_channels,
        num_outputs=num_outputs,
        lr=lr,
        xy_dim=xy_size,
    )

In [24]:
trainer = pl.Trainer(
    devices=1,
    accelerator=accelerator,
    max_epochs=num_epochs,
    log_every_n_steps=1,
    callbacks=[
        LearningRateMonitor("epoch"),
        progress.TQDMProgressBar(refresh_rate=1),
        EarlyStopping(
            monitor="val_auc",
            min_delta=0,
            patience=20,
            verbose=False,
            mode="min",
        ),
    ],
)
trainer.logger._log_graph = True
trainer.logger._default_hp_metric = None

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/jamesmitchell-white/Documents/GitHub/DeepLearnHackathon/ExoplanetSearchChallenge/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [None]:
model = model.to(device)

# fit the model
trainer.fit(
    model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)


  | Name      | Type              | Params | Mode  | In sizes         | Out sizes
---------------------------------------------------------------------------------------
0 | criterion | BCEWithLogitsLoss | 0      | train | ?                | ?        
1 | model     | RegNet            | 80.6 M | train | [1, 1, 380, 380] | [1, 1]   
---------------------------------------------------------------------------------------
80.6 M    Trainable params
0         Non-trainable params
80.6 M    Total params
322.270   Total estimated model params size (MB)
385       Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                            …

/Users/jamesmitchell-white/Documents/GitHub/DeepLearnHackathon/ExoplanetSearchChallenge/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoader` to improve performance.
/Users/jamesmitchell-white/Documents/GitHub/DeepLearnHackathon/ExoplanetSearchChallenge/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoader` to improve performance.


Training: |                                                                                                   …

Validation: |                                                                                                 …

In [None]:
model.eval()
model.to(device)
test_loader = DataLoader(test_data, batch_size=batch_size)

results = []

for X_batch in test_loader:
    X_batch = X_batch[0]  # DataLoader returns a tuple
    with torch.no_grad():
        outputs = torch.sigmoid(model(X_batch)).cpu()
    batch_results = outputs.detach().numpy().squeeze()
    results.append(batch_results)

y_pred = np.concatenate(results, axis=0)
fpr, tpr, _ = roc_curve(y_test, y_pred)
auc = roc_auc_score(y_test, y_pred)
accuracy = np.sum([round(y_pred[i]) == y_test[i] for i in range(len(y_test))]) / len(y_test)

print(f"Accuracy of {accuracy:.2}. AUC of {auc}")

In [None]:
plt.figure(figsize=(10., 7.5))

plt.plot(fpr, tpr, lw=3, c="steelblue")
plt.plot(np.linspace(0, 1, 100), np.linspace(0, 1, 100),
         c="gray", ls="--", alpha=0.5, lw=3,
         )

plt.xlabel("FPR", fontsize=14)
plt.ylabel("TPR", fontsize=14)

plt.xticks(fontsize=12)
plt.yticks(fontsize=12)

plt.show()