In [1]:
from dae_model import ConvDAE
from dataGenImg import DataGen
from utils import utils
import pandas as pd
import os
from torch.utils.data import Dataset, DataLoader
import torch
from tqdm import tqdm

import matplotlib.pyplot as plt
import numpy as np

def plot_images_side_by_side(inputs, outputs, num_images=4):
    fig, axes = plt.subplots(num_images, 2, figsize=(6, num_images * 5))
    
    for i in range(num_images):
        # Plot input image
        axes[i, 0].imshow(inputs[i, 0], cmap='gray')
        axes[i, 0].set_title(f'Input {i+1}')
        axes[i, 0].axis('off')
        
        # Plot output image
        axes[i, 1].imshow(outputs[i, 0], cmap='gray')
        axes[i, 1].set_title(f'Output {i+1}')
        axes[i, 1].axis('off')
    
    plt.tight_layout()
    plt.show()


  from .autonotebook import tqdm as notebook_tqdm


## Datafeeder

In [2]:
config_path = 'config.json'
config = utils.Configuration(config_path)

ids = os.listdir(config.datapath)
ids = sorted(ids)
print(len(ids))


train_dataset = DataGen(ids,  config_path=config_path)


train_loader = DataLoader(train_dataset, batch_size=config.batch_size)


4890


## Model

In [3]:
model = ConvDAE()
state_dict = torch.load('./train-results/elemodel-1/epoch_100_model.pth', map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
model.eval()

ConvDAE(
  (encoder): Encoder(
    (encoder): Sequential(
      (0): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05

In [7]:
embedding_df = pd.DataFrame()
with torch.no_grad():
    for i, data in tqdm(enumerate(train_loader), total=len(ids)):
        inputs, label, id = data
        outputs, embedding = model(inputs)

        embedding_np = embedding.cpu().numpy()
        
        # Create a temporary DataFrame
        coord = id[0][:-4].split('_')
        temp_df = pd.DataFrame({
            'cell-id': id[0][3:7],
            'lat':float(coord[1][2:]),
            'lon':float(coord[2][2:]),
            # 'id': id[0],
            'embedding': [list(embedding_np[0])]
        })
        
        # Append to the main DataFrame
        embedding_df = pd.concat([embedding_df, temp_df], ignore_index=True)

        # print(id)

        # if i==2:
        #     break

embedding_df.to_csv('elevation-crops-embedding.csv', index=False)
embedding_df

100%|██████████| 4890/4890 [02:22<00:00, 34.28it/s]


Unnamed: 0,cell-id,lat,lon,embedding
0,0001,35.064,32.263,"[-2.521943, -0.5231237, -0.40994224, 0.5443020..."
1,0002,35.076,32.263,"[-2.521943, -0.5231237, -0.40994224, 0.5443020..."
2,0003,35.040,32.274,"[-2.0220041, -1.1064456, -0.87099844, 0.289591..."
3,0004,35.052,32.274,"[-0.6397823, 0.16071472, -0.08685991, 0.452104..."
4,0005,35.064,32.274,"[0.7283832, -1.438596, -0.32949847, 0.23495379..."
...,...,...,...,...
4885,4886,34.980,34.078,"[-1.7025188, 0.42279524, -1.245933, 0.48889568..."
4886,4887,34.992,34.078,"[-2.1138537, -0.45107505, -0.47971603, 0.37905..."
4887,4888,35.004,34.078,"[-2.521943, -0.5231237, -0.40994224, 0.5443020..."
4888,4889,34.956,34.089,"[-2.030232, 0.10630721, -0.11729323, 0.1701394..."


In [39]:
embedding_df

Unnamed: 0,cell-id,lat,lon,embedding
0,0001,35.064,32.263,"[-2.521943, -0.5231237, -0.40994224, 0.5443020..."
1,0002,35.076,32.263,"[-2.521943, -0.5231237, -0.40994224, 0.5443020..."
2,0003,35.040,32.274,"[-2.0220041, -1.1064456, -0.87099844, 0.289591..."
3,0004,35.052,32.274,"[-0.25739184, -1.6701397, -1.008265, -0.126257..."
4,0005,35.064,32.274,"[0.7283832, -1.438596, -0.32949847, 0.23495379..."
...,...,...,...,...
4885,4886,34.980,34.078,"[-1.7025188, 0.42279524, -1.245933, 0.48889568..."
4886,4887,34.992,34.078,"[-2.1138537, -0.45107505, -0.47971603, 0.37905..."
4887,4888,35.004,34.078,"[-2.521943, -0.5231237, -0.40994224, 0.5443020..."
4888,4889,34.956,34.089,"[-2.030232, 0.10630721, -0.11729323, 0.1701394..."


In [None]:
# plot_images_side_by_side(inputs, outputs, num_images=8)