# HyperResNet: SuperResolution Using Machine Learning

This notebook allows you to run the HyperResNet model from the `itsitgroup/HyperResNet` GitHub repository. You can input parameters through the UI.


## Setup

### Clone the Repository

In [None]:
#@title Clone the GitHub repository {display-mode: "form"}
!git clone https://github.com/itsitgroup/HyperResNet.git
%cd HyperResNet

### Install Dependencies

In [None]:
#@title Install required dependencies {display-mode: "form"}
!pip install -r requirements.txt

### Input Parameters

Use the form below to input the parameters for the model training and evaluation.

In [None]:
#@title Input Parameters {display-mode: "form"}

model_path = "my_model.h5" #@param {type:"string"}
batch_size = 32 #@param {type:"integer"}
epochs = 10 #@param {type:"integer"}
learning_rate = 0.0001 #@param {type:"number"}
filters = 64 #@param {type:"integer"}
blocks = 3 #@param {type:"integer"}
save_every = 0 #@param {type:"integer"}

# Create a dictionary to store the parameters
params = {
    "model_path": model_path,
    "batch_size": batch_size,
    "epochs": epochs,
    "learning_rate": learning_rate,
    "filters": filters,
    "blocks": blocks,
    "save_every": save_every
}

### Run the Script

In [None]:
#@title Run the HyperResNet script {display-mode: "form"}

import os

# Ensure the save_path directory exists
save_path = 'plots'
if not os.path.exists(save_path):
    os.makedirs(save_path)

# Construct the command to run the script with the user-defined parameters
command = (
    f"python main.py --model_path {params['model_path']} "
    f"--batch_size {params['batch_size']} "
    f"--epochs {params['epochs']} "
    f"--learning_rate {params['learning_rate']} "
    f"--filters {params['filters']} "
    f"--blocks {params['blocks']} "
    f"--save_every {params['save_every']}"
)

# Run the command
os.system(command)

### Display Results

In [None]:
#@title Display Results {display-mode: "form"}

import matplotlib.pyplot as plt
import os

# Define the paths to the saved plots
loss_plot_path = os.path.join(save_path, 'loss.png')
accuracy_plot_path = os.path.join(save_path, 'accuracy.png')
predictions_plot_path = os.path.join(save_path, 'predictions.png')

# Display the loss plot
if os.path.exists(loss_plot_path):
    img = plt.imread(loss_plot_path)
    plt.figure(figsize=(10, 5))
    plt.imshow(img)
    plt.axis('off')
    plt.title('Loss Plot')
    plt.show()

# Display the accuracy plot
if os.path.exists(accuracy_plot_path):
    img = plt.imread(accuracy_plot_path)
    plt.figure(figsize=(10, 5))
    plt.imshow(img)
    plt.axis('off')
    plt.title('Accuracy Plot')
    plt.show()

# Display the predictions plot
if os.path.exists(predictions_plot_path):
    img = plt.imread(predictions_plot_path)
    plt.figure(figsize=(10, 5))
    plt.imshow(img)
    plt.axis('off')
    plt.title('Predictions Plot')
    plt.show()