In [None]:
!pip install pandas numpy torchvision 

In [None]:
%reload_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset, random_split
import numpy as np
import torch.optim.lr_scheduler as lr_scheduler
import matplotlib.pyplot as plt
from CNNs.modelsLSTM import AstroNet,RNN
from analysis import plot_circles_on_image,pixel_distance,Metrics,epoch_metrics,error_calc
import copy
from tqdm import tqdm

# Train dynamic models to find the CoB and CoM of a celectial object
## Objects:
- Mars
- Asteroid itokawa 
## Models:
- CNN - RNN
- AstroNet (CNN-LSTM) 

## For Itokawa: 
Uncomment this

In [None]:
# from asteroids.generate_datasetSQ import CustomDataset

## For Mars:
Uncomment this

In [None]:
# from Mars.generate_datasetSQ import CustomDataset

## Load the datasets:

In [None]:
# Load the datasets
train_dataset = torch.load('train_datasetSQ.pth')
val_dataset= torch.load('val_datasetSQ.pth')
test_dataset = torch.load('test_datasetSQ.pth')

## Create data loaders:

In [None]:
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=20, shuffle=True)
val_loader =   DataLoader(val_dataset, batch_size=20, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=22, shuffle=True)

In [None]:
def plot_results(num_epochs, train_losses, val_losses, val_distances):
   epochs = range(1, num_epochs + 1)
   plt.rcParams.update({'font.size': 24})
   dark_grey= '#3b3b3b'
   # Plotting all the losses
   plt.figure(figsize=(12, 6))

   plt.scatter(epochs, train_losses, label='Training Loss')
   plt.scatter(epochs, val_losses, label='Validation Loss')
   plt.xlabel('Epochs', color=dark_grey)
   plt.ylabel('Loss', color=dark_grey)
   plt.legend(labelcolor=dark_grey)
   plt.title('Training and Validation Loss', color=dark_grey)
   plt.tight_layout()
#    plt.savefig("pics//asteroids//Valtrain_loss_ast_COB3.pdf", format="pdf")
   plt.show()

   # Plotting all the distances
   plt.figure(figsize=(12, 6))

   plt.scatter(epochs, val_distances, label='Validation Distance', color='green')
   plt.xlabel('Epochs', color=dark_grey)
   plt.ylabel('Distance', color=dark_grey)
   plt.legend( labelcolor=dark_grey)
   plt.title('Pixel Distance', color=dark_grey)
   plt.tight_layout()
#    plt.savefig("pics//asteroids//Valdistance_ast_COB3.pdf", format="pdf")
   plt.show()



In [None]:
def calculate_accuracy(outputs, targets):
    total_distance = 0
    for output, target in zip(outputs, targets):
        distance = ((target[0] - output[0]) ** 2 + (target[1] - output[1]) ** 2) ** 0.5
        total_distance += distance
    return total_distance / len(outputs)

