In [None]:
import os
import pickle
import random
import sys
import warnings

sys.path.append(os.path.join(".."))

import pytorch_lightning as pl
import torch
# from src.model_utils import custom_multiclass_report, CroplandDataModule_LSTM, Crop_LSTM, Crop_PL
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torch import nn
from torch.utils.data import DataLoader, TensorDataset, Dataset

In [None]:
# Read dictionary pkl file
with open(os.path.join('..', 'data', 'processed_files', 'pkls', 'X_FR_RUS_ROS_lstm.pkl'), "rb") as fp:
    X = pickle.load(fp)

with open(os.path.join('..', 'data', 'processed_files', 'pkls', 'y_FR_RUS_ROS_lstm.pkl'), "rb") as fp:
    y = pickle.load(fp)


In [None]:
class CroplandDataset(Dataset):
    def __init__(self, X, y):
        self.X_monthly = X[0]  
        self.X_static = X[1] 
        self.y = y 

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        x_monthly = self.X_monthly[idx]
        x_static = self.X_static[idx]
        target = self.y[idx]
        
        return (x_monthly, x_static), target

In [None]:
class CroplandDataModule_LSTM(pl.LightningDataModule):
    """
    This module defines a LightningDataModule class for loading and preparing data for a Cropland classification model using LSTM architecture.

    Args:
    X (dict): A dictionary containing the input data for Train, Validation, and Test sets.
    y (dict): A dictionary containing the corresponding target values for Train, Validation, and Test sets.
    batch_size (int): The batch size to be used for training and evaluation. Default is 128.
    """

    def __init__(self, X: dict, y: dict, batch_size: int = 128):
        super().__init__()
        self.batch_size = batch_size
        self.X_monthly_train, self.X_monthly_val, self.X_monthly_test = (
            torch.FloatTensor(X["Train"][0]),
            torch.FloatTensor(X["Val"][0]),
            torch.FloatTensor(X["Test"][0]),
        )
        self.X_static_train, self.X_static_val, self.X_static_test = (
            torch.FloatTensor(X["Train"][1]),
            torch.FloatTensor(X["Val"][1]),
            torch.FloatTensor(X["Test"][1]),
        )
        self.y_train, self.y_val, self.y_test = (
            torch.LongTensor(y["Train"]),
            torch.LongTensor(y["Val"]),
            torch.LongTensor(y["Test"]),
        )

        self.dl_dict = {"batch_size": self.batch_size}

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.dataset_train = CroplandDataset((self.X_monthly_train, self.X_static_train), self.y_train)
            self.dataset_val = CroplandDataset((self.X_monthly_val, self.X_static_val), self.y_val)

        if stage == "test" or stage is None:
            self.dataset_test = CroplandDataset((self.X_monthly_test, self.X_static_test), self.y_test)

    def train_dataloader(self):
        return DataLoader(self.dataset_train, shuffle=True, **self.dl_dict)

    def val_dataloader(self):
        return DataLoader(self.dataset_val, **self.dl_dict)

    def test_dataloader(self):
        return DataLoader(self.dataset_test, **self.dl_dict)

In [None]:
class Crop_Conv_LSTM(nn.Module):
    """
    A PyTorch module implementing a Crop LSTM network.

    The Crop_LSTM module takes as input a sequence of feature vectors and applies a multi-layer LSTM network
    followed by two linear layers with ReLU activation to predict the output.

    Args:
    input_size (int): The number of expected features in the input (default: 52).
    hidden_size (int): The number of features in the hidden state (default: 104).
    num_layers (int): Number of recurrent layers (default: 4).
    output_size (int): The number of output logits (default: 4).

    Inputs:
    X (torch.Tensor): A tensor of shape (batch_size, sequence_length, input_size) containing the input sequence.

    Outputs:
    out (torch.Tensor): A tensor of shape (batch_size, output_size) containing the output logits.

    """

    def __init__(
        self,
        input_size=12,
        hidden_size=68,
        num_layers=2,
        output_size=4,
        dropout=0.2,
    ) -> None:
        super(Crop_LSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=input_size,
                            hidden_size=input_size * 2,
                            num_layers=num_layers,
                            batch_first=True,
                            dropout=dropout
                            )
        self.net = nn.Sequential(
            nn.Linear(hidden_size, 2 * hidden_size),
            nn.BatchNorm1d(2 * hidden_size),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
            nn.Linear(2 * hidden_size, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.LeakyReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size, hidden_size // 2),
            nn.BatchNorm1d(hidden_size // 2),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_size // 2, output_size),
        )
    def forward(self, X):
        out, _ = self.lstm(X[0])
        out = out[:, -1, :]
        out = self.net(torch.cat((out, X[1]), dim=1))
        return F.log_softmax(out, dim=1)

In [None]:
# initilize data module
dm = CroplandDataModule_LSTM(X=X, y=y, batch_size=128)

# initilize model
warnings.filterwarnings("ignore")
torch.manual_seed(123)
random.seed(123)
            
network = Crop_LSTM()
model = Crop_PL(net=network)

# initilize trainer
early_stop_callback = EarlyStopping(
    monitor="val/loss", min_delta=1e-4, patience=30, verbose=True, mode="min"
)
lr_monitor = LearningRateMonitor(logging_interval="epoch")

trainer = pl.Trainer(
    max_epochs=500,
    accelerator="gpu",
    precision=16,
    devices=[3],
    benchmark=True,
    check_val_every_n_epoch=1,
    callbacks=[early_stop_callback, lr_monitor],
)
trainer.fit(model, dm)