<a href="https://colab.research.google.com/github/issmythe/ccai_crop_mapping/blob/main/tutorial_cleaned.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# OpenMapFlow Tutorial

### Sections
1. Installing OpenMapFlow
2. Exploring labeled earth observation data
3. Training a model
4. Doing inference over small region
5. Deploying of best model

### Prerequisites:
- Github account
- Github access token (obtained [here](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token))
- Forked OpenMapFlow repository
- Basic Python knowledge  

### Editable Google Doc for Q&A:
https://docs.google.com/document/d/1Kp6MphER1G5tdLYeAzl4n19S10TweIxiYT64rXsjKm4/edit?usp=sharing

## 1. Clone Github repo and install OpenMapFlow


In [None]:
!pip install "ipywidgets>=7,<8" -q # https://github.com/googlecolab/colabtools/issues/3020

In [None]:
#@title Git credentials
from ipywidgets import HTML, Password, Text, Textarea, VBox
inputs = [
      Password(description="Github Token:"),
      Text(description='Github Email:'),
      Text(description='Github User:'),
]
VBox(inputs)

In [None]:
#@title Clone directory
token = inputs[0].value
email = inputs[1].value
username = inputs[2].value

github_url_input = Textarea(value=f'https://github.com/{username}/openmapflow.git')
VBox([HTML(value="<b>Github Clone URL</b>"), github_url_input])

! git clone -q https://$token@github.com/nasaharvest/openmapflow.git

In [None]:
#@title Config
from pathlib import Path

github_url = github_url_input.value
project_name = "crop-mask-example" # maize-example
country_name = "Togo" # Kenya

for input_value in [token, email, username, github_url]:
  if input_value.strip() == "":
    raise ValueError("Found input with blank value.")

path_to_project = f"{Path(github_url).stem}/{project_name}"

!git config --global user.email $username
!git config --global user.name $email
!git clone {github_url.replace("https://", f"https://{username}:{token}@")}

%cd {path_to_project}

In [None]:
#@title Installs
!pip install openmapflow[all] -q &> /dev/null
!pip install dvc[gs] cmocean -q &> /dev/null

In [None]:
#@title Download GDAL
%%shell
GDAL_VERSION="3.6.4+dfsg-1~jammy0"
add-apt-repository -y ppa:ubuntugis/ubuntugis-unstable &> /dev/null
apt-get -qq update &> /dev/null
apt-get -qq install python3-gdal=$GDAL_VERSION gdal-bin=$GDAL_VERSION libgdal-dev=$GDAL_VERSION &> /dev/null

In [None]:
# CLI
!openmapflow

## 2. Exploring labeled earth observation data 🛰️



###Setup

In [None]:
# A Google Cloud Account is required to access the data
!gcloud auth application-default login

In [None]:
# Pull in data already available
! dvc pull &> /dev/null

In [None]:
# See report of data already available
! openmapflow datasets

### Exploring labels

In [None]:
#@title Imports + read data
import pandas as pd
import geopandas as gpd
from shapely.geometry import Point
from datasets import datasets, label_col
from openmapflow.constants import LAT, LON, DATASET, SUBSET

df = pd.concat([d.load_df(to_np=True) for d in datasets[:1]]) # Global only

In [None]:
#@title Convert pandas dataframe to geopandas dataframe
gdf = gpd.GeoDataFrame(df)
gdf["geometry"] = [Point(xy) for xy in zip(gdf[LON], gdf[LAT])]

In [None]:
#@title Plot labels
world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
ax = world.plot(figsize=(20,20), facecolor="lightgray")
ax.set_title("Label Locations")
ax.axis('off')
gdf.plot(
    ax=ax,
    marker='o',
    categorical=True,
    markersize=1,
    column=DATASET,
    legend=True,
    legend_kwds={'loc': 'lower left'});

## Similarity with Mexico

In [None]:
#@title Setup
import geopandas as gpd
import numpy as np

from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
#@title Read country shapefile and make dataframes
countries = gpd.read_file('/content/world_adm.geojson')
gdf.crs = countries.crs # NASA Harvest data doesn't have a CRS but this seems to work as expected
gdf_lab = countries[['name', 'geometry']].sjoin(gdf, how='inner', predicate='intersects')
gdf_lab = gdf_lab.drop('geometry', axis=1)

