In [None]:
import torch
from torch.utils.data import DataLoader

import mlflow
from sklearn.metrics import confusion_matrix as sk_cm

import sys
sys.path.append("../")

from src.models.LSTM import LSTM
from src.features.dataset_seq import DepressionDataset, collate_fn, get_sampler
from src.train_seq import run_epoch

from main import plot_cm, get_metrics

mlflow.set_tracking_uri("../mlruns/")
mlflow.set_experiment("LSTM")

In [None]:
params = {
    "encoder_type": "w2v",
    "num_classes": 3 - 1, # bcs of coral
    "batch_size": 32,
    "epochs": 100,
    "hidden_dim": 128,
    "learning_rate": 0.001,
    "weight_decay": 5e-4,
}

# EMBEDDING DIMENSION
if params["encoder_type"] == "bert":
    params["embedding_dim"] = 768
elif params["encoder_type"] == "w2v":
    params["embedding_dim"] = 300
else:
    raise ValueError("Invalid encoder type")

# DEVICE
if torch.backends.mps.is_available():
    params["device"] = torch.device("mps")
else:
    params["device"] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using", params["device"])

In [None]:
train = DepressionDataset('train', params["encoder_type"], root_path="..")
valid = DepressionDataset('valid', params["encoder_type"], root_path="..")
test  = DepressionDataset('test', params["encoder_type"],  root_path="..")

In [None]:
train_loader = DataLoader(train, batch_size=params["batch_size"], collate_fn=collate_fn, sampler=get_sampler(train))
valid_loader = DataLoader(valid, batch_size=params["batch_size"], collate_fn=collate_fn, shuffle=False)
test_loader  = DataLoader(test,  batch_size=params["batch_size"], collate_fn=collate_fn, shuffle=False)

In [None]:
model = LSTM(
    params["num_classes"], 
    params["embedding_dim"],
    params["hidden_dim"], 
).to(params["device"])

optimizer = torch.optim.Adam(
    model.parameters(), 
    lr = params["learning_rate"], 
    weight_decay = params["weight_decay"],
)

In [None]:
mlflow.start_run()
mlflow.log_params(params)

valid_loss_hist = []
for epoch in range(params["epochs"]):
    model.train()
    y_true, y_pred, train_loss = run_epoch(model, train_loader, optimizer, params["device"], epoch, 'train')
    train_cm = sk_cm(y_true, y_pred, labels=[0, 1, 2], normalize='true')
    mlflow.log_metrics(get_metrics(y_true, y_pred, "train"), step=epoch)
    mlflow.log_metric("train.loss", train_loss, step=epoch)

    # Valid
    model.eval()
    y_true, y_pred, valid_loss = run_epoch(model, valid_loader, optimizer, params["device"], epoch, 'valid')
    valid_cm = sk_cm(y_true, y_pred, labels=[0, 1, 2], normalize='true')
    mlflow.log_metrics(get_metrics(y_true, y_pred, "valid"), step=epoch)
    mlflow.log_metric("valid.loss", valid_loss, step=epoch)

    # Plot confusion matrix
    cm_path = plot_cm([
        [train_cm, "Train"],
        [valid_cm, "Valid"]
    ], epoch, root="../reports/figures")
    mlflow.log_artifact(cm_path, artifact_path="confusion_matrix")

    # Early stopping
    if epoch > 4 and valid_loss > max(valid_loss_hist[-3:]):
        print("Early stopping")
        break
    valid_loss_hist.append(valid_loss)


In [None]:
# Evaluate on test
model.eval()
y_true, y_pred, test_loss = run_epoch(model, test_loader, optimizer, params["device"], epoch, 'test')
test_cm = sk_cm(y_true, y_pred, labels=[0, 1, 2], normalize='true')
mlflow.log_metrics(get_metrics(y_true, y_pred, "test"), step=epoch)
mlflow.log_metric("test.loss", test_loss, step=epoch)

# Plot confusion matrix
cm_path = plot_cm([
    [test_cm, "Test"]
], 999, root="../reports/figures")
mlflow.log_artifact(cm_path, artifact_path="confusion_matrix")

# Save model to mlflow
mlflow.pytorch.log_model(model, "model")

mlflow.end_run()