## Compare the trained models
Load the checkpoints, get the test datamodel up and running, access the two channel output.

In [1]:
import torch
from mimo.models.mimo_unet import MimoUnetModel
from mimo.tasks.sen12tp.sen12tp_datamodule import get_datamodule

In [None]:
# defining args within the notebook
args = {
    'seed': 1,
    'checkpoint_path': 'MIMO_NDVI_Prediction_Gauss/1yt2l40t/checkpoints/last.ckpt',
    'dataset_dir': '/deepskieslab/rnevin/zenodo_data',
    'batch_size': 32,
    'num_loss_function_params': 2,
    'num_subnetworks': 2,
    'filter_base_count': 30,
    'center_dropout_rate': 0.1,
    'final_dropout_rate': 0.1,
    'encoder_dropout_rate': 0.0,
    'core_dropout_rate': 0.0,
    'decoder_dropout_rate': 0.0,
    'loss_buffer_size': 10,
    'loss_buffer_temperature': 0.3,
    'input_repetition_probability': 0.0,
    'batch_repetitions': 1,
    'patch_size': 256,
    'stride': 249,
    'loss': 'gaussian_nll',  # Adjust based on your actual loss function
    'weight_decay': 0.0001,
    'learning_rate': 0.0001,
    'num_workers': 30,
    'training_set_percentage': 1.0,
}

args["input"] = ["VV_sigma0", "VH_sigma0"]
args["target"] = ["NDVI"]  # Example target

'''
python scripts/train/train_ndvi.py   --max_epochs 40   --batch_size 32   -t NDVI   -i VV_sigma0   -i VH_sigma0   --project "MIMO_NDVI_Prediction_Gauss"
'''

from argparse import Namespace

args = Namespace(**args)  # Convert dictionary to Namespace
dm = get_datamodule(args)


# Initialize the datamodule (this can be modified based on your setup)
dm = get_datamodule(args)

# Instantiate the model (ensure to use the same parameters as during training)
model = MimoUnetModel(
    in_channels=len(dm.model_inputs),
    out_channels=len(dm.model_targets) * args['num_loss_function_params'],
    num_subnetworks=args['num_subnetworks'],
    filter_base_count=args['filter_base_count'],
    center_dropout_rate=args['center_dropout_rate'],
    final_dropout_rate=args['final_dropout_rate'],
    encoder_dropout_rate=args['encoder_dropout_rate'],
    core_dropout_rate=args['core_dropout_rate'],
    decoder_dropout_rate=args['decoder_dropout_rate'],
    loss_buffer_size=args['loss_buffer_size'],
    loss_buffer_temperature=args['loss_buffer_temperature'],
    input_repetition_probability=args['input_repetition_probability'],
    batch_repetitions=args['batch_repetitions'],
    loss=args['loss'],
    weight_decay=args['weight_decay'],
    learning_rate=args['learning_rate'],
    seed=args['seed'],
)

# Load the checkpoint
checkpoint = torch.load(args['checkpoint_path'])
model.load_state_dict(checkpoint['model_state_dict'])

# Set the model to evaluation mode
model.eval()


In [None]:
print(model)

In [None]:

# Now we can use the model to make predictions
test_data = dm.test_dataloader()  # Ensure the dataloader is set up for your test dataset

# Make predictions
predictions = []
with torch.no_grad():
    for inputs in test_data:
        outputs = model(inputs)
        predictions.append(outputs)

# Optionally save the predictions (or process them further)
for i, pred in enumerate(predictions):
    save_image(pred, f'prediction_{i}.png')  # Replace with actual saving function if needed



In [None]:
from mimo.models.mimo_unet import MimoUnetModel

In [None]:
import torch

# Assuming your model is defined as `model` and the checkpoint is saved as 'checkpoint.pth'
checkpoint_path = 'MIMO_NDVI_Prediction_Gauss/1yt2l40t/checkpoints/last.ckpt'

# Load the checkpoint
checkpoint = torch.load(checkpoint_path)

# Assuming the checkpoint contains 'model_state_dict' for the model weights
model.load_state_dict(checkpoint['model_state_dict'])

# Set the model to evaluation mode
model.eval()