mex_train = pd.DataFrame(gdf_lab[gdf_lab['name'] == 'Mexico'])

np.random.seed(123)
mex_train = mex_train.sample(frac=1).reset_index(drop=True)
mex_train.loc[mex_train.index > int(len(mex_train) * 0.7), 'subset'] = 'validation'
mex_train.loc[mex_train.index > int(len(mex_train) * 0.9), 'subset'] = 'testing'

df = pd.concat([gdf_lab[gdf_lab['name'] != 'Mexico'], mex_train])


In [None]:
#@title Get principal components
ss = StandardScaler()

df['pca_data'] = df['eo_data'].apply(np.ravel)
df['pca_data'] = ss.fit_transform(np.stack(df['pca_data'])).tolist()

pca = PCA(n_components=10)
pcs = pca.fit_transform(np.stack(df['pca_data']))
print(pca.explained_variance_ratio_)

for i in range(5):
    df[f'pc{i}'] = pcs[:, i]

In [None]:
#@title Map principal component values
col = 'pc2'
vmin=df[col].quantile(0.01)
vmax=df[col].quantile(0.99)

sns.set_style('white')
f, ax = plt.subplots(figsize=[18, 9])
ax.set_axis_off()

countries.plot(color='none', edgecolor='darkgray', ax=ax)
plot_df.plot(column=col, vmin=vmin, vmax=vmax, markersize=2, cmap='viridis', ax=ax)


In [None]:
#@title Filter global training data based on Mexico PC values
mex = df.loc[df['name'] == 'Mexico']
filtered = df.loc[df['name'] != 'Mexico'].copy()
print(len(filtered))
filtered['inc1'] = 1 # All PCs within central 98% of Mexico data
filtered['inc5'] = 1 # All PCs within central 90% of Mexico data

for i in range(5):
    filtered.loc[filtered[f'pc{i}'] < mex[f'pc{i}'].quantile(0.01), 'inc1'] = 0
    filtered.loc[filtered[f'pc{i}'] < mex[f'pc{i}'].quantile(0.05), 'inc5'] = 0

    filtered.loc[filtered[f'pc{i}'] > mex[f'pc{i}'].quantile(0.99), 'inc1'] = 0
    filtered.loc[filtered[f'pc{i}'] > mex[f'pc{i}'].quantile(0.95), 'inc5'] = 0

    print(filtered['inc1'].sum(), filtered['inc5'].sum())

# Expect to have n = 14639 for middle 98%, n = 7656 for middle 90%


In [None]:
#@title Map which locations are included in filtered dataset
plot_df = gpd.GeoDataFrame(filtered)
plot_df['geometry'] = [Point(xy) for xy in zip(plot_df[LON], plot_df[LAT])]

sns.set_style('white')
f, ax = plt.subplots(figsize=[18, 9])
ax.set_axis_off()

countries.plot(color='none', edgecolor='darkgray', ax=ax)
plot_df.plot(markersize=2, ax=ax, color='gray')
plot_df[plot_df['inc1'] == 1].plot(markersize=2, ax=ax, color='blue')
plot_df[plot_df['inc5'] == 1].plot(markersize=2, ax=ax, color='red')


## Model training

###Setup

In [None]:
# Set metaparams
model_name = 'm1'

start_month = 'February'
input_months = 12
batch_size = 32
upsample_minority_ratio = 0.5
lr = 1e-4
num_epochs = 25


In [None]:
#@title Imports
import warnings
from argparse import ArgumentParser

import numpy as np
import geopandas as gpd
import pandas as pd
import torch
import yaml
from datasets import datasets, label_col
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
)
from torch.utils.data import DataLoader
from tsai.models.TransformerModel import TransformerModel

from openmapflow.bands import BANDS_MAX
from openmapflow.constants import SUBSET
from openmapflow.pytorch_dataset import PyTorchDataset
from openmapflow.train_utils import (
    generate_model_name,
    get_x_y,
    model_path_from_name,
    upsample_df,
)
from openmapflow.utils import tqdm

try:
    import google.colab  # noqa

    IN_COLAB = True
except ImportError:
    IN_COLAB = False


warnings.simplefilter("ignore", UserWarning)  # TorchScript throws excessive warnings

