In [1]:
import os
import pickle
import random
import sys
import warnings

sys.path.append(os.path.join(".."))

import pytorch_lightning as pl
import torch
from src.model_utils import custom_multiclass_report, CroplandDataModule_MLP, Crop_MLP, Crop_PL
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F

torch.set_float32_matmul_precision('medium')

In [39]:
X['Train'][0].shape

(10764884, 10, 12)

In [13]:
torch.tensor(X['Train'][0]).permute(0, 2, 1).shape

torch.Size([10764884, 12, 10])

In [38]:
transformer_enc = nn.TransformerEncoder(nn.TransformerEncoderLayer(
            d_model=10,
            nhead=5,
            dim_feedforward=128,
            dropout=0.2,
            activation="gelu",
            batch_first=True,
        ), num_layers=4)

In [25]:
embedding = torch.FloatTensor(X['Train'][0]).permute(0, 2, 1)

In [29]:
transformer_enc(embedding[:128])

torch.Size([128, 12, 10])

In [30]:
nn.Flatten()(transformer_enc(embedding[:128])).shape

torch.Size([128, 120])

In [2]:
num_cuda_devices = torch.cuda.device_count()
print(f"Number of CUDA devices available: {num_cuda_devices}")

for i in range(num_cuda_devices):
    device = torch.device(f'cuda:{i}')
    properties = torch.cuda.get_device_properties(device)
    print(f"Device {i}: {properties.name}, Compute Capability: {properties.major}.{properties.minor}")

Number of CUDA devices available: 4
Device 0: Quadro RTX 8000, Compute Capability: 7.5
Device 1: Quadro RTX 8000, Compute Capability: 7.5
Device 2: NVIDIA GeForce RTX 3060, Compute Capability: 8.6
Device 3: NVIDIA GeForce RTX 3060, Compute Capability: 8.6


In [3]:
class Crop_MLP(nn.Module):
    """
    A multi-layer perceptron (MLP) used for crop classification.

    Args:
        input_size (int): The number of input features (default: 164).
        output_size (int): The number of output logits (default: 4).

    Inputs:
        X (torch.Tensor): A tensor of shape (batch_size, input_size) containing input data.

    Returns:
        torch.Tensor: A tensor of shape (batch_size, output_size) containing the output logits.
    """

    def __init__(self, input_size=164, output_size=4) -> None:
        super(Crop_MLP, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(input_size, 2 * input_size),
            nn.BatchNorm1d(2 * input_size),
            nn.LeakyReLU(),
            nn.Dropout(0.7),
            nn.Linear(2 * input_size,  input_size),
            nn.BatchNorm1d(input_size),
            nn.LeakyReLU(),
            nn.Dropout(0.3),
            nn.Linear(input_size, input_size // 2),
            nn.BatchNorm1d(input_size // 2),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
            nn.Linear(input_size // 2, output_size),
        )
    
    def initialize_bias_weights(self, y_train):
        """
        Initialize the bias weights of the final linear layer based on the class distribution.

        Args:
            y_train (torch.Tensor): A tensor of shape (num_samples,) containing the target labels.
        """
        _, class_counts = torch.unique(y_train, return_counts=True)
        total_samples = len(y_train)
        class_distribution = class_counts.float() / total_samples

        # Initialize bias weights for the final linear layer
        bias_weights = -torch.log(class_distribution)
        self.net[-1].bias.data = bias_weights

    def forward(self, X) -> torch.Tensor:
        output = self.net(X)
        return F.log_softmax(output, dim=1)

### Read from file

In [7]:
# Read dictionary pkl file
with open(os.path.join("..", "data", "processed_files", "pkls", "X_FR_RUS_ROS_lstm.pkl"), "rb") as fp:
    X = pickle.load(fp)

with open(os.path.join("..", "data", "processed_files", "pkls", "y_FR_RUS_ROS_lstm.pkl"), "rb") as fp:
    y = pickle.load(fp)

In [5]:
# initilize data module
dm = CroplandDataModule_MLP(X=X, y=y, batch_size=256)

# initilize model
warnings.filterwarnings("ignore")
torch.manual_seed(123)
random.seed(123)

network = Crop_MLP()
# network.initialize_bias_weights(dm.y_train.argmax(dim=1))
model = Crop_PL(net=network)

# initilize trainer
early_stop_callback = EarlyStopping(
    monitor="val/loss", min_delta=1e-4, patience=20, verbose=True, mode="min"
)
lr_monitor = LearningRateMonitor(logging_interval="epoch")


In [None]:
trainer = pl.Trainer(
    max_epochs=500,
    accelerator='gpu',
    precision=16,
    devices=[3],
    benchmark=True,
    check_val_every_n_epoch=1,
    callbacks=[early_stop_callback, lr_monitor],
)
trainer.fit(model, dm)


In [None]:
from src.model_utils import custom_multiclass_report

# check metrics
predictions = torch.cat(trainer.predict(model, DataLoader(dm.X_test, batch_size=2048)), dim=0)
softmax = nn.Softmax(dim=1)
yprob = softmax(predictions.float())
ypred = torch.argmax(yprob, 1)
ytest = torch.argmax(dm.y_test, 1).cpu().numpy()


print(custom_multiclass_report(ytest, ypred, yprob))

In [None]:
# Save the module to a file
torch.save(model, 'lightning_logs/version_100/my_module.pth')