In [1]:
import cv2
import glob
import numpy as np
import os
import torch
import torch.nn as nn
import torch.onnx
import urllib

In [2]:
# Define contents here

DATA_PATH = 'data'

In [84]:
# Create data dir, if it doesn't exist

if not os.path.exists(DATA_PATH):
    os.makedirs(DATA_PATH)

# Download the data
# https://wiki.earthdata.nasa.gov/display/GIBS/GIBS+API+for+Developers
for i in range(1, 11):
    for j in range(1, 11):
        img_path = f'https://gibs.earthdata.nasa.gov/wmts/epsg4326/best/MODIS_Terra_CorrectedReflectance_TrueColor/default/2012-07-09/250m/6/{i}/{j}.jpg'
        urllib.request.urlretrieve(img_path, f'{DATA_PATH}/{i}_{j}.jpg')

In [74]:
# Load the data into a torch.Tensor (swap RB channels and normalize to 0-1)
data = np.array([cv2.imread(f)[..., ::-1] / 255.0 for f in sorted(glob.glob(f'{DATA_PATH}/*.jpg'), key=os.path.getmtime)])
# data = np.array([cv2.imread("data/1_1.jpg")[..., ::-1] / 255.0])
input_data = torch.Tensor(np.transpose(data, (0, 3, 1, 2)))

In [78]:
# Define a simple torch model.

class ImageCompress(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 1, 3)
        self.conv.weight = torch.nn.Parameter(torch.ones_like(self.conv.weight))
        self.lin = nn.Linear(510 * 510, 10)
        self.lin.weight = torch.nn.Parameter(torch.ones_like(self.lin.weight))

    def forward(self, x):
        x = self.conv(x)
        x = torch.flatten(x, 1) / 10**9
        x = self.lin(x)
        return x 

In [80]:
# Run inference on the data using the model
model = ImageCompress()
model.eval()
output = model(input_data)
print(output)

tensor([[ 5.7454e-03,  4.6327e-03,  6.2564e-03,  4.8822e-03,  4.5342e-03,
          6.9192e-03,  6.7266e-03,  4.2350e-03,  5.7192e-03,  4.9231e-03],
        [ 5.7737e-03,  4.6609e-03,  6.2847e-03,  4.9105e-03,  4.5625e-03,
          6.9474e-03,  6.7549e-03,  4.2633e-03,  5.7475e-03,  4.9514e-03],
        [ 5.9659e-03,  4.8531e-03,  6.4768e-03,  5.1026e-03,  4.7546e-03,
          7.1396e-03,  6.9471e-03,  4.4554e-03,  5.9396e-03,  5.1436e-03],
        [ 5.7265e-03,  4.6137e-03,  6.2375e-03,  4.8633e-03,  4.5152e-03,
          6.9002e-03,  6.7077e-03,  4.2160e-03,  5.7003e-03,  4.9042e-03],
        [ 5.7271e-03,  4.6143e-03,  6.2381e-03,  4.8639e-03,  4.5158e-03,
          6.9008e-03,  6.7083e-03,  4.2166e-03,  5.7009e-03,  4.9048e-03],
        [ 5.6914e-03,  4.5787e-03,  6.2024e-03,  4.8282e-03,  4.4802e-03,
          6.8652e-03,  6.6726e-03,  4.1810e-03,  5.6652e-03,  4.8691e-03],
        [ 5.7482e-03,  4.6355e-03,  6.2592e-03,  4.8850e-03,  4.5370e-03,
          6.9220e-03,  6.7294e-0

In [92]:
# Export the model to ONNX.
# See the documentation here: https://pytorch.org/docs/master/onnx.html

# Create a dummy input to the model
sample_input = torch.randn(1, *(input_data[0].shape), requires_grad=True)
sample_output = model(sample_input)

# Export the model
torch.onnx.export(
    model,
    sample_input,
    "image_compression_model.onnx",
    export_params=True,
    input_names = ['input'],
    output_names = ['output'],
    dynamic_axes={
        'input' : {0 : 'batch_size'},
        'output' : {0 : 'batch_size'}
    }
)

tensor([[[[ -2.8951,  -1.4935,  -3.8754,  ...,  -4.3602,  -1.9693,  -6.4406],
          [  3.3443,   9.6906,   4.9512,  ...,  -4.3874,  -2.8110,  -6.7794],
          [  5.0252,   9.5156,   6.4991,  ...,   1.5959,  -0.8888,  -4.1813],
          ...,
          [-12.0944,  -9.7453,  -8.1732,  ...,  11.6588,   8.8020,   6.6534],
          [ -3.1033,  -0.6967,  -4.6561,  ...,   8.3263,   4.7642,   3.8585],
          [ -2.4585,  -1.5185,  -7.4515,  ...,   3.2237,   0.6477,  -1.0087]]]],
       grad_fn=<MkldnnConvolutionBackward>)
tensor([[[[ -2.8951,  -1.4935,  -3.8754,  ...,  -4.3602,  -1.9693,  -6.4406],
          [  3.3443,   9.6906,   4.9512,  ...,  -4.3874,  -2.8110,  -6.7794],
          [  5.0252,   9.5156,   6.4991,  ...,   1.5959,  -0.8888,  -4.1813],
          ...,
          [-12.0944,  -9.7453,  -8.1732,  ...,  11.6588,   8.8020,   6.6534],
          [ -3.1033,  -0.6967,  -4.6561,  ...,   8.3263,   4.7642,   3.8585],
          [ -2.4585,  -1.5185,  -7.4515,  ...,   3.2237,   0.6477