In [None]:
import os
import math
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import astropy.io.fits as pyfits
from torchsummary import summary
import argparse
import os
import math
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import astropy.io.fits as pyfits
from torchsummary import summary
import argparse
import pdb

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
class Atanh(nn.Module):
    def __init__(self):
        super(Atanh, self).__init__()
        self._amp = nn.Parameter(torch.randn(1))
        self._slope = nn.Parameter(torch.randn(1))

    def forward(self, x):
        return self._amp * ((torch.exp(2 * self._slope * x) - 1) / (torch.exp(2 * self._slope * x) + 1))


In [None]:
class AReLU(nn.Module):
    def __init__(self):
        super(AReLU, self).__init__()
        self._const = nn.Parameter(torch.randn(1))

    def forward(self, x):
        return torch.maximum(self._const * x, x)

In [None]:
class AReLU2(nn.Module):
    def __init__(self,):
        """
        Trainable activation function

        Args:
            init_const (float): Initial value for the trainable constant
        """
        super(AReLU2, self).__init__()

        # Create a trainable parameter
        # Uses torch.nn.Parameter to make it part of the model's parameters
        self._const = nn.Parameter(torch.tensor(torch.randn(1), dtype=torch.float32))

    def forward(self, x):
        """
        Forward pass with trainable activation

        Args:
            x (torch.Tensor): Input tensor

        Returns:
            torch.Tensor: Activated tensor
        """
        return torch.maximum(self._const * x, x)


In [None]:
class Automap(nn.Module):
    def __init__(self, input_shape=(45*45)):
        super(Automap, self).__init__()
        self.input_shape = input_shape
        self.model = nn.Sequential(
            nn.Linear(45*45, 6*6*256, bias=False),
            nn.PReLU(),
            nn.Unflatten(1, (256, 6, 6)),

            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.PReLU(),
            nn.Dropout(0.1),

            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.PReLU(),
            nn.Dropout(0.1),

            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=2, bias=False),
            nn.PReLU(),
            nn.Dropout(0.1),

            nn.Conv2d(32, 1, kernel_size=4, stride=1, padding=1, bias=False),
            AReLU2()
        )

    def forward(self, x):
        aa = self.model(x)
        return aa

In [None]:
def combine_images(generated_images):
    num = generated_images.shape[0]
    width = int(math.sqrt(num))
    height = int(math.ceil(float(num)/width))
    shape = generated_images.shape[2:4]
    image = torch.zeros((height*shape[0], width*shape[1]))

    for index, img in enumerate(generated_images):
        i = int(index/width)
        j = index % width
        image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = torch.from_numpy(img[0]).float()

    return image.numpy()

