In [1]:
import time
import datetime
import json
import os
import random
import numpy as np
import webbrowser
from PIL import Image
import matplotlib.pyplot as plt
from rich.progress import track
import subprocess
import pandas as pd
import torch
import torch.nn as nn
import torchvision
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchinfo import summary
from torch.utils.tensorboard import SummaryWriter


In [2]:
def create_dl_links():
    dl_list =[]
    files = os.listdir("combined")
    for i in range(len(files)):
        with open("combined/" + files[i], "r") as file:
            data = json.load(file)
            for j in range(len(data)):
                date = data[j]["date"].split(" ")
                date.remove(date[-1])
                date = date[0].replace("-", "/")
                img = data[j]["image"] + ".png"
                link = f"https://epic.gsfc.nasa.gov/archive/enhanced/{date}/png/{img}"
                dl_list.append(link)
    return dl_list
list = create_dl_links()

In [3]:
# Create SummaryWriter instance
writer = SummaryWriter(f"runs/autoencoder/{datetime.datetime.now()}")

In [4]:

# Define the autoencoder architecture
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            ## nn.Identity(), # Does nothing
            nn.Conv2d(in_channels=3, out_channels=3, kernel_size=8,  stride=2, padding=3), # Could be: Padding = (kernel_size - 1) / 2 to retain input size
            nn.ReLU(),
            nn.Conv2d(in_channels=3, out_channels=3, kernel_size=8, stride=2, padding=3),
            nn.ReLU(),
            nn.Conv2d(in_channels=3, out_channels=3, kernel_size=8, stride=2, padding=3),
            nn.ReLU()
        )
        # Decoder
        self.decoder = nn.Sequential(
            ## nn.Identity(), # Does nothing
            # now trying to reverse the encoder
            nn.ConvTranspose2d(in_channels=3, out_channels=3, kernel_size=8, stride=2, padding=3),
            nn.ConvTranspose2d(in_channels=3, out_channels=3, kernel_size=8, stride=2, padding=3),
            nn.ConvTranspose2d(in_channels=3, out_channels=3, kernel_size=8, stride=2, padding=3),          
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [5]:
# Create dictionary with metrics and hyperparameters here:
hparam_dict = {
    "device": torch.device("mps" if torch.backends.mps.is_available() else "mps"), # Change to "cuda" if you want to use use Nvidia GPU
    "batch_size": 1,
    "learning_rate": 0.001,
    "num_epochs": 5,
    "use_batch_download": False, # DANGER! If set to True, will delete entire dataset folder at end of training!
    "loader_workers": 1,
    "transform": transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor()
    ])
}
metric_dict = {}

In [6]:

autoencoder = Autoencoder().to(hparam_dict["device"])
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=hparam_dict["learning_rate"])

# Add criterion and optimizer to hparam_dict
# Use string and split by "." to get class name

hparam_dict["criterion"] = str(criterion.__class__).split(".")[-1][:-2]
hparam_dict["optimizer"] = str(optimizer.__class__).split(".")[-1][:-2]

In [7]:
print(summary(autoencoder, input_size=(1, 3, 512, 512))) # Static
writer.add_graph(autoencoder, torch.rand(1, 3, 512, 512)) # Static

