In [1]:
%cd ../

%load_ext autoreload
%autoreload 2

/Users/macos/Uni/1st_year/period_3/DSProj/code/models


In [2]:
import math
from pathlib import Path
from datetime import datetime

import numpy as np
import torch
import torch.nn as torch_nn
import pandas as pd
from torch.optim import AdamW, lr_scheduler
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm

from MF.model import MF
import evaluation
import utils

In [3]:
NUM_SITES_PER_PACK = 2
NUM_SPECIES_PER_PACK = 4
BATCH_SIZE = 25
LR = 5e-3
EPS = 1e-6
N_EPOCHS = 70
USE_REGULARIZATION = True
PROBABILITY_OUTPUT = True
D_HID = 10

DEVICE = "mps"
PATH_DIR_DATA_PROCESS = Path("data_processed")

# 1. Load data train/val and encoder

In [4]:
path_dir_data = PATH_DIR_DATA_PROCESS / "trainval"

path_data_train = path_dir_data / "data_train.npy"
path_data_val = path_dir_data / "data_val.npy"

data_train = np.load(path_data_train, allow_pickle=True)
data_val = np.load(path_data_val, allow_pickle=True)

In [5]:
path_dir_encode = PATH_DIR_DATA_PROCESS / "encoder"

path_enc_species = path_dir_encode / "ordinal_enc_species.json"
path_enc_site = path_dir_encode / "ordinal_enc_site.json"

enc_species = utils.CategoryDict.from_file(path_enc_species)
enc_site = utils.CategoryDict.from_file(path_enc_site)

# 2. Create train/val data_loader

In [6]:
class FossilNOW(Dataset):
    def __init__(self, data: list) -> None:
        super().__init__()

        self.data = data

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        x = self.data[index]    
    
        occurence = torch.tensor(x['occurence'], device=DEVICE, dtype=torch.float32)
        sites = torch.tensor(x['sites'], device=DEVICE, dtype=torch.int32)
        species = torch.tensor(x['genera'], device=DEVICE, dtype=torch.int32)

        return occurence, sites, species
    
dataset_train, dataset_val = FossilNOW(data_train), FossilNOW(data_val)
loader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
loader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=False)
    

# 3. Model

In [7]:
mf = MF(
    enc_site.size(),
    enc_species.size(),
    d_hid=D_HID,
    prob_output=PROBABILITY_OUTPUT,
)
# mf.load_state_dict(torch.load("model_best_PROBABILITY_OUTPUT=False_0.1747.pt"))
mf = mf.to(device=torch.device(DEVICE))

optimizer = AdamW(mf.parameters(), lr=LR, weight_decay=2e-5 if USE_REGULARIZATION else 0)
criterion = torch_nn.MSELoss(reduction="none")
scheduler = lr_scheduler.LinearLR(optimizer, 1.0, 5e-2, BATCH_SIZE)

## 4. Training

In [8]:
def calc_loss(occ: torch.Tensor, pred: torch.Tensor, alpha: float = 10):
    loss = criterion(occ, pred)

    confidence = 1 + alpha * occ
    loss = torch.mean(confidence * loss)

    return loss

def train(model, loader, optimizer):
    model.train()

    for x in tqdm(loader, total=math.ceil(len(dataset_train) / BATCH_SIZE)):
        optimizer.zero_grad()

        occurence, idx_sites, idx_species = x

        pred = model(idx_sites, idx_species)
        loss = calc_loss(occurence, pred)

        loss.backward()
        optimizer.step()

    scheduler.step()

def val(model, loader) -> tuple:
    preds = []

    model.eval()
    with torch.no_grad():
        losses = []
        for x in tqdm(loader, total=int(math.ceil(len(dataset_val) / BATCH_SIZE))):
            occurence, idx_sites, idx_species = x

            pred = model(idx_sites, idx_species)
            loss = calc_loss(occurence, pred)

            preds.append({
                'sites': idx_sites.detach().cpu().numpy(),
                'species': idx_species.detach().cpu().numpy(), 
                'occurence': occurence.detach().cpu().numpy(),
                'prediction': pred.detach().cpu().numpy()
            })

            losses.append(loss.item())

    loss = sum(losses)/len(losses)

    return loss, preds


