## Compare the trained models

In [1]:
# Import necessary libraries
import torch
from mimo.models.mimo_unet import MimoUnetModel
from mimo.tasks.sen12tp.sen12tp_datamodule import get_datamodule

# Assuming the model configuration is already set, just like in the training script

# Define the arguments or use a config dictionary
args = {
    'seed': 42,
    'checkpoint_path': 'MIMO_NDVI_Prediction_Gauss/1yt2l40t/checkpoints/last.ckpt',  # Provide the path to your checkpoint
    'num_loss_function_params': 2,
    'num_subnetworks': 2,
    'filter_base_count': 64,
    'center_dropout_rate': 0.1,
    'final_dropout_rate': 0.1,
    'encoder_dropout_rate': 0.1,
    'core_dropout_rate': 0.1,
    'decoder_dropout_rate': 0.1,
    'loss_buffer_size': 1024,
    'loss_buffer_temperature': 0.5,
    'input_repetition_probability': 0.0,
    'batch_repetitions': 1,
    'loss': 'mse',  # Adjust based on your actual loss function
    'weight_decay': 0.0001,
    'learning_rate': 0.0001,
}

# Initialize the datamodule (this can be modified based on your setup)
dm = get_datamodule(args)

# Instantiate the model (ensure to use the same parameters as during training)
model = MimoUnetModel(
    in_channels=len(dm.model_inputs),
    out_channels=len(dm.model_targets) * args['num_loss_function_params'],
    num_subnetworks=args['num_subnetworks'],
    filter_base_count=args['filter_base_count'],
    center_dropout_rate=args['center_dropout_rate'],
    final_dropout_rate=args['final_dropout_rate'],
    encoder_dropout_rate=args['encoder_dropout_rate'],
    core_dropout_rate=args['core_dropout_rate'],
    decoder_dropout_rate=args['decoder_dropout_rate'],
    loss_buffer_size=args['loss_buffer_size'],
    loss_buffer_temperature=args['loss_buffer_temperature'],
    input_repetition_probability=args['input_repetition_probability'],
    batch_repetitions=args['batch_repetitions'],
    loss=args['loss'],
    weight_decay=args['weight_decay'],
    learning_rate=args['learning_rate'],
    seed=args['seed'],
)

# Load the checkpoint
checkpoint = torch.load(args['checkpoint_path'])
model.load_state_dict(checkpoint['model_state_dict'])

# Set the model to evaluation mode
model.eval()

# Now we can use the model to make predictions
test_data = dm.test_dataloader()  # Ensure the dataloader is set up for your test dataset

# Make predictions
predictions = []
with torch.no_grad():
    for inputs in test_data:
        outputs = model(inputs)
        predictions.append(outputs)

# Optionally save the predictions (or process them further)
for i, pred in enumerate(predictions):
    save_image(pred, f'prediction_{i}.png')  # Replace with actual saving function if needed



ModuleNotFoundError: No module named 'lightning'

In [None]:
from mimo.models.mimo_unet import MimoUnetModel

In [None]:
import torch

# Assuming your model is defined as `model` and the checkpoint is saved as 'checkpoint.pth'
checkpoint_path = 'MIMO_NDVI_Prediction_Gauss/1yt2l40t/checkpoints/last.ckpt'

# Load the checkpoint
checkpoint = torch.load(checkpoint_path)

# Assuming the checkpoint contains 'model_state_dict' for the model weights
model.load_state_dict(checkpoint['model_state_dict'])

# Set the model to evaluation mode
model.eval()