# VisNIR LSSM for Kex

> Ensembling Resnet, ... to predict exchangeable potassium in soil

In [1]:
%load_ext autoreload
%autoreload 2

## Imports

In [4]:
from pathlib import Path
from functools import partial

from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split

import timm
from torcheval.metrics import R2Score
from torch.optim import lr_scheduler

from lssm.loading import load_ossl
from lssm.preprocessing import ToAbsorbance, ContinuumRemoval, Log1p
from lssm.dataloaders import SpectralDataset, get_dls
from lssm.callbacks import (MetricsCB, BatchSchedCB, BatchTransformCB,
                            DeviceCB, TrainCB, ProgressCB)
from lssm.transforms import GADFTfm, _resizeTfm, StatsTfm
from lssm.learner import Learner


## Data loading & preprocessing

In [None]:
analytes = 'k.ext_usda.a725_cmolc.kg'
data = load_ossl(analytes, spectra_type='visnir')
X, y, X_names, smp_idx, ds_name, ds_label = data

X = Pipeline([('to_abs', ToAbsorbance()), 
              ('cr', ContinuumRemoval(X_names))]).fit_transform(X)
y = Log1p().fit_transform(y)

In [10]:
# Train/valid split
n_smp = None # For demo. purpose
X_train, X_valid, y_train, y_valid = train_test_split(X[:n_smp, :], y[:n_smp], 
                                                      test_size=0.1,
                                                      stratify=ds_name[:n_smp], 
                                                      random_state=41)

# Get PyTorch datasets
train_ds, valid_ds = [SpectralDataset(X, y, ) 
                      for X, y, in [(X_train, y_train), (X_valid, y_valid)]]

# Then PyTorch dataloaders
dls = get_dls(train_ds, valid_ds, bs=32)


100%|██████████| 44489/44489 [00:15<00:00, 2859.72it/s]


## DL model ensembling

In [12]:
model_name = 'resnet18'
model = timm.create_model(model_name, pretrained=True, in_chans=1, num_classes=1)

# Define modelling pipeline & Train
epochs = 1
lr = 5e-3

metrics = MetricsCB(r2=R2Score())

tmax = epochs * len(dls.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)

xtra = [BatchSchedCB(sched)]

gadf = BatchTransformCB(GADFTfm())
resize = BatchTransformCB(_resizeTfm)
stats = BatchTransformCB(StatsTfm(model.default_cfg))

cbs = [DeviceCB(), gadf, resize, stats, TrainCB(), 
       metrics, ProgressCB(plot=False)]

learn = Learner(model, dls, nn.MSELoss(), lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW)

learn.fit(epochs)

r2,loss,epoch,train
0.36,0.095,0,train
0.539,0.068,0,eval