In [None]:
#@title Overwrite get_x_y
from typing import List, Tuple
from openmapflow.constants import CLASS_PROB, EO_DATA, MONTHS
from openmapflow.utils import str_to_np

def get_x_y(
    df: pd.DataFrame,
    label_col: str = CLASS_PROB,
    start_month: str = "February",
    input_months: int = 12,
) -> Tuple[List[np.ndarray], List[float]]:
    """Get the X and y data from a dataframe."""
    i = MONTHS.index(start_month)

    def to_numpy(x: str):
        if type(x) == str:
            x = str_to_np(x)
        return x[i : i + input_months, :]  # noqa

    tqdm.pandas()
    return df[EO_DATA].progress_apply(to_numpy).to_list(), df[label_col].to_list()


In [None]:
#@title Get country boundaries
countries = gpd.read_file('/content/world_adm.geojson')
gdf_lab = countries[['name', 'geometry']].sjoin(gdf, how='inner', predicate='intersects')


In [None]:
#@title Data setup
keep_cols = ['name', 'lon', 'lat', 'class_probability', 'subset', 'eo_data']
gdf_lab = gdf_lab.loc[gdf_lab['name'] != 'Mexico', keep_cols]

mex_adm1 = gpd.read_file('/content/mex_adm1.geojson').rename({'shapeISO': 'adm1'}, axis=1)
mex_gdf = mex_adm1[['adm1', 'geometry']].sjoin(gdf, how='inner', predicate='intersects')
mex_gdf = mex_gdf.assign(name='Mexico')[keep_cols + ['adm1']]

admins = mex_gdf[['adm1']].drop_duplicates()


In [None]:
#@title Make Mexico training data
mex_train = pd.DataFrame(gdf_lab[gdf_lab['name'] == 'Mexico'].drop('geometry', axis=1))

np.random.seed(123)
mex_train = mex_train.sample(frac=1).reset_index(drop=True)
mex_train.loc[mex_train.index > int(len(mex_train) * 0.7), 'subset'] = 'validation'
mex_train.loc[mex_train.index > int(len(mex_train) * 0.9), 'subset'] = 'testing'

# mex_train = pd.concat([mex_train, filtered[filtered['inc5'] == 1]])
print(len(mex_train))


### Bagging

In [None]:
#@title Get train/test helper
def get_train_test(fold, use_global=False, use_pc5=False):
    np.random.seed(fold + 1)
    fold_admins = admins.sample(frac=1, replace=True)
    train_fold = mex_gdf.merge(fold_admins)
    if use_global:
        train_fold = pd.concat([train_fold, gdf_lab])
    if use_pc5:
        assert use_global == False
        train_fold = pd.concat([train_fold, filtered.loc[filtered['inc5'] == 1, keep_cols]])
    test_fold = pd.concat([mex_gdf, train_fold]).drop_duplicates(subset=['lat', 'lon'], keep=False)
    return train_fold, test_fold


In [None]:
#@title Make data loaders helper
def make_data_loaders(train_fold, val_fold):
    train_fold[label_col] = (train_fold[label_col] > 0.5).astype(int)
    val_fold[label_col] = (val_fold[label_col] > 0.5).astype(int)

    train_df = upsample_df(train_fold, label_col, upsample_minority_ratio)

    x_train, y_train = get_x_y(train_df, label_col, start_month, input_months)
    x_val, y_val = get_x_y(val_fold, label_col, start_month, input_months)

    # Convert to tensors
    train_data = PyTorchDataset(x=x_train, y=y_train)
    val_data = PyTorchDataset(x=x_val, y=y_val)
    train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=False)

    return train_data, train_dataloader, val_data, val_dataloader



In [None]:
#@title Init model helper
def init_model(train_data):

    num_timesteps, num_bands = train_data[0][0].shape
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    class Model(torch.nn.Module):
        def __init__(self, normalization_vals=BANDS_MAX):
            super().__init__()
            self.model = TransformerModel(c_in=num_bands, c_out=1)
            self.normalization_vals = torch.tensor(normalization_vals)

        def forward(self, x):
            with torch.no_grad():
                x = x / self.normalization_vals
                x = x.transpose(2, 1)
            x = self.model(x).squeeze(dim=1)
            x = torch.sigmoid(x)
            return x


    model = Model().to(device)
    params_to_update = model.parameters()
    optimizer = torch.optim.Adam(params_to_update, lr=lr)
    criterion = torch.nn.BCELoss()

    return model, device, optimizer, criterion


