In [None]:
import os
import numpy as np
import astropy.io.fits as pyfits
from scipy.ndimage import gaussian_filter
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [None]:
### CBAM Block
class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  # 全局平均池化
        self.max_pool = nn.AdaptiveMaxPool2d(1)  # 全局最大池化
        
        self.mlp = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction_ratio, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction_ratio, in_channels, 1)
        )
        
        nn.init.kaiming_normal_(self.mlp[0].weight)
        nn.init.zeros_(self.mlp[0].bias)
        nn.init.kaiming_normal_(self.mlp[2].weight)
        nn.init.zeros_(self.mlp[2].bias)

    def forward(self, x):
        avg_out = self.mlp(self.avg_pool(x))
        max_out = self.mlp(self.max_pool(x))
        out = avg_out + max_out  # 通道特征融合
        return torch.sigmoid(out)  # 输出归一化权重

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        assert kernel_size in (3, 7), "kernel size should be 3 or 7"
        padding = (kernel_size - 1) // 2
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)  # 平均响应
        max_out, _ = torch.max(x, dim=1, keepdim=True)  # 最大响应
        concat = torch.cat([avg_out, max_out], dim=1)  # 特征拼接
        out = self.conv(concat)
        return self.sigmoid(out)

class CBAM(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16, kernel_size=7):
        super().__init__()
        self.ca = ChannelAttention(in_channels, reduction_ratio)
        self.sa = SpatialAttention(kernel_size)

    def forward(self, x):
        x = x * self.ca(x)
        return x * self.sa(x)

### ResNet Convolution Block
class ResidualConv(nn.Module):
  def __init__(self, input_dim, output_dim, stride=1, padding=1):
      super().__init__()
      self.conv_block = nn.Sequential(
          nn.BatchNorm2d(input_dim), nn.ReLU(),
          nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=padding),
          nn.BatchNorm2d(output_dim), nn.ReLU(),
          nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1)
      )
      self.conv_skip = nn.Sequential(
          nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
          nn.BatchNorm2d(output_dim)
      )
  def forward(self, x):
      return self.conv_block(x) + self.conv_skip(x)


class MyModel(nn.Module):
    def __init__(self, in_channels=1, d=32,dp=0.5):
        super().__init__()
        self.d = d
        self.dp = dp
        self.encoder1 = ResidualConv(in_channels, d)
        self.encoder2 = ResidualConv(d, d*2)
        self.encoder3 = ResidualConv(d*2, d*4)
        self.encoder4 = ResidualConv(d*4, d*8)
        self.encoder5 = ResidualConv(d*8, d*16)

        self.pool = nn.MaxPool2d(2)

        self.mid = self.conv_block(d*16, d*32)
        
        self.up5 = nn.ConvTranspose2d(d*32, d*16, 2, stride=2)
        self.dec5 = self.conv_block(d*32, d*16)
        self.up4 = nn.ConvTranspose2d(d*16, d*8, 2, stride=2)
        self.dec4 = self.conv_block(d*16, d*8)
        self.up3 = nn.ConvTranspose2d(d*8, d*4, 2, stride=2)
        self.dec3 = self.conv_block(d*8, d*4)
        self.up2 = nn.ConvTranspose2d(d*4, d*2, 2, stride=2)
        self.dec2 = self.conv_block(d*4, d*2)
        self.up1 = nn.ConvTranspose2d(d*2, d, 2, stride=2)
        self.dec1 = self.conv_block(d*2, d)

        self.final = nn.Conv2d(d, 1, kernel_size=1)

        self.cbam4 = CBAM(d*8)
        self.cbam2 = CBAM(d*2)
        self.cbam0 = CBAM(d)

    def conv_block(self, in_ch, out_ch):
        
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(self.dp),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(self.dp)
        )

    def forward(self, x):

        e1 = self.encoder1(x)
        e2 = self.encoder2(self.pool(e1))
        e3 = self.encoder3(self.pool(e2))
        e4 = self.encoder4(self.pool(e3))
        e5 = self.encoder5(self.pool(e4))
        
        m = self.mid(self.pool(e5))

        d5 = self.dec5(torch.cat([self.up5(m), e5], dim=1))
        d4 = self.dec4(torch.cat([self.up4(d5), e4], dim=1))
        d4 = self.cbam4(d4)
        d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d2 = self.cbam2(d2)
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        d1 = self.cbam0(d1)

        out = self.final(d1)  
        
        return out

