In [7]:
import torch
import torch.nn as F
from torch_geometric.loader import DataLoader, ImbalancedSampler

import sys
sys.path.append("../")

from src.features.dataset import DepressionDataset
from src.models.CNN import CNN

In [8]:
BATCH_SIZE = 128
EPOCHS = 100
LEARNING_RATE = 0.001
WEIGHT_DECAY = 5e-4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
FILTERS = 100
NUMBER_OF_CLASSES = 3
ENCODER_TYPE = "bert"

embedding_size = 768 if ENCODER_TYPE == "bert" else (300 if ENCODER_TYPE == "w2v" else ValueError("Invalid encoder type"))

In [9]:
train = DepressionDataset('train', ENCODER_TYPE, "other", root_path="..")
valid = DepressionDataset('valid', ENCODER_TYPE, "other", root_path="..")

In [10]:
sampler = ImbalancedSampler(train, len(train))

train_loader = DataLoader(train, batch_size=BATCH_SIZE, sampler=sampler)
valid_loader = DataLoader(valid, batch_size=BATCH_SIZE, shuffle=False)

In [11]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using", device)

Using mps


In [12]:
model = CNN(NUMBER_OF_CLASSES-1, FILTERS, embedding_size).to(device)

In [13]:
def collate_fn(b):
    x = b.x.to(device)
    batch = b.batch.to(device)

    # longest sentence in the batch
    max_len = torch.max(batch.bincount())

    # number of sentences in the batch
    num_sentences = batch[-1].item() + 1

    # create a new tensor with the correct shape
    x_new = torch.zeros((num_sentences, max_len, x.shape[1])).to(device)
    
    # fill the new tensor with the old values
    for i in range(num_sentences):
        x_new[i, :batch.bincount()[i]] = x[batch == i]

    return x_new, b.y.to(device)

In [19]:
from tqdm import tqdm
from coral_pytorch.losses import corn_loss
from coral_pytorch.dataset import corn_label_from_logits

def run_epoch(model, loader, optimizer, device, epoch, set_type):
    y_true = []
    y_pred = []
    epoch_loss = 0

    if set_type == 'train':
        desc = f'Epoch {epoch:3d} ┬ Train'
    elif set_type == 'valid':
        desc = '          └ Valid'

    for b in tqdm(loader, desc=desc):
        x, y = collate_fn(b)
        optimizer.zero_grad()
        out = model(x)
        loss = corn_loss(out, y, num_classes=3)
        
        if set_type == 'train':
            loss.backward()
            optimizer.step()

        pred = corn_label_from_logits(out)

        y_true += y.tolist()
        y_pred += pred.cpu().tolist()
        epoch_loss += loss.item()
    epoch_loss /= len(loader)

    return y_true, y_pred, epoch_loss

In [20]:
from main import get_metrics, plot_cm
import mlflow

from sklearn.metrics import (
    confusion_matrix as sk_cm
)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

mlflow.set_experiment("CNN")

with mlflow.start_run():    
    mlflow.log_param("batch_size", BATCH_SIZE)
    mlflow.log_param("epochs", EPOCHS)
    mlflow.log_param("learning_rate", LEARNING_RATE)
    mlflow.log_param("weight_decay", WEIGHT_DECAY)
    mlflow.log_param("filters", FILTERS)
    mlflow.log_param("encoder_type", ENCODER_TYPE)

    for epoch in range(EPOCHS):
        model.train()
        y_true, y_pred, train_loss = run_epoch(model, train_loader, optimizer, device, epoch, 'train')
        train_cm = sk_cm(y_true, y_pred, labels=[0, 1, 2], normalize='true')
        mlflow.log_metrics(get_metrics(y_true, y_pred, "train"), step=epoch)
        mlflow.log_metric("train_loss", train_loss, step=epoch)

        # Valid
        model.eval()
        y_true, y_pred, valid_loss = run_epoch(model, valid_loader, optimizer, device, epoch, 'valid')
        valid_cm = sk_cm(y_true, y_pred, labels=[0, 1, 2], normalize='true')
        mlflow.log_metrics(get_metrics(y_true, y_pred, "valid"), step=epoch)
        mlflow.log_metric("valid_loss", valid_loss, step=epoch)

        # Plot confusion matrix
        cm_path = plot_cm(train_cm, valid_cm, epoch, root="../reports/figures")
        mlflow.log_artifact(cm_path, artifact_path="confusion_matrix")

Epoch   0 ┬ Train:   0%|          | 0/49 [00:00<?, ?it/s]

Epoch   0 ┬ Train:   2%|▏         | 1/49 [00:25<20:29, 25.61s/it]