In [None]:
#@title Train model

from collections import defaultdict

def train_model(model, device, optimizer, criterion,
                train_data, train_dataloader, val_data, val_dataloader):

    lowest_validation_loss = None
    metrics = defaultdict(list)
    train_batches = 1 + len(train_data) // batch_size
    val_batches = 1 + len(val_data) // batch_size

    with tqdm(range(num_epochs), desc="Epoch") as tqdm_epoch:
        for epoch in tqdm_epoch:

            # ------------------------ Training ----------------------------------------
            total_train_loss = 0.0
            model.train()
            for x in tqdm(
                train_dataloader,
                total=train_batches,
                desc="Train",
                leave=False,
                disable=IN_COLAB,
            ):
                inputs, labels = x[0].to(device), x[1].to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # Get model outputs and calculate loss
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                total_train_loss += loss.item() * len(inputs)

            # ------------------------ Validation --------------------------------------
            total_val_loss = 0.0
            y_true = []
            y_score = []
            y_pred = []
            model.eval()
            with torch.no_grad():
                for x in tqdm(
                    val_dataloader,
                    total=val_batches,
                    desc="Validate",
                    leave=False,
                    disable=IN_COLAB,
                ):
                    inputs, labels = x[0].to(device), x[1].to(device)

                    # Get model outputs and calculate loss
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    total_val_loss += loss.item() * len(inputs)

                    y_true += labels.tolist()
                    y_score += outputs.tolist()
                    y_pred += (outputs > 0.5).long().tolist()

            # ------------------------ Metrics + Logging -------------------------------
            train_loss = total_train_loss / len(train_data)
            val_loss = total_val_loss / len(val_data)

            if lowest_validation_loss is None or val_loss < lowest_validation_loss:
                lowest_validation_loss = val_loss

            metrics['accuracy'].append(accuracy_score(y_true, y_pred))
            metrics['f1'].append(f1_score(y_true, y_pred))
            metrics['precision'].append(precision_score(y_true, y_pred))
            metrics['recall'].append(recall_score(y_true, y_pred))
            metrics['roc_auc'].append(roc_auc_score(y_true, y_score))
            metrics['train_loss'].append(train_loss)
            metrics['val_loss'].append(val_loss)

            tqdm_epoch.set_postfix(loss=val_loss)
            # print(epoch, round(train_loss, 3), round(val_loss, 3),
            #       round(metrics['precision'][-1], 3), round(metrics['recall'][-1], 3))

        return pd.DataFrame(metrics), y_pred



In [None]:
#@title Bagging loop
N_FOLDS = 20
use_global, use_pc5 = False, True
dir = '/content/bagging_pc5'
! mkdir $dir

for fold in range(N_FOLDS):
    train_df, val_df = get_train_test(fold, use_global=use_global, use_pc5=use_pc5)
    print(len(train_df))
    train_data, train_dataloader, val_data, val_dataloader = make_data_loaders(train_df, val_df)
    model, device, optimizer, criterion = init_model(train_data)
    metrics, preds = train_model(
        model, device, optimizer, criterion, train_data, train_dataloader, val_data, val_dataloader)

    val_df.drop('eo_data', axis=1).assign(pred_class=preds, fold=fold).to_csv(
        f'{dir}/preds{fold}.csv', index=False)
    metrics.assign(fold=0, iter=[x for x in range(num_epochs)]).to_csv(
        f'{dir}/metrics{fold}.csv', index=False)

    print('*' * 20, f'fold {fold} / {N_FOLDS} complete', '*' * 20)


In [None]:
#@title Read in results
# ! unzip /content/bagging_mex.zip
# ! mv bagging_mex /content
folder = 'bagging_pc5'

metrics = pd.concat([pd.read_csv(f'/content/{folder}/metrics{i}.csv') for i in range(N_FOLDS)])
preds = pd.concat([pd.read_csv(f'/content/{folder}/preds{i}.csv') for i in range(N_FOLDS)])