In [None]:
def train(valloader, dataloader, model, device='cpu'):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.MSELoss()
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=3)

    patience = 7
    best_loss = float('inf')
    best_model_wts = copy.deepcopy(model.state_dict())
    no_improvement_counter = 0

    train_losses = []
    val_losses = []
    val_distances = []
    accuracies = []
    metrics_history = {'MAE': [], 'MSE': [], 'RMSE': [], 'R-squared': []}

    num_epochs = 80
    for epoch in range(num_epochs):
        print("##### epoch {} : #####".format(epoch + 1))

        model.train()
        running_loss = 0.0
        tot = 0.0
        b = 0.0

        for inputs, targets in tqdm(dataloader):
            tot += dataloader.batch_size
            inputs = inputs.to(device)
            targets = targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            # Ensure outputs have the correct shape (batch_size, num_classes)
            t=targets[:, -1, :]
            o=outputs[:, -1, :]

            loss = criterion(outputs, targets)  # Use the last label in the sequence
            loss.backward()
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)

            optimizer.step()
            running_loss += loss.item()
            b += 1.0
            if b % 1000 == 0:
                print("batch: {}: Running Loss  {:.4f}".format(b, running_loss / b))
                # print(f'Outputs : {o[:3]}')  # Should be (batch_size, output_size)
                # print(f'Labels : {t[:3]}')  # Should be (batch_size,)

        train_loss = running_loss / len(dataloader)
        train_losses.append(train_loss)

        model.eval()

        validation_loss = 0.0
        total_distance = 0.0
        counter = 0.0
        Error = 0.0
        metrics = {'MAE': 0.0, 'MSE': 0.0, 'RMSE': 0.0, 'R-squared': 0.0, 'count': 0}

        with torch.no_grad():
            for x, y in valloader:
                x_val, y_val = x.to(device), y.to(device)
                val_result = model(x_val)
                lossv = criterion(val_result, y_val)  # Use the last label in the sequence
                validation_loss += lossv.item()
                d, c = pixel_distance(val_result[:, -1, :], y_val[:, -1, :])
                Error += error_calc(val_result[:, -1, :], y_val[:, -1, :])
                total_distance += d.item()
                counter += c
                Metrics(metrics, val_result[:, -1, :], y_val[:, -1, :])

        avg_val_loss = validation_loss / len(valloader)
        avg_val_distance = total_distance / len(valloader)
        accuracy = 100 * (counter / len(valloader.dataset))
        Error /= len(valloader)
        val_losses.append(avg_val_loss)
        val_distances.append(avg_val_distance)
        accuracies.append(accuracy)

        print(f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}\n"
              f"Validation Distance: {avg_val_distance:.4f}, Validation Accuracy: {accuracy:.4f}, Error: {Error:.4f}")
        epoch_metric = epoch_metrics(metrics)
        for metric, value in epoch_metric.items():
            metrics_history[metric].append(value)
        scheduler.step(avg_val_loss)
        scheduler.get_last_lr()
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            no_improvement_counter = 0
        else:
            no_improvement_counter += 1
            if no_improvement_counter >= patience:
                print("Early stopping triggered")
                num_epochs = epoch + 1
                break

    model.load_state_dict(best_model_wts)
    plot_results(num_epochs, train_losses, val_losses, val_distances)
    return model, train_losses, val_losses, val_distances, metrics_history


In [None]:
def Test(model, testloader, device='cpu'):
    model.to(device)
    model.eval()
    num_epochs = 5
    criterion = nn.MSELoss()
    test_losses = []
    test_distances = []
    accuracies = []
    Error = 0.0

    for epoch in range(num_epochs):
        print("##### epoch {} : #####".format(epoch + 1))
        metrics = {'MAE': 0.0, 'MSE': 0.0, 'RMSE': 0.0, 'R-squared': 0.0, 'count': 0}

        test_loss = 0.0
        total_distance = 0.0
        counter = 0.0
        n = 0.0
        with torch.no_grad():
            for x, y in testloader:
                x_test, y_test = x.to(device), y.to(device)
                test_result = model(x_test)
                losst = criterion(test_result, y_test)
                test_loss += losst.item()
                d, c = pixel_distance(test_result[:, -1, :], y_test[:, -1, :])
                total_distance += d
                counter += c
                Metrics(metrics, test_result[:, -1, :], y_test[:, -1, :])
                if n % 50 == 0:
                    plot_circles_on_image(x_test[0, -1, :], test_result[0, -1, :], y_test[0, -1, :])
                n += 1.0

                Error += error_calc(test_result[:, -1, :], y_test[:, -1, :])
                counter += c

        Error /= len(testloader)
        avg_test_loss = test_loss / len(testloader)
        avg_test_distance = total_distance / len(testloader)
        accuracy = 100 * (counter / len(testloader.dataset))
        test_losses.append(avg_test_loss)
        test_distances.append(avg_test_distance)
        accuracies.append(accuracy)
        print(f"Test Loss: {avg_test_loss:.4f}, Test Distance: {avg_test_distance:.4f}, Error: {Error:.4f}")
        epoch_metric = epoch_metrics(metrics)

    return test_losses, test_distances, accuracies

## Train the model:

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("** GPU **")
else:
    device = torch.device("cpu")
    print("** CPU **")

# LTSM_model =RNN1()
LSTM_model= CNNLSTM()
model,train_losses ,val_losses ,val_distances,val_metrics =train(test_loader,train_loader,LSTM_model,device)

## Save the model:

In [None]:
# torch.save(model, 'mars_models//model_LSTM_COB.pth')
# print("Model paths is saved")

## Load and test the model:

In [None]:
model_test = torch.load('mars_models//model_LSTM_COB.pth')

In [None]:
Test(model_test, test_loader, device='cpu')