# Pre-training with Learning from Randomness (LFR) on the Parihaka dataset

## Imports

In [None]:
import torch
from torch import nn
from minerva.models.ssl import LearnFromRandomnessModel, RepeatedModuleList
from minerva.models.nets.image.deeplabv3 import DeepLabV3Backbone
from minerva.pipelines.lightning_pipeline import SimpleLightningPipeline
from minerva.data.data_modules.seismic_image import SeismicImageDataModule
from functools import partial
from common import get_trainer

## Variables

In [None]:
# Data
dataset_name = "seam_ai"        # Dataset name (just identifier)
data_dir = "/shared/datasets/seam_ai"
batch_size = 1
image_size = (1024, 1024)

# Model
model_name = "lfr_5"            # Model name (just identifier)
n_prediction_heads = 5

# Training
learning_rate = 1e-4
log_dir = "./logs"              # Directory to save logs
batch_size = 1                  # Batch size    
seed = 42                       # Seed for reproducibility
num_epochs = 100                # Number of epochs to train
is_debug = True                 # If True, only 3 batch will be processed for 3 epochs
accelerator = "gpu"             # CPU or GPU
devices = 1                     # Num GPUs

## Data Module

In [None]:
data_module = SeismicImageDataModule(
    root_dirs=data_dir,
    batch_size=batch_size,
    resize=image_size,
    labels=False
)

## Model

In [None]:
Projector = partial(
    nn.Sequential,
    nn.Conv2d(3, 32, 3, 2, 1),
    nn.ReLU(True),
    nn.Conv2d(32, 256, 3, 2, 1),
    nn.ReLU(True),
    nn.Conv2d(256, 2048, 3, 2, 1),
    nn.ReLU(True),
    nn.ConvTranspose2d(2048, 256, 3, 2, 1, 1),
    nn.ReLU(True),
    nn.ConvTranspose2d(256, 32, 3, 2, 1, 1),
    nn.ReLU(True),
    nn.ConvTranspose2d(32, 1, 3, 2, 1, 1)
)

Predictor = partial(
    nn.Sequential,
    nn.ConvTranspose2d(2048, 256, 3, 2, 1, 1),
    nn.ReLU(True),
    nn.ConvTranspose2d(256, 32, 3, 2, 1, 1),
    nn.ReLU(True),
    nn.ConvTranspose2d(32, 1, 3, 2, 1, 1)
)


LFR_model = LearnFromRandomnessModel(
    backbone=DeepLabV3Backbone(),
    projectors=RepeatedModuleList(n_prediction_heads, Projector),
    predictors=RepeatedModuleList(n_prediction_heads, Predictor),
    loss_fn=nn.MSELoss(),
    learning_rate=learning_rate
)

LFR_model

## Pipeline

In [None]:
trainer = get_trainer(
    model_name,
    dataset_name,
    log_dir,
    num_epochs,
    accelerator,
    devices,
    is_debug,
)

pipeline = SimpleLightningPipeline(LFR_model, trainer, log_dir)

## Run!

In [None]:
pipeline.run(data_module, task="fit")