In [None]:
!pip install pandas numpy torchvision 

In [None]:
%reload_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset, random_split
import numpy as np
import torch.optim.lr_scheduler as lr_scheduler
from Mars.generate_dataset import CustomDataset
import matplotlib.pyplot as plt
from CNNs.models import CNN, ResNet
from CNNs.Baseline import LinearReg
from analysis import plot_circles_on_image,pixel_distance,Metrics,epoch_metrics,error_calc
import copy


# Train static models to find the CoB and CoM of a celectial object
## Objects:
- Mars
## Models:
- Linear Regression 
- CNN
- ResNet

## Load the datasets:

In [None]:
# Load the datasets
train_dataset = torch.load('train_dataset.pth')
val_dataset= torch.load('val_dataset.pth')
test_dataset = torch.load('test_dataset.pth')

## Create data loaders

In [None]:
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=50, shuffle=True)
val_loader =   DataLoader(val_dataset, batch_size=20, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=20, shuffle=True)

In [None]:
# plotting fuction 
def plot_results(num_epochs, train_losses, val_losses, val_distances):
   epochs = range(1, num_epochs + 1)
   plt.rcParams.update({'font.size': 24})
   dark_grey= '#3b3b3b'
   # Plotting all the losses
   plt.figure(figsize=(12, 6))
   plt.scatter(epochs, train_losses, label='Training Loss')
   plt.scatter(epochs, val_losses, label='Validation Loss')
   plt.xlabel('Epochs', color=dark_grey)
   plt.ylabel('Loss', color=dark_grey)
   plt.legend(labelcolor=dark_grey)
   plt.title('Training and Validation Loss', color=dark_grey)
   plt.tight_layout()
#    plt.savefig("pics//asteroids//Valtrain_loss_ast_COB3.pdf", format="pdf")
   plt.show()

   # Plotting all the distances
   plt.figure(figsize=(12, 6))

   plt.scatter(epochs, val_distances, label='Validation Distance', color='green')
   plt.xlabel('Epochs', color=dark_grey)
   plt.ylabel('Distance', color=dark_grey)
   plt.legend( labelcolor=dark_grey)
   plt.title('Pixel Distance', color=dark_grey)
   plt.tight_layout()
#    plt.savefig("pics//asteroids//Valdistance_ast_COB3.pdf", format="pdf")
   plt.show()


In [None]:
# train function 
def train(valloader,dataloader,model,device='cpu'):
    # Instantiate the model
    # Define an optimizer and loss function
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001) 

    criterion = nn.MSELoss()

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=3)

    patience = 10  # Number of epochs to wait for improvement before stopping
    best_loss = float('inf')
    best_model_wts = copy.deepcopy(model.state_dict())
    no_improvement_counter = 0

    train_losses = []
    val_losses = []
    val_distances = []
    accuracies = []
    metrics_history = {'MAE': [], 'MSE': [], 'RMSE': [], 'R-squared': []}

    num_epochs = 80  # Set the number of epochs as required
    for epoch in range(num_epochs):
        print("##### epoch {} : #####".format(epoch+1))

        model.train()  # Set the model to training mode
        running_loss,tot = 0.0,0.0
        b=0.

        for inputs, targets in dataloader:
            tot += dataloader.batch_size  
            inputs=inputs.to(device)
            targets=targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            running_loss+= loss.item()
            b += 1.0
            if  b % 500 == 0:
                print("batch: {}: Running Loss  {:.4f}".format(b,running_loss/b))
            
        train_loss = running_loss / (len(dataloader))
        train_losses.append(train_loss)

        model.eval()
        
        validation_loss = 0.0
        total_distance = 0.0
        counter=0.0
        Error=0.0
        metrics = {'MAE': 0.0, 'MSE': 0.0, 'RMSE': 0.0, 'R-squared': 0.0,'count': 0}

        n=0.0
        with torch.no_grad():
            for x, y in valloader:
                x_val, y_val = x.to(device), y.to(device)
                val_result = model(x_val)
                lossv = criterion(val_result, y_val)
                validation_loss += lossv.item()
                d, c = pixel_distance(val_result, y_val)
                Error+= error_calc(val_result, y_val)
                total_distance += d.item()
                counter += c
                Metrics(metrics, val_result, y_val)
               

        avg_val_loss = validation_loss / len(valloader)
        avg_val_distance = total_distance / len(valloader)
        accuracy = 100 * (counter / len(valloader.dataset))
        Error/=len(valloader)
        val_losses.append(avg_val_loss)
        val_distances.append(avg_val_distance)
        accuracies.append(accuracy)

        print(f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}\n" 
              f"Validation Distance: {avg_val_distance:.4f}, Validation Accuracy: {accuracy:.4f}, Error: {Error:.4f} ")
        epoch_metric = epoch_metrics(metrics)
        for metric, value in epoch_metric.items():
            metrics_history[metric].append(value)
        # Step the learning rate scheduler
        scheduler.step(avg_val_loss)

        # Early stopping and saving the best model
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            no_improvement_counter = 0
        else:
            no_improvement_counter += 1
            if no_improvement_counter >= patience:
                print("Early stopping triggered")
                num_epochs=epoch+1
                break

    # Load the best model weights
    model.load_state_dict(best_model_wts)
    plot_results(num_epochs,train_losses,val_losses,val_distances)
    return model, train_losses ,val_losses ,val_distances,metrics_history 


In [None]:
def Test(model, testloader, device='cpu'):
    model.to(device)
    model.eval()
    num_epochs = 5 
    criterion = nn.MSELoss()
    test_losses = []
    test_distances = []
    accuracies = []
    Error=0.0
 # Set the number of epochs as required
    for epoch in range(num_epochs):
        print("##### epoch {} : #####".format(epoch+1))
        metrics = {'MAE': 0.0, 'MSE': 0.0, 'RMSE': 0.0, 'R-squared': 0.0, 'count': 0}
        
        test_loss = 0.0
        total_distance = 0.0
        counter = 0.0
        n=0.0
        with torch.no_grad():
            for x, y in testloader:
                x_test, y_test = x.to(device), y.to(device)
                test_result = model(x_test)
                losst = criterion(test_result, y_test)
                test_loss += losst.item()
                d, c = pixel_distance(test_result, y_test)
                total_distance += d
                counter += c
                Metrics(metrics, test_result, y_test)
                if n%500==0 and c/(testloader.batch_size)>= 0.95:
                    plot_circles_on_image(x_test[0], test_result[0], y_test[0])
                n+=1.0  

                Error+= error_calc(test_result, y_test)
                counter += c              

        Error/=len(testloader)
        avg_test_loss = test_loss / len(testloader)
        avg_test_distance = total_distance / len(testloader)
        accuracy = 100 * (counter / len(testloader.dataset))
        test_losses.append(avg_test_loss)
        test_distances.append(avg_test_distance)
        accuracies.append(accuracy)
        print(f"Test Loss: {avg_test_loss:.4f}, Test Distance: {avg_test_distance:.4f}, Error: {Error:.4f} ")
        epoch_metric = epoch_metrics(metrics)

## Train the model:

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("** GPU **")
else:
    device = torch.device("cpu")
    print("** CPU **")

model=CNN()
model,train_losses ,val_losses ,val_distances,val_metrics =train(val_loader,train_loader,model,device)


## Save the model:

In [None]:
# torch.save(model, 'model_CNN2_COB_COMM.pth')
# print("Model paths is saved")

## Load and test the model:

In [None]:
test_model= torch.load('model_CNN2_CoM.pth')

In [None]:
Test(test_model, test_loader, device='cpu')