### Dataset and DataLoader


In [None]:
class AstroDataset(Dataset):
    def __init__(self, map_paths, cat_paths=None):
        self.map_paths = map_paths
        self.cat_paths = cat_paths
    
    def __len__(self):
        return len(self.map_paths)
    
    def __getitem__(self, idx):
        # Load image with gaussian smoothing
        map_data = gaussian_filter(pyfits.open(self.map_paths[idx])[0].data, sigma=SMOOTHING)
        
        if self.cat_paths is None:
            return torch.FloatTensor(map_data).unsqueeze(0)

        # Load catalog
        cat = pyfits.open(self.cat_paths[idx])[0].data
        target = np.zeros((1024, 1024))

        for y, x in cat[:, 1:3]:
            target[int(y), int(x)] = 1.0

        return torch.FloatTensor(map_data).unsqueeze(0), torch.FloatTensor(target), self.cat_paths[idx]

### Model and Training Functions

In [None]:
DEVICE = "cuda"

# Function for applicating model results
def detect_objects(model, image_path, confidence_threshold, device=DEVICE):
    model.eval()
    with torch.no_grad():
        # Load and preprocess image
        map_data = gaussian_filter(pyfits.open(image_path)[0].data, sigma=SMOOTHING)
        image = torch.FloatTensor(map_data).unsqueeze(0).unsqueeze(0).to(device)
        
        # Get predictions
        output = model(image)
        output = torch.sigmoid(output)
        predictions = output.cpu().squeeze().numpy()
        
        # Convert to coordinates
        coordinates = []
        for y, x in zip(*np.where(predictions > confidence_threshold)):
            confidence = predictions[y, x]
            coordinates.append((x, y, confidence))
            
        return coordinates

# Visualizing model predictions
def visualize_results(model, image_path, label_path, confidence_threshold, device=DEVICE):
    # Get predictions with confidence scores
    results = detect_objects(model, image_path, confidence_threshold, device)
    results = np.array(results).T if results else np.array([[],[],[]])
    
    Z = pyfits.open(image_path)[0].data
    Z_smooth = gaussian_filter(Z, sigma=SMOOTHING)
    
    labels = np.transpose(pyfits.open(label_path)[0].data)
    
    plt.figure(figsize=(10,10))
    plt.imshow(Z_smooth, vmin=-0.1, vmax=0.2, cmap='binary')
    plt.scatter(labels[2], labels[1], facecolors='none', edgecolors='red', s=100, label="True")
    if len(results[0]) > 0:
        # Color the scatter points based on confidence scores
        plt.scatter(results[0], results[1], facecolors='none', edgecolors='green', s=100, label="Predicted")
    else:
        print("result length 0")
    plt.legend()
    plt.show()


