### Imports

In [1]:
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

In [2]:
%load_ext autoreload
%autoreload 2

In [50]:
import sys
from pathlib import Path
scripts_path = Path("../Data-Preprocessing/").resolve()
sys.path.append(str(scripts_path))

scripts_path = Path("../Evaluation/").resolve()
sys.path.append(str(scripts_path))

In [51]:
import pickle
from sklearn.cluster import KMeans
from sklearn.neighbors import NearestCentroid
from scripts.data_visualiser import *
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from Modeling.model_scripts.subpatch_extraction import *
from scripts.data_loader import *
from scripts.data_preprocessor import *
from scripts.temporal_data_preprocessor import *
from scripts.temporal_data_loader import *
from scripts.temporal_visualiser import *
from scripts.temporal_chanel_refinement import *
from model_scripts.get_statistics import *
from model_scripts.pre_trained_temporal import *
from model_scripts.dataset_creation import *
from model_scripts.train_model_ae import *
from model_scripts.model_visualiser import *
from model_scripts.clustering import *
from evaluation_scripts.result_visualiser import *
from evaluation_scripts.label_helper import *
from Pipeline.pre_processing_pipeline import *
from Pipeline.temporal_preprocessing_pipeline import *
import numpy as np
import preprocessing_config as config
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score
from sklearn.cluster import DBSCAN
from sklearn.decomposition import PCA
import skimage.measure
import torch
import torch.nn as nn
import torch.optim as optim

### Loading the data

In [5]:
# temp_pipeline = PreProcessingPipelineTemporal()

In [6]:
# train_fn, dataloader_train = temp_pipeline.get_processed_trainloader(64, 'indexbands', vi_type='msi')

Train

In [7]:
temporal_images = load_field_images_temporal(config.base_directory_temporal_train1)
border_removed_images_train = blacken_field_borders_temporal(temporal_images)
field_numbers_train, indices_images_train = allbands_temporal_cubes(border_removed_images_train)

len(indices_images_train), indices_images_train[0][0].shape

(2425, (64, 64, 10))

Evaluation

In [8]:
temporal_images_eval = load_field_images_temporal(config.base_directory_temporal_test1)
border_removed_images = blacken_field_borders_temporal(temporal_images_eval)
field_numbers_eval, indices_images_eval = allbands_temporal_cubes(border_removed_images)

(len(indices_images_eval), indices_images_eval[0][0].shape)

(48, (64, 64, 10))

In [9]:
image_tensor_train = np.stack(indices_images_train)  # Shape: (N x 7 x 64 x 64 x 6)
image_tensor_eval = np.stack(indices_images_eval)   # Shape: (N x 7 x 64 x 64 x 6)

image_tensor_train.shape, image_tensor_eval.shape

((2425, 7, 64, 64, 10), (48, 7, 64, 64, 10))

In [10]:
image_tensor_train = torch.tensor(image_tensor_train, dtype=torch.float32).permute(0, 1, 4, 2, 3)  # (N, T, H, W, C) -> (N, T, C, H, W)
image_tensor_eval = torch.tensor(image_tensor_eval, dtype=torch.float32).permute(0, 1, 4, 2, 3)  # (N, T, H, W, C) -> (N, T, C, H, W)
image_tensor_train.shape, image_tensor_eval.shape

(torch.Size([2425, 7, 10, 64, 64]), torch.Size([48, 7, 10, 64, 64]))

### Mini-patches

In [11]:
train_patches, train_patch_coordinates = non_overlapping_sliding_window(image_tensor_train, field_numbers_train, patch_size=5)
eval_patches, eval_patch_coordinates = non_overlapping_sliding_window(image_tensor_eval, field_numbers_eval, patch_size=5)
len(eval_patches), eval_patches[0].shape

(867, torch.Size([7, 10, 5, 5]))

In [12]:
len(train_patches), len(train_patch_coordinates)

(39042, 39042)

In [24]:
train_patches = np.stack(train_patches)  # Shape: (N x 7 x 6 x 5 x 5)
eval_patches = np.stack(eval_patches)   # Shape: (N x 7 x 6 x 5 x 5)

In [14]:
train_patch_coordinates[0]

('1167134.0', 25, 25)

In [15]:
train_coord_dataloader = field_nos_dataloader(train_patch_coordinates)
eval_coord_dataloader = field_nos_dataloader(eval_patch_coordinates)
train_coord_dataloader[0]

'1167134.0_25_25'

In [16]:
len(train_coord_dataloader)

39042

### Data Loaders

In [17]:
train_ratio = 0.8  

# Split patches and corresponding field numbers
train_patches, test_patches, train_field_numbers, test_field_numbers = train_test_split(
    train_patches, train_coord_dataloader, test_size=1-train_ratio, random_state=42
)

# Create train and test dataloaders
batch_size = 64
dataloader_train = create_data_loader(train_patches, train_field_numbers, batch_size=batch_size, shuffle=True)
dataloader_test = create_data_loader(test_patches, test_field_numbers, batch_size=batch_size, shuffle=False)

for batch_inputs, batch_field_numbers in dataloader_train:
    print("Train Batch Inputs Shape:", batch_inputs.shape)
    print("Train Batch Field Numbers:", batch_field_numbers)
    break  

for batch_inputs, batch_field_numbers in dataloader_test:
    print("Test Batch Inputs Shape:", batch_inputs.shape)
    print("Test Batch Field Numbers:", batch_field_numbers)
    break  