In [9]:
best_loss_val = 10e10

for n in range(N_EPOCHS):
    print(f"== Epoch: {n:02d}")
    
    train(mf, loader_train, optimizer)

    loss_val, preds_val = val(mf, loader_val)

    if loss_val < best_loss_val:
        best_loss_val = loss_val
        torch.save(mf.state_dict(), f"model_best_PROBABILITY_OUTPUT={PROBABILITY_OUTPUT}_{loss_val:.4f}.pt")
    else:
        break

    print(f"Loss val: {loss_val}")

== Epoch: 00


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/103 [00:00<?, ?it/s]

Loss val: 0.3306711742889534
== Epoch: 01


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/103 [00:00<?, ?it/s]

Loss val: 0.17894929068759807
== Epoch: 02


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/103 [00:00<?, ?it/s]

Loss val: 0.15959250688263513
== Epoch: 03


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/103 [00:00<?, ?it/s]

Loss val: 0.1424806959999418
== Epoch: 04


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/103 [00:00<?, ?it/s]

Loss val: 0.12674113959797378
== Epoch: 05


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/103 [00:00<?, ?it/s]

Loss val: 0.11550664959601986
== Epoch: 06


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/103 [00:00<?, ?it/s]

Loss val: 0.10930152475219039
== Epoch: 07


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/103 [00:00<?, ?it/s]

Loss val: 0.10447654705856321
== Epoch: 08


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/103 [00:00<?, ?it/s]

Loss val: 0.10223171049819409
== Epoch: 09


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/103 [00:00<?, ?it/s]

Loss val: 0.1000989092224576
== Epoch: 10


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/103 [00:00<?, ?it/s]

# 5. Evaluate

In [10]:
list_preds = []

_, preds = val(mf, loader_val)
for pred in preds:
    for sites, species, occu, pre in zip(pred['sites'], pred['species'], pred['occurence'], pred['prediction']):
        for i in range(NUM_SITES_PER_PACK):
            for j in range(NUM_SPECIES_PER_PACK):
                list_preds.append({
                    'site': sites[i],
                    'species': species[j],
                    'occurence': occu[i, j],
                    'pred': pre[i, j]
                })

df_pred = pd.DataFrame.from_records(list_preds)

df_pred.head()

  0%|          | 0/103 [00:00<?, ?it/s]

Unnamed: 0,site,species,occurence,pred
0,198,116,0.0,0.003366
1,198,117,0.0,0.117846
2,198,118,0.0,0.051285
3,198,119,0.0,0.020756
4,199,116,0.0,0.008086


In [11]:
print(f"Expected Percentile Ranking : {evaluation.calc_expected_percentile_rank(df_pred):.6f}")
if PROBABILITY_OUTPUT is True:
    print(f"TPR : {evaluation.calc_tpr(df_pred):.6f}")

Expected Percentile Ranking : 0.160329
TPR : 0.022789


# 6. Save model and embeddings

In [12]:
tag = datetime.now().strftime("%b%d_%H-%M-%S")

path_dir_weights = PATH_DIR_DATA_PROCESS / f"mf_PROBABILITY_OUTPUT={PROBABILITY_OUTPUT}/{tag}"

path_dir_weights.mkdir(exist_ok=True, parents=True)

## 6.1. Save embeddings

In [13]:
emb_sites, emb_species = mf.get_embds()

path_embd_sites = path_dir_weights / "emb_sites.npy"
path_embd_species = path_dir_weights / "emb_species.npy"

np.save(path_embd_sites, emb_sites)
np.save(path_embd_species, emb_species)

## 6.2. Save models

In [14]:
torch.save(mf.state_dict(), path_dir_weights / "model.pt")