In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

import sys
sys.path.append("../")

from src.models.CNN import CNN

In [None]:
BATCH_SIZE = 128
EPOCHS = 100
LEARNING_RATE = 0.001
WEIGHT_DECAY = 5e-4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
FILTERS = 64
NUMBER_OF_CLASSES = 3
ENCODER_TYPE = "bert"

embedding_size = 768 if ENCODER_TYPE == "bert" else (300 if ENCODER_TYPE == "w2v" else ValueError("Invalid encoder type"))

In [None]:
import os
import numpy as np
import pandas as pd
import pickle
from torch.utils.data import Dataset

class DepressionDataset(Dataset):
    def __init__(self, set_type, encoder_type, root_path="."):
        
        self.set_type = set_type
        self.encoder_type = encoder_type
        self.root_path = root_path

        self.raw_dir = os.path.join(root_path, "data", "gold", set_type, "raw")

        assert os.path.exists(self.raw_dir), f"Path {self.raw_dir} does not exist"

        self.sample_weights = self.get_sample_weights()

    def __len__(self):
        dirs = os.listdir(self.raw_dir)
        dirs = [d for d in dirs if os.path.isdir(os.path.join(self.raw_dir, d))]
        return len(dirs)

    def __getitem__(self, idx):
        idx = str(idx)
        features_path = os.path.join(self.raw_dir, idx, f"features_{self.encoder_type}.npy")
        label_path = os.path.join(self.raw_dir, idx, "label.pkl")

        features = np.load(features_path)
        label = pickle.load(open(label_path, "rb"))

        features = torch.from_numpy(features).float()
        label = torch.tensor(label).long()

        return features, label
    
    def get_sample_weights(self):
        labels = []
        for idx in range(len(self)):
            idx = str(idx)
            label_path = os.path.join(self.raw_dir, idx, "label.pkl")
            label = pickle.load(open(label_path, "rb"))
            labels.append(label)
        labels = np.array(labels)

        class_weights = np.zeros(np.unique(labels).shape[0])
        for i in range(len(class_weights)):
            class_weights[i] = len(labels) / np.sum(labels == i)
        
        sample_weights = class_weights[labels]

        return sample_weights

In [None]:
train = DepressionDataset('train', ENCODER_TYPE, root_path="..")
valid = DepressionDataset('valid', ENCODER_TYPE, root_path="..")

In [None]:
def collate_fn(batch):
    # padd features to same length
    features, labels = zip(*batch)
    features = torch.nn.utils.rnn.pad_sequence(features, batch_first=True)
    return features, torch.stack(labels)

sampler = torch.utils.data.sampler.WeightedRandomSampler(
    weights=train.sample_weights,
    num_samples=len(train.sample_weights),
    replacement=True
)

train_loader = DataLoader(train, batch_size=BATCH_SIZE, sampler=sampler, collate_fn=collate_fn)
valid_loader = DataLoader(valid, batch_size=BATCH_SIZE, shuffle=False,   collate_fn=collate_fn)

In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using", device)

In [None]:
model = CNN(NUMBER_OF_CLASSES-1, FILTERS, embedding_size).to(device)

In [None]:
from tqdm import tqdm
from coral_pytorch.losses import corn_loss
from coral_pytorch.dataset import corn_label_from_logits

def run_epoch(model, loader, optimizer, device, epoch, set_type):
    y_true = []
    y_pred = []
    epoch_loss = 0

    if set_type == 'train':
        desc = f'Epoch {epoch:3d} ┬ Train'
    elif set_type == 'valid':
        desc = '          └ Valid'

    for b in tqdm(loader, desc=desc):
        x, y = b
        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad()
        out = model(x)
        loss = corn_loss(out, y, num_classes=3)

        pred = corn_label_from_logits(out)

        y_true += y.tolist()
        y_pred += pred.cpu().tolist()
        epoch_loss += loss.item()

        if set_type == 'train':
            loss.backward()
            optimizer.step()
            
    epoch_loss /= len(loader)

    return y_true, y_pred, epoch_loss

In [None]:
import mlflow
from sklearn.metrics import (
    confusion_matrix as sk_cm
)
from main import get_metrics, plot_cm

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# mlflow.set_tracking_uri("../mlruns")
mlflow.set_tracking_uri("http://127.0.0.1:5000")
mlflow.set_experiment("CNN")

with mlflow.start_run():    
    mlflow.log_param("batch_size", BATCH_SIZE)
    mlflow.log_param("epochs", EPOCHS)
    mlflow.log_param("learning_rate", LEARNING_RATE)
    mlflow.log_param("weight_decay", WEIGHT_DECAY)
    mlflow.log_param("filters", FILTERS)
    mlflow.log_param("encoder_type", ENCODER_TYPE)

    for epoch in range(EPOCHS):
        model.train()
        y_true, y_pred, train_loss = run_epoch(model, train_loader, optimizer, device, epoch, 'train')
        train_cm = sk_cm(y_true, y_pred, labels=[0, 1, 2], normalize='true')
        mlflow.log_metrics(get_metrics(y_true, y_pred, "train"), step=epoch)
        mlflow.log_metric("train_loss", train_loss, step=epoch)

        # Valid
        model.eval()
        y_true, y_pred, valid_loss = run_epoch(model, valid_loader, optimizer, device, epoch, 'valid')
        valid_cm = sk_cm(y_true, y_pred, labels=[0, 1, 2], normalize='true')
        mlflow.log_metrics(get_metrics(y_true, y_pred, "valid"), step=epoch)
        mlflow.log_metric("valid_loss", valid_loss, step=epoch)

        # Plot confusion matrix
        cm_path = plot_cm([
            [train_cm, "Train"],
            [valid_cm, "Valid"]
        ], epoch, root="../reports/figures")
        mlflow.log_artifact(cm_path, artifact_path="confusion_matrix")