preds_agg = preds.groupby(['lon', 'lat', 'adm1'])[['class_probability', 'pred_class']].mean().reset_index()
metrics = metrics.groupby('iter').mean().reset_index()

In [None]:
#@title Plot results
import plotly
import plotly.graph_objects as go
from plotly import subplots

fig = plotly.subplots.make_subplots(rows=1, cols=2)

x = [x for x in range(len(metrics))]
fig.add_trace(go.Scatter(name='Train Loss', x=x, y=metrics['train_loss'], line_color='cornflowerblue'), row=1, col=1)
fig.add_trace(go.Scatter(name='Val Loss', x=x, y=metrics['val_loss'], line_color='orange'), row=1, col=1)

fig.add_trace(go.Scatter(name='Accuracy', x=x, y=metrics['accuracy'], line_color='blue'), row=1, col=2)
fig.add_trace(go.Scatter(name='F1', x=x, y=metrics['f1'], line_color='green'), row=1, col=2)
fig.add_trace(go.Scatter(name='Precision', x=x, y=metrics['precision'], line_color='purple'), row=1, col=2)
fig.add_trace(go.Scatter(name='Recall', x=x, y=metrics['recall'], line_color='red'), row=1, col=2)


fig.update_layout(height=500, width=1200)

fig.show()


In [None]:
#@title Tabular results
preds_agg = preds.groupby(['lon', 'lat', 'adm1'])[['pred_class', 'class_probability']].mean().reset_index()
preds_agg['pred_class'] = np.round(preds_agg['pred_class']).apply(int)
preds_agg['class_probability'] = np.round(preds_agg['class_probability']).apply(int)
y_true, y_pred = preds_agg['class_probability'], preds_agg['pred_class']

print('accuracy', round(accuracy_score(y_true, y_pred), 3))
print('f1', round(f1_score(y_true, y_pred), 3))
print('precision', round(precision_score(y_true, y_pred), 3))
print('recall', round(recall_score(y_true, y_pred), 3))


In [None]:
#@title Map results
import matplotlib.gridspec as gridspec

sns.set_style('white')
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=[12, 7])
plt.subplots_adjust(wspace=0, hspace=0)

for i in range(2):
    mex_adm1.plot(edgecolor='darkgray', color='none', ax=ax[i])
    ax[i].set_axis_off()

plot_preds.plot(column='class_probability', cmap='bwr', markersize=5, ax=ax[0])
plot_preds.plot(column='pred_class', cmap='bwr', markersize=5, ax=ax[1])


### Experiment model code

This code is basically redundant with the bagging code but makes it a bit easier to experimentally train a single model

In [None]:
#@title Dataloaders
mex_train[label_col] = (mex_train[label_col] > 0.5).astype(int)
train_df = mex_train[mex_train[SUBSET] == "training"]

train_df = upsample_df(train_df, label_col, upsample_minority_ratio)
val_df = mex_train[mex_train[SUBSET] == "validation"]
x_train, y_train = get_x_y(train_df, label_col, start_month, input_months)
x_val, y_val = get_x_y(val_df, label_col, start_month, input_months)

# Convert to tensors
train_data = PyTorchDataset(x=x_train, y=y_train)
val_data = PyTorchDataset(x=x_val, y=y_val)
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=False)


In [None]:
# @title Init model
num_timesteps, num_bands = train_data[0][0].shape

class Model(torch.nn.Module):
    def __init__(self, normalization_vals=BANDS_MAX):
        super().__init__()
        self.model = TransformerModel(c_in=num_bands, c_out=1)
        self.normalization_vals = torch.tensor(normalization_vals)

    def forward(self, x):
        with torch.no_grad():
            x = x / self.normalization_vals
            x = x.transpose(2, 1)
        x = self.model(x).squeeze(dim=1)
        x = torch.sigmoid(x)
        return x

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = Model().to(device)

# ------------ Model hyperparameters -------------------------------------
params_to_update = model.parameters()
optimizer = torch.optim.Adam(params_to_update, lr=lr)
criterion = torch.nn.BCELoss()

if model_name == "":
    model_name = generate_model_name(val_df=val_df, start_month=start_month)

lowest_validation_loss = None
metrics = {}
train_batches = 1 + len(train_data) // batch_size
val_batches = 1 + len(val_data) // batch_size