In [None]:
def train(batch_size=128, epochs=200):
    # Load data

    load_data = np.load('/content/drive/MyDrive/NN_deconv/io_outputs_10_2025.npz')
    train_images = load_data['cube']
    X_train = train_images.reshape(1, 45, 45, 30000)
    X_train = np.swapaxes(X_train, 3, 2)
    X_train = np.swapaxes(X_train, 2, 1)
    X_train = np.swapaxes(X_train, 1, 0)

    #pyfits.writeto('test.fits', X_train[0,0,:,:], overwrite=True)
    #pdb.set_trace()

    load_oidata = np.load('/content/drive/MyDrive/NN_deconv/io_models_10_2025.npz')
    train_oidata = load_oidata['cube'].T

    TMP = Automap()
    summary(TMP, (45*45,), device = 'cpu')

    #pdb.set_trace()
    # Convert to torch tensors
    X_train_tensor = torch.from_numpy(X_train).float()
    train_oidata_tensor = torch.from_numpy(train_oidata).float()

    # Setup model, loss, and optimizer
    model = Automap().to('cuda')
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.00009)



    # Training loop
    fig2, ax11 = plt.subplots(1, 1, figsize=(9, 6))
    metrics = []

    for epoch in range(epochs):
        epoch_losses = []
        print(f"Epoch: {epoch}")


        for index in range(int(X_train.shape[0] / batch_size)):
            # Prepare batch
            image_batch = X_train_tensor[index * batch_size:(index + 1) * batch_size] ## Deconvolved ims
            oidata_batch = train_oidata_tensor[index * batch_size:(index + 1) * batch_size] ### Convolved ims

            # Zero gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(oidata_batch.to('cuda')).to('cuda')

            # Compute loss
            loss = criterion(outputs, image_batch.to('cuda'))

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            epoch_losses.append(loss.item())

            # Periodic image generation and saving
            if index % 50 == 0:
                with torch.no_grad():
                    generated_images = model(oidata_batch.to('cuda')).to('cuda')
                    image = combine_images(generated_images.cpu().numpy())

                    fig, ax1 = plt.subplots(1,1)
                    ax1.imshow(image, cmap='jet', origin='lower')
                    fig.savefig(f"/content/drive/MyDrive/NN_deconv/deconv_10nn/{epoch}_Deconv_{index}.png")
                    pyfits.writeto(f"/content/drive/MyDrive/NN_deconv/deconv_10nn/{epoch}_Deconv_{index}.fits", image, overwrite=True)
                    plt.close(fig)

        mean_loss = np.mean(epoch_losses)
        metrics.append(mean_loss)

        ax11.errorbar(epoch, mean_loss, yerr=np.std(epoch_losses), fmt='o', color='red')
        ax11.set_xlabel('Epoch')
        ax11.set_ylabel('Model MSE')
        fig2.savefig('/content/drive/MyDrive/NN_deconv/losses80_Deconv_10_2500.png', bbox_inches='tight')
        plt.close(fig2)

        # Save model periodically
        if epoch % 50 == 0:
            torch.save(model.state_dict(), f'/content/drive/MyDrive/NN_deconv/deconv_10nn/model_Deconv_{epoch}.pth')

    # Save final model and metrics
    torch.save(model.state_dict(), '/content/drive/MyDrive/NN_deconv/model_Deconv_10nn_final.pth')
    np.savez('/content/drive/MyDrive/NN_deconv/metrics_Deconv_10nn_2500.npz', metrics=metrics)


In [None]:
def generate(oifilename):
    # Load data
    DATA = np.load(oifilename)
    DATA = DATA['cube'].astype(np.float32) #.T
    #pdb.set_trace()
    # Setup model
    model = Automap().to('cuda')
    model.load_state_dict(torch.load('/content/drive/MyDrive/NN_deconv/model_Deconv_10_final.pth'))
    model.eval()

    # Convert data to tensor
    DATA_tensor = torch.from_numpy(DATA)

    # Generate images
    with torch.no_grad():
        generated_image = model(DATA_tensor.to('cuda'))
        generated_images = generated_image.squeeze().cpu().numpy()

    # Save generated images
    pyfits.writeto('/content/drive/MyDrive/NN_deconv/predicted_images_10_4.fits', generated_images, overwrite=True)
    return 0

In [None]:
def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", type=str)
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--nice", dest="nice", action="store_true")
    parser.add_argument("--epoch", type=int, default=10)
    parser.add_argument("--filename", type=str)
    parser.set_defaults(nice=False)
    args = parser.parse_args()
    return args

In [None]:
# @title Default title text
if __name__ == "__main__":

    train(batch_size=2048, epochs=2500)

    #generate(oifilename='/content/drive/MyDrive/NN_deconv/n_cube_2025_4.npz')


    #args = get_args()
    #if args.mode == "train":
    #    train(batch_size=args.batch_size, epochs=args.epoch)
    #elif args.mode == "generate":
    #    generate(oifilename=args.filename)

  self._const = nn.Parameter(torch.tensor(torch.randn(1), dtype=torch.float32))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                 [-1, 9216]      18,662,400
             PReLU-2                 [-1, 9216]               1
         Unflatten-3            [-1, 256, 6, 6]               0
   ConvTranspose2d-4          [-1, 128, 12, 12]         524,288
             PReLU-5          [-1, 128, 12, 12]               1
           Dropout-6          [-1, 128, 12, 12]               0
   ConvTranspose2d-7           [-1, 64, 24, 24]         131,072
             PReLU-8           [-1, 64, 24, 24]               1
           Dropout-9           [-1, 64, 24, 24]               0
  ConvTranspose2d-10           [-1, 32, 46, 46]          32,768
            PReLU-11           [-1, 32, 46, 46]               1
          Dropout-12           [-1, 32, 46, 46]               0
           Conv2d-13            [-1, 1, 45, 45]             512
           AReLU2-14            [-1, 1,