# Start tensorboard
tensorboard_process = subprocess.Popen(["tensorboard", "--logdir=runs", "--port=6006",], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

Layer (type:depth-idx)                   Output Shape              Param #
Autoencoder                              [1, 3, 512, 512]          --
├─Sequential: 1-1                        [1, 3, 64, 64]            --
│    └─Conv2d: 2-1                       [1, 3, 256, 256]          579
│    └─ReLU: 2-2                         [1, 3, 256, 256]          --
│    └─Conv2d: 2-3                       [1, 3, 128, 128]          579
│    └─ReLU: 2-4                         [1, 3, 128, 128]          --
│    └─Conv2d: 2-5                       [1, 3, 64, 64]            579
│    └─ReLU: 2-6                         [1, 3, 64, 64]            --
├─Sequential: 1-2                        [1, 3, 512, 512]          --
│    └─ConvTranspose2d: 2-7              [1, 3, 128, 128]          579
│    └─ConvTranspose2d: 2-8              [1, 3, 256, 256]          579
│    └─ConvTranspose2d: 2-9              [1, 3, 512, 512]          579
Total params: 3,474
Trainable params: 3,474
Non-trainable params: 0
Total mult-

In [8]:
# Launch browser on tensorboard port
webbrowser.open("http://localhost:6006")

True

In [9]:
# Checking for truncated images if use_batch_download is False
if not hparam_dict["use_batch_download"]:
    def is_image_truncated(file_path):
        try:
            # Attempt to open the image file
            img = Image.open(file_path)
            img.verify()  # Verify the integrity of the image file
            return False    # Image is not truncated
        except (IOError, SyntaxError):
            return True     # Image is truncated

    def find_truncated_images(folder_path):
        truncated_images = []

        # Iterate over files in the folder
        for filename in os.listdir(folder_path):
            file_path = os.path.join(folder_path, filename)

            # Check if the file is a regular file and not a directory
            if os.path.isfile(file_path):
                if is_image_truncated(file_path):
                    truncated_images.append(filename)

        return truncated_images

    # Example: Replace 'your_folder_path' with the actual path to your folder
    folder_path = 'download/earth/'
    truncated_images = find_truncated_images(folder_path)

    if truncated_images:
        print("Truncated Images:")
        for image in truncated_images:
            print(image)
    else:
        print("No truncated images found.")

    # Remove truncated images
    for image in truncated_images:
        os.remove(os.path.join(folder_path, image))


No truncated images found.


In [10]:
def train(hparam_dict, metric_dict):
    # Make sure the model is on the right device
    autoencoder.to(hparam_dict["device"]) # Else this happens: Input type (MPSFloatType) and weight type (torch.FloatTensor) should be the same (RuntimeError)
    epoch_time_list = []

    for epoch in track(range(hparam_dict["num_epochs"]), description=f'Starting training loop... '):
        if hparam_dict["use_batch_download"]:
            # Download data and preprocess
            for i in range(hparam_dict["batch_size"]):
                # Download data
                # Pick random image from list
                rand = random.randint(0, len(list))
                os.system(f"wget -P download/earth {list[rand]} --quiet")
        epoch_start = time.time()
        # Load data
        dataset = ImageFolder(root='download', transform=hparam_dict["transform"])
        dataloader = DataLoader(dataset, batch_size=hparam_dict["batch_size"], shuffle=True, num_workers=hparam_dict["loader_workers"])

        for data in dataloader:
            input, _ = data
            input = input.to(hparam_dict["device"])

            # Forward pass
            output = autoencoder(input)
            loss = criterion(output, input)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Log loss to tensorboard
        writer.add_scalar("Reconstruction loss", loss.item(), epoch, time.time())

        # Prepare the last reconstructed image 
        detached_tensor = output.detach().cpu()

        # Save the input and output images
        writer.add_image("Input", make_grid(input), epoch)
        writer.add_image("Output", make_grid(detached_tensor), epoch)

        # Clear download folder
        if hparam_dict["use_batch_download"]:
            os.system("rm download/earth/*")

        # Calculate epoch time
        epoch_end = time.time()
        epoch_time = epoch_end - epoch_start

        # Convert to minutes
        epoch_time = epoch_time / 60

        # Add epoch_time to epoch_time_list
        epoch_time_list.append(epoch_time)

        # Log epoch time to tensorboard
        writer.add_scalar("Epoch time (minutes)", epoch_time, epoch)
        


    # Calculate average epoch time
    avg_epoch_time = sum(epoch_time_list) / len(epoch_time_list)

    # Add avg_epoch_time to metric_dict
    metric_dict["Average epoch time (minutes)"] = avg_epoch_time

    # Add total training time to metric_dict
    metric_dict["Total training time (minutes)"] = sum(epoch_time_list)

    # Add loss to metric_dict
    metric_dict["Last reconstruction loss"] = loss.item()

    # Configure hparam_dict and metric_dict for tensorboard
    hparam_dict = {k: str(v) for k, v in hparam_dict.items()}
    metric_dict = {k: float(v) for k, v in metric_dict.items()}

    # Add hparam_dict and metric_dict to tensorboard
    writer.add_hparams(hparam_dict, metric_dict)


In [11]:
# Train with hparam_dict as is at top of file
train(hparam_dict=hparam_dict, metric_dict=metric_dict)

# Train with hparam_dict modified
hparam_dict["batch_size"] = 2
hparam_dict["loader_workers"] = 2

train(hparam_dict=hparam_dict, metric_dict=metric_dict)

# Train with hparam_dict modified
hparam_dict["batch_size"] = 4
hparam_dict["loader_workers"] = 4

train(hparam_dict=hparam_dict, metric_dict=metric_dict)

# Train with hparam_dict modified
hparam_dict["loader_workers"] = 6
hparam_dict["batch_size"] = 6

train(hparam_dict=hparam_dict, metric_dict=metric_dict)

# Train with hparam_dict modified
hparam_dict["batch_size"] = 8
hparam_dict["loader_workers"] = 8

train(hparam_dict=hparam_dict, metric_dict=metric_dict)

# Close tensorboard writer
writer.close()

Output()

In [None]:
# Save the last model
torch.save(autoencoder.state_dict(), 'autoencoder.pth')

In [None]:
# Terminate tensorboard
tensorboard_process.terminate()

# Wait for termination
tensorboard_process.wait()

# Get the return code
return_code = tensorboard_process.returncode
print(f"Return code of tensorboard subprocess: {return_code}")