## Training

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import pandas as pd
from config import conf
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from torchvision import transforms as T
from scripts.transforms import normalize, rescale
from scripts.dataset import PlanetDataSet
import plotly.express as px

In [None]:
data_df = pd.read_csv(conf.data_file)
data_df["lc_sub_tags"] = data_df["lc_sub_tags"].fillna("none")
data_df["lc_tags"] = data_df["lc_tags"].fillna("none")
data_df.head(2)

In [None]:
data_df = data_df[~data_df["degraded_forest"].isna()]
data_df["degraded_forest"] = data_df["degraded_forest"].astype(str)
len(data_df)

In [None]:
transforms = T.Compose(
    [
        rescale((32, 32)),
        # normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],),
    ]
)

In [None]:
planet_dataset = PlanetDataSet(
    root=conf.train_imgs_path,
    data_df=data_df,
    label_col="lc_tags",
    transforms=transforms,
)
all_tags_dict = planet_dataset.class_to_idx
all_tags_dict

In [None]:
degraded_df = data_df.groupby(by=["lc_tags"]).count().reset_index()
degraded_df = degraded_df.sort_values(by=["lc_sub_tags"], ascending=False)
degraded_df

In [None]:
grassland = data_df[data_df.lc_tags == "Grassland"].sample(90)
forest = data_df[data_df.lc_tags == "Forest"].sample(90)
other_land = data_df[data_df.lc_tags == "Otherland"].sample(90)
rest = data_df[~data_df.lc_tags.isin(["Grassland", "Forest", "Otherland"])]

lc_df = pd.concat([grassland, forest, other_land, rest])

lc_df = lc_df.groupby(by=["lc_tags"]).count().reset_index()
lc_df = lc_df.sort_values(by=["lc_sub_tags"], ascending=False)
lc_df

In [None]:
fig = px.bar(
    lc_df,
    y="lc_tags",
    x="multiple",
    title="land cover classes distribution",
    orientation="h",
    labels={"lc_tags": "lc_tags", "multiple": "Count number"},
)
fig.update_traces(
    marker=dict(
        color="rgba(164, 163, 204, 0.85)",
        line=dict(color="rgb(248, 248, 249)", width=1),
    )
)
fig.show()