In [6]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision

from tqdm import tqdm
from copy import deepcopy
from timeit import default_timer
from sklearn.model_selection import train_test_split
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix, log_loss
from sklearn.metrics import RocCurveDisplay, roc_curve, auc


import torch.nn as nn
import torch.nn.functional as F

from torchsummary import summary
from torch.utils.data import DataLoader
from torch.utils.data import random_split

In [None]:
X = np.load('/content/drive/MyDrive/simulations.npy')
y = np.load('/content/drive/MyDrive/dataset.npy')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def train(model, criterion, optimizer, dataset, n_epochs, n_stop=10):
    batch_size = 200
    total_size = dataset.shape[0]
    train_dataset, val_dataset, test_dataset = random_split(dataset, 
                                            [int(total_size * 0.7), int(total_size * 0.2), int(total_size * 0.1)])

    train_loader = DataLoader(train_dataset, batch_size=batch_size)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)

    train_loss = list()
    val_loss = list()
    min_val_loss = np.inf
    f = 0
    t0 = default_timer()
    for epoch in range(n_epochs):
        t1 = default_timer()
        model.train()
        train_batch_loss = list()
        for batch in train_loader:
            optimizer.zero_grad()
            _X, _y = batch
            _X, _y = _X.to(device), _y.to(device)
            _y_pred = model(_X)
            loss = criterion(_y_pred, _y)
            loss.backward()
            optimizer.step()
            train_batch_loss.append(loss.item())

        model.eval()
        val_batch_loss = list()
        for batch in val_loader:
            _X, _y = batch
            _X, _y = _X.to(device), _y.to(device)
            _y_pred = model(_X)
            loss = criterion(_y_pred, _y)
            val_batch_loss.append(loss.item())

        t2 = default_timer()

        if epoch % 10 == 0:
            print(f'Epoch: {epoch} ({round(t2-t1, 3)}s, {round(t2-t0, 3)}s), \tTrain loss: {np.mean(train_batch_loss).round(3)}, \tValidation loss: {np.mean(val_batch_loss).round(3)}')
        train_loss.append(np.mean(train_batch_loss))
        val_loss.append(np.mean(val_batch_loss))

        # (optional) early stopping:

        if round(val_loss[-1], 3) >= min_val_loss:
            f = f + 1
        else:
            # print(f'epoch: {epoch}, val loss did decrease, saving model')
            f = 0
            best_model = deepcopy(model)
            min_val_loss = round(val_loss[-1], 3)
        if f >= n_stop:
            print(f'epoch: {epoch}, val loss did not decrease for {f} epoch(s)')
            break

    return best_model, train_loss, val_loss

In [None]:
# Convert data to PyTorch tensors
X_tensor = torch.from_numpy(X)
y_tensor = torch.from_numpy(y)

# Define the CNN architecture
class CNN(torch.nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = torch.nn.Conv2d(2, 32, kernel_size=3, padding=1)
        self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool2 = torch.nn.MaxPool2d(kernel_size=2)
        self.fc1 = torch.nn.Linear(64 * 128 * 128, 128)
        self.fc2 = torch.nn.Linear(128, 513 * 513)

    def forward(self, x):
        x = self.conv1(x)
        x = torch.relu(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = torch.relu(x)
        x = self.pool2(x)
        x = x.view(-1, 64 * 128 * 128)
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        return x

# Instantiate the CNN
cnn = CNN()

# Define the loss function and optimizer
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(cnn.parameters(), lr=0.001)

# Train the CNN
num_epochs = 10
batch_size = 1
num_batches = X.shape[0]

for epoch in range(num_epochs):
    for batch in range(num_batches):
        # Get a pair of images
        X_batch = X_tensor[batch].unsqueeze(0)
        y_batch = y_tensor[batch].unsqueeze(0)

        # Forward pass
        xy_pred = cnn(torch.cat((X_batch, y_batch), dim=1).float())
        y_pred = xy_pred.view(513, 513)

        # Calculate loss and update weights
        loss = criterion(y_pred, y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Print the loss after each epoch
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}')

# Predict one image using another image
X_test = X_tensor[0].unsqueeze(0)
y_test = y_tensor[0].unsqueeze(0)
xy_pred = cnn(torch.cat((X_test, y_test), dim=1).float())
y_pred = xy_pred.view(513, 513)

# Print the predicted image
print(y_pred)