'''
Metric function for calculating PR-AUC.
This exact function will be used for evaluation
'''
def calculate_precision_recall_curve(predictions, labels, verbose = True):

    if verbose:
        print("shape of predictions: ", predictions.shape)

    # Flatten the predictions and get the indices of the sorted predictions
    flat_predictions = predictions.flatten()
    sorted_indices = np.argsort(-flat_predictions)  # Sort in descending order

    precisions = []
    recalls = []

    true_preds = 0
    num_preds = 0
    predicted_labels = 0
    num_labels = sum(len(l) for l in labels)

    labels_within_distance = [[] for _ in range(len(flat_predictions))]

    i = 0
    for image_idx, image_labels in enumerate(labels):
        for y_true, x_true in image_labels:
            for y in range(max(0, int(y_true) - 15), min(1024, int(y_true) + 16)):
                # Calculate the maximum x distance for the current y
                max_x_dist = int((max(0, 15**2 - (y - y_true)**2))**0.5)
                # Calculate the range of x-coordinates
                for x in range(max(0, int(x_true) - max_x_dist), min(1024, int(x_true) + max_x_dist + 1)):
                    coord_idx = image_idx * 1024 * 1024 + y * 1024 + x
                    labels_within_distance[coord_idx].append(i)
            i += 1
    
    label_predicted = [False] * num_labels

    # Iterate over sorted predictions
    for idx in sorted_indices:

        num_preds += 1

        # Determine the image index and the coordinate within the image
        image_idx = idx // (1024 * 1024)
        coord_idx = idx % (1024 * 1024)
        y, x = divmod(coord_idx, 1024)

        if len(labels_within_distance[idx]) > 0:
            true_preds += 1
            for label in labels_within_distance[idx]:
                if label_predicted[label] is False:
                    label_predicted[label] = True
                    predicted_labels += 1

        # Calculate precision and recall
        precision = true_preds / num_preds
        recall = predicted_labels / num_labels

        # Append precision and recall to the lists
        precisions.append(precision)
        recalls.append(recall)

    # Calculate PR-AUC using the trapezoidal rule
    pr_auc = np.trapz(precisions, x=recalls)

    return precisions, recalls, pr_auc

# Evaluate model
def get_pr(model, test_loader, device=DEVICE, verbose=True):
    model.eval()
    outputs_list = []
    labels_list = []
    for images, _, paths in test_loader:
        with torch.no_grad():
            images = images.to(device)
            outputs = torch.sigmoid(model(images)).cpu().numpy().squeeze(1)
            outputs_list.append(outputs)
            cat_data = [np.transpose(pyfits.open(path)[0].data) for path in paths]
            labels = [list(zip(cat[1], cat[2])) for cat in cat_data]
            labels_list += labels
    outputs = np.concat(outputs_list,axis=0)
    return calculate_precision_recall_curve(outputs, labels_list, verbose=verbose)

In [None]:
# Path to training data *** do not change *** 
DATA_DIR = "/bohr/training-lg02/v1/"

DEVICE = "cuda"
MODEL_SAVE_PATH = "model.pth"

SMOOTHING = 0.0
BATCH_SIZE = 4

data_size = len(os.listdir(os.path.join(DATA_DIR, 'map')))
dataset = AstroDataset([os.path.join(DATA_DIR, f'map/{i}.fits') for i in range(1, 56+1)], [os.path.join(DATA_DIR, f'cat/{i}.fits') for i in range(1, 56+1)])
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

test_set = AstroDataset([os.path.join(DATA_DIR, f'map/{i}.fits') for i in range(57, data_size+1)], [os.path.join(DATA_DIR, f'cat/{i}.fits') for i in range(57, data_size+1)])
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
if model in locals():
    del model

NUM_EPOCHS = 5
LEARNING_RATE = 0.003

model = MyModel()
weight = torch.tensor([1000.0]).cuda()
criterion = nn.BCEWithLogitsLoss(pos_weight=weight)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

train_losses = []

model.cuda()
for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0
    
    for images, targets, _ in dataloader:
        images = images.to(DEVICE)
        targets = targets.to(DEVICE)
        
        optimizer.zero_grad()
        outputs = model(images).squeeze(1)
        loss = criterion(outputs, targets)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    train_loss = total_loss/len(dataloader)
    train_losses.append(train_loss)

    model.eval()
    _, _, test_auc = get_pr(model, dataloader, verbose=False)
    _, _, train_auc = get_pr(model, test_loader, verbose=False)

    if (epoch+1)%1 == 0:
        print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Training Loss: {train_loss:.6f}, Train AUC: {train_auc:.6f}, Test AUC: {test_auc:.6f}')

### Inference

In [None]:
i = 30
visualize_results(model, dataset.map_paths[i], dataset.cat_paths[i], confidence_threshold=0.5)