In [None]:
#@title Train model

lowest_validation_loss = None
metrics = {}
train_batches = 1 + len(train_data) // batch_size
val_batches = 1 + len(val_data) // batch_size

train_loss_arr, val_loss_arr = [], []
acc_arr, f1_arr, recall_arr, prec_arr = [], [], [], []

with tqdm(range(num_epochs), desc="Epoch") as tqdm_epoch:
    for epoch in tqdm_epoch:

        # ------------------------ Training ----------------------------------------
        total_train_loss = 0.0
        model.train()
        for x in tqdm(
            train_dataloader,
            total=train_batches,
            desc="Train",
            leave=False,
            disable=IN_COLAB,
        ):
            inputs, labels = x[0].to(device), x[1].to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # Get model outputs and calculate loss
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item() * len(inputs)

        # ------------------------ Validation --------------------------------------
        total_val_loss = 0.0
        y_true = []
        y_score = []
        y_pred = []
        model.eval()
        with torch.no_grad():
            for x in tqdm(
                val_dataloader,
                total=val_batches,
                desc="Validate",
                leave=False,
                disable=IN_COLAB,
            ):
                inputs, labels = x[0].to(device), x[1].to(device)

                # Get model outputs and calculate loss
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                total_val_loss += loss.item() * len(inputs)

                y_true += labels.tolist()
                y_score += outputs.tolist()
                y_pred += (outputs > 0.5).long().tolist()

        # ------------------------ Metrics + Logging -------------------------------
        train_loss = total_train_loss / len(train_data)
        val_loss = total_val_loss / len(val_data)

        if lowest_validation_loss is None or val_loss < lowest_validation_loss:
            lowest_validation_loss = val_loss

        metrics = {
            "accuracy": accuracy_score(y_true, y_pred),
            "f1": f1_score(y_true, y_pred),
            "precision": precision_score(y_true, y_pred),
            "recall": recall_score(y_true, y_pred),
            "roc_auc": roc_auc_score(y_true, y_score),
        }
        metrics = {k: round(float(v), 4) for k, v in metrics.items()}
        print(round(train_loss, 3), round(val_loss, 3), metrics)

        train_loss_arr.append(train_loss)
        val_loss_arr.append(val_loss)
        acc_arr.append(metrics['accuracy'])
        f1_arr.append(metrics['f1'])
        recall_arr.append(metrics['recall'])
        prec_arr.append(metrics['precision'])

        tqdm_epoch.set_postfix(loss=val_loss)

        # ------------------------ Model saving --------------------------
        if lowest_validation_loss == val_loss:
            # Some models in tsai need to be modified to be TorchScriptable
            # https://github.com/timeseriesAI/tsai/issues/561
            sm = torch.jit.script(model)
            model_path = model_path_from_name(model_name=model_name)
            if model_path.exists():
                model_path.unlink()
            else:
                model_path.parent.mkdir(parents=True, exist_ok=True)
            sm.save(str(model_path))



In [None]:
#@title Plot results
import plotly
import plotly.graph_objects as go
from plotly import subplots

fig = plotly.subplots.make_subplots(rows=1, cols=2)

x = [x for x in range(num_epochs)]
fig.add_trace(go.Scatter(name='Train Loss', x=x, y=train_loss_arr, line_color='cornflowerblue'), row=1, col=1)
fig.add_trace(go.Scatter(name='Val Loss', x=x, y=val_loss_arr, line_color='orange'), row=1, col=1)

fig.add_trace(go.Scatter(name='Accuracy', x=x, y=acc_arr, line_color='blue'), row=1, col=2)
fig.add_trace(go.Scatter(name='F1', x=x, y=f1_arr, line_color='green'), row=1, col=2)
fig.add_trace(go.Scatter(name='Precision', x=x, y=prec_arr, line_color='purple'), row=1, col=2)
fig.add_trace(go.Scatter(name='Recall', x=x, y=recall_arr, line_color='red'), row=1, col=2)

fig.update_layout(height=500, width=1200)

fig.show()


In [None]:
print(f"MODEL_NAME={model_name}")
print(model_path_from_name(model_name=model_name))
print(yaml.dump(metrics, allow_unicode=True, default_flow_style=False))
