In [None]:
import lmdb
from bigearthnet_patch_interface.s2_interface import BigEarthNet_S2_Patch
from pathlib import Path
import importlib

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch import Tensor, cat, stack
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn

from sklearn.metrics import accuracy_score


#### Data Loading

In [None]:
import os
from dotenv import load_dotenv

# Get paths
load_dotenv('./.env')

BEN_LMDB_PATH = os.environ.get("BEN_LMDB_PATH")

TRAIN_CSV_FILE = os.environ.get("TRAIN_CSV")
TEST_CSV_FILE = os.environ.get("TEST_CSV")
VAL_CSV_FILE = os.environ.get("VAL_CSV")
PATH_TO_MODELS = os.environ.get("PATH_TO_MODELS")

assert Path(BEN_LMDB_PATH).exists()
assert Path(TRAIN_CSV_FILE).exists()
assert os.path.isdir(PATH_TO_MODELS)

In [None]:
env = lmdb.open(BEN_LMDB_PATH, readonly=True, readahead=False, lock=False)
txn = env.begin()
cur = txn.cursor()

In [None]:
import dataset_class as ds_class
importlib.reload(ds_class)

val_ds = ds_class.BenDataset(VAL_CSV_FILE, BEN_LMDB_PATH, 2)
train_ds = ds_class.BenDataset(TRAIN_CSV_FILE, BEN_LMDB_PATH, 2)

val_loader = DataLoader(val_ds, batch_size=16, shuffle=True)
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True)

## Training

#### Determine hpyerparameters and prepare paths/folders.

In [None]:
import conv_mixer as conv_mixer
import training_utils as train_utils
importlib.reload(conv_mixer)

# Hyperparameters
h = 256
depth = 8
# kernel_size
# patch_size

model_name = f'ConvMx-{h}-{depth}'


model = conv_mixer.ConvMixer(10, h, depth, n_classes=19)
if torch.cuda.is_available():
    model = model.cuda()

loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [None]:
from datetime import datetime

timestamp = datetime.now().strftime('%m-%d_%H%M')
model_dir = timestamp + '-' + model_name

assert not os.path.isdir(PATH_TO_MODELS + '/' + model_dir)
os.mkdir(PATH_TO_MODELS + '/' + model_dir)

In [None]:
importlib.reload(train_utils)

n_epochs = 5
val_loss_min = np.inf

train_loss_hist = []
train_acc_hist = []
val_loss_hist = []
val_acc_hist = []


## try catch keyboard interrupt

for e in range(n_epochs):
    print(f'{e+1:3d}/{n_epochs} : ', end="")

    # Inference, backpropagation, weight adjustments
    model.train()
    train_loss, train_acc = train_utils.train_batch(val_loader, model, optimizer, loss_fn)
        
    train_loss_hist.append(train_loss)
    train_acc_hist.append(train_acc)
    
    # # Evaluate on validation data
    model.eval()
    val_loss, val_acc = train_utils.validate_batch(train_loader, model, loss_fn)
        
    val_loss_hist.append(val_loss)
    val_acc_hist.append(val_acc)
    
    print(f'train_loss={train_loss:.4f}', f'train_acc={train_acc:.4f}', end=" ")
    print(f'val_loss={val_loss:.4f}', f'val_acc={val_acc:.4f}')
    
    # Save checkpoint model if validation loss improves
    if val_loss < val_loss_min:
        print(f'val_loss decreased ({val_loss_min:.6f} --> {val_loss:.6f}). Saving model weights ...')

        p = PATH_TO_MODELS + f'/{model_dir}/e{e+1}_{model_name}.pt'
        train_utils.save_complete_model(p, model)
        
        val_loss_min = val_loss


print('Finished Training')
print('Saving final model ...')
p = PATH_TO_MODELS + f'/{model_dir}/{model_name}.pt'
train_utils.save_complete_model(p, model)

## Evaluation

#### Evaluate model

In [None]:
fig = plt.figure(figsize=(16,4))
ax = fig.add_subplot(121)
ax.plot(val_loss_hist, label='val')
ax.plot(train_loss_hist, label='train')
ax.legend(loc="upper right")
# ax.set_ylim([0, 1])
ax.set_title("loss")
ax.set_xlabel("epochs")

ax = fig.add_subplot(122)
ax.plot(val_acc_hist, label='val')
ax.plot(train_acc_hist, label='train')
ax.legend(loc="lower right")
ax.set_ylim([0, 1])
ax.set_title("accuracy")
ax.set_xlabel("epochs")

plt.savefig(PATH_TO_MODELS +  f'/{model_dir}/{model_name}.pdf')