Train Batch Inputs Shape: torch.Size([64, 10, 7, 5, 5])
Train Batch Field Numbers: ('1222816.0_25_30', '1175368.0_35_10', '1223522.0_1223526.0_1223533.0_20_15', '1222460.0_1222465.0_1229355.0_30_15', '1228883.0_25_45', '1187379.0_25_30', '1224172.0_30_35', '1196307.0_1218729.0_1218733.0_1218746.0_1223915.0_40_35', '1223509.0_25_50', '1219440.0_1219441.0_1219442.0_1219443.0_1219444.0_45_45', '1219886.0_1219889.0_50_20', '1187577.0_25_25', '1225347.0_40_15', '1176302.0_1176303.0_1193921.0_1195992.0_30_25', '1218979.0_30_25', '1223174.0_1223178.0_25_45', '1187375.0_1226275.0_45_40', '1225197.0_30_25', '1222258.0_1222259.0_1228169.0_30_30', '1226145.0_30_45', '1185910.0_1228312.0_1228611.0_15_50', '1223690.0_1229444.0_25_30', '1226165.0_25_30', '1219845.0_1224190.0_25_45', '1224345.0_30_30', '1222370.0_25_25', '1216702.0_1216710.0_1216716.0_1216720.0_35_30', '1222237.0_1222240.0_1223306.0_30_50', '1175685.0_1175686.0_40_15', '1195751.0_30_30', '1226150.0_1226372.0_25_30', '1168695.0_122032

In [18]:
batch_size = 64
dataloader_eval = create_data_loader(eval_patches, eval_coord_dataloader, batch_size=batch_size, shuffle=False)

for batch_inputs, batch_field_numbers in dataloader_eval:
    print("Eval Batch Inputs Shape:", batch_inputs.shape) 
    print("Eval Batch Field Numbers:", batch_field_numbers)
    break  

Eval Batch Inputs Shape: torch.Size([64, 10, 7, 5, 5])
Eval Batch Field Numbers: ('1168039.0_20_25', '1168039.0_20_30', '1168039.0_20_35', '1168039.0_25_25', '1168039.0_25_30', '1168039.0_25_35', '1168039.0_30_25', '1168039.0_30_30', '1168039.0_30_35', '1168039.0_35_25', '1168039.0_35_30', '1168039.0_35_35', '1228889.0_25_15', '1228889.0_25_20', '1228889.0_25_25', '1228889.0_25_30', '1228889.0_25_35', '1228889.0_25_40', '1228889.0_30_15', '1228889.0_30_20', '1228889.0_30_25', '1228889.0_30_30', '1228889.0_30_35', '1228889.0_30_40', '1228889.0_35_15', '1228889.0_35_20', '1228889.0_35_25', '1228889.0_35_30', '1168663.0_1176271.0_25_35', '1168663.0_1176271.0_25_40', '1168663.0_1176271.0_30_20', '1168663.0_1176271.0_30_25', '1168663.0_1176271.0_30_30', '1168663.0_1176271.0_30_35', '1168663.0_1176271.0_30_40', '1168663.0_1176271.0_35_20', '1168663.0_1176271.0_35_25', '1168663.0_1176271.0_35_30', '1168663.0_1176271.0_35_35', '1168663.0_1176271.0_35_40', '1168692.0_1220431.0_15_20', '1168692.

### Pre-trained Models: Resnet3D

In [28]:
device = 'cuda'
train_patches.shape

(31233, 7, 10, 5, 5)

In [42]:
resnet3d_extractor = ResNet3DFeatureExtractor()
resnet3d_features_train, train_coord_dl = extract_features(resnet3d_extractor, dataloader_train, device)
print("ResNet3D Extracted Features Shape:", resnet3d_features_train.shape)  

ResNet3D Extracted Features Shape: torch.Size([31233, 512])


In [43]:
resnet3d_features_test, test_coord_dl = extract_features(resnet3d_extractor, dataloader_test, device)
print("ResNet3D Extracted Features Shape:", resnet3d_features_test.shape)  

ResNet3D Extracted Features Shape: torch.Size([7809, 512])


In [44]:
resnet3d_features_eval, eval_coord_dl = extract_features(resnet3d_extractor, dataloader_eval, device)
print("ResNet3D Extracted Features Shape:", resnet3d_features_eval.shape)  

ResNet3D Extracted Features Shape: torch.Size([867, 512])


### Evaluation

In [52]:
kmeans = train_kmeans_patches(resnet3d_features_train.cpu(), n_clusters=2, random_state=12)

train_patch_predictions = kmeans.predict(resnet3d_features_train.reshape(resnet3d_features_train.size(0), -1).numpy().astype(np.float32))
test_patch_predictions = kmeans.predict(resnet3d_features_test.reshape(resnet3d_features_test.size(0), -1).numpy().astype(np.float32))
eval_patch_predictions = kmeans.predict(resnet3d_features_eval.reshape(resnet3d_features_eval.size(0), -1).numpy().astype(np.float32))

  super()._check_params_vs_input(X, default_n_init=10)


In [53]:
# Assign field labels
threshold = 0.5
train_field_labels = assign_field_labels_ae(train_coord_dl, train_patch_predictions, threshold)
test_field_labels = assign_field_labels_ae(test_coord_dl, test_patch_predictions, threshold)
eval_field_labels = assign_field_labels_ae(eval_coord_dl, eval_patch_predictions, threshold)

In [54]:
accuracy, report, x_y_coords = evaluate_test_labels_ae(eval_field_labels, config.labels_path)
print(f"Test Accuracy: {accuracy}")
print(report)

Test Accuracy: 0.5245901639344263
              precision    recall  f1-score   support

           0       0.40      0.23      0.29        26
           1       0.57      0.74      0.64        35

    accuracy                           0.52        61
   macro avg       0.48      0.49      0.47        61
weighted avg       0.49      0.52      0.49        61

