In [1]:
%cd ../

/Users/macos/Uni/1st_year/period_3/DSProj/code/models


In [2]:
import math
from typing import Literal
from pathlib import Path

import numpy as np
import torch
import torch.nn as torch_nn
import pandas as pd
from torch import Tensor
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm

import evaluation
import utils

In [3]:
NUM_SITES_EACH_PACK = 2
NUM_GENERA_EACH_PACK = 4
BATCH_SIZE = 25
LR = 1e-4
EPS = 1e-6
N_EPOCHS = 30
USE_REGULARIZATIOn = True

DEVICE = "mps"
PATH_DIR_DATA_PROCESS = Path("data_processed")

# 1. Load data train/val and encoder

In [4]:
path_dir_data = PATH_DIR_DATA_PROCESS / "trainval"

path_data_train = path_dir_data / "data_train.npy"
path_data_val = path_dir_data / "data_val.npy"

data_train = np.load(path_data_train, allow_pickle=True)
data_val = np.load(path_data_val, allow_pickle=True)

In [5]:
path_dir_encode = PATH_DIR_DATA_PROCESS / "encoder"

path_enc_genera = path_dir_encode / "ordinal_enc_genera.json"
path_enc_site = path_dir_encode / "ordinal_enc_site.json"

enc_genera = utils.CategoryDict.from_file(path_enc_genera)
enc_site = utils.CategoryDict.from_file(path_enc_site)

# 2. Create train/val data_loader

In [6]:
class FossilNOW(Dataset):
    def __init__(self, data: list) -> None:
        super().__init__()

        self.data = data

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        x = self.data[index]    
    
        occurence = torch.tensor(x['occurence'], device=DEVICE, dtype=torch.float32)
        sites = torch.tensor(x['sites'], device=DEVICE, dtype=torch.int32)
        genera = torch.tensor(x['genera'], device=DEVICE, dtype=torch.int32)

        return occurence, sites, genera
    
dataset_train, dataset_val = FossilNOW(data_train), FossilNOW(data_val)
loader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
loader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=False)
    

# 3. Model

In [7]:
class MF(torch_nn.Module):
    def __init__(
        self, 
        n_sites: int, 
        n_genera: int, 
        d_hid: int = 64,
        use_regularization: bool = True,
        training_strategy: Literal['regression', 'classification'] = 'classification'
    ) -> None:
        super().__init__()

        self.use_reg = use_regularization
        self.strategy = training_strategy

        self.emd_site = torch_nn.Embedding(n_sites, d_hid)
        self.emd_genera = torch_nn.Embedding(n_genera, d_hid)
        self.batchnorm = torch_nn.BatchNorm2d(1)

        self.act_relu = torch_nn.ReLU()
        if training_strategy == "classification":
            self.criterion = torch_nn.BCEWithLogitsLoss(reduction='none')
        else:
            self.criterion = torch_nn.MSELoss(reduction='none')

    def forward(self, idx_sites: Tensor, idx_genera: Tensor, occurence: Tensor):
        embedding_sites = self.emd_site(idx_sites)
        # [bz, n_sites_pack, d_hid]
        embedding_genera = self.emd_genera(idx_genera)
        # [bz, n_genera_pack, d_hid]

        embedding_genera_T = torch.permute(embedding_genera, (0, 2, 1))
        # [bz, d_hid, n_genera_pack]

        occurence_pred = torch.bmm(embedding_sites, embedding_genera_T)
        # [bz, n_sites_pack, n_genera_pack]

        occurence_pred = self.batchnorm(occurence_pred.unsqueeze(1)).squeeze(1)
        # [bz, n_sites_pack, n_genera_pack]

        if self.strategy == "regression":
            occurence_pred = self.act_relu(occurence_pred)
            # [bz, n_sites_pack, n_genera_pack]

        loss = self.criterion(occurence, occurence_pred)
        # [bz, n_sites_pack * n_genera_pack]

        loss = torch.sum(occurence * loss) / (torch.sum(occurence) + EPS)

        if self.use_reg is True:
            loss_reg = torch.norm(self.emd_site._parameters['weight']) + torch.norm(self.emd_genera._parameters['weight'])
            loss += loss_reg

        return loss, occurence_pred

In [8]:
training_strategy = "regression"

mf = MF(
    enc_site.size(),
    enc_genera.size(),
    use_regularization=False,
    training_strategy=training_strategy
).to(device=DEVICE)
optimizer = AdamW(mf.parameters(), lr=LR, weight_decay=2e-5)

## 4. Training

In [9]:
def train(model, loader, optimizer):
    model.train()

    for x in tqdm(loader, total=len(dataset_train) // BATCH_SIZE):
        optimizer.zero_grad()

        occurence, idx_sites, idx_genera = x

        loss, _ = model(idx_sites, idx_genera, occurence)

        loss.backward()
        optimizer.step()

def val(model, loader) -> list:
    preds = []

    model.eval()
    with torch.no_grad():
        losses = []
        for x in tqdm(loader, total=int(math.ceil(len(dataset_val) // BATCH_SIZE))):
            occurence, idx_sites, idx_genera = x

            loss, pred = model(idx_sites, idx_genera, occurence)

            preds.append({
                'sites': idx_sites.detach().cpu().numpy(),
                'genera': idx_genera.detach().cpu().numpy(), 
                'occurence': occurence.detach().cpu().numpy(),
                'prediction': pred.detach().cpu().numpy()
            })

            losses.append(loss.item())

    print(f"Loss val: {sum(losses)/len(losses)}")

    return preds


In [10]:
for n in range(N_EPOCHS):
    print(f"== Epoch: {n:02d}")
    
    train(mf, loader_train, optimizer)
    val(mf, loader_val)

== Epoch: 00


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 0.6494366146117738
== Epoch: 01


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 0.6301417766699513
== Epoch: 02


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 0.5948479043193233
== Epoch: 03


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 0.5566928951630314
== Epoch: 04


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 0.5002825493806774
== Epoch: 05


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 0.42392012444370003
== Epoch: 06


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 0.32024901554675644
== Epoch: 07


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 0.19823644954928207
== Epoch: 08


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 0.09073816713209047
== Epoch: 09


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 0.032230840876834616
== Epoch: 10


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 0.008466766893160666
== Epoch: 11


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 0.0009762679379774163
== Epoch: 12


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 1.6985722807519446e-05
== Epoch: 13


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 2.558513229975775e-09
== Epoch: 14


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 7.688482575658605e-13
== Epoch: 15


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 4.946378901068735e-13
== Epoch: 16


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 8.648433929186901e-14
== Epoch: 17


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 1.3832872997635246e-14
== Epoch: 18


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 3.4492358617928856e-15
== Epoch: 19


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 0.0
== Epoch: 20


  0%|          | 0/1520 [00:00<?, ?it/s]

KeyboardInterrupt: 

# 5. Evaluate

In [11]:
list_preds = []

preds = val(mf, loader_val)
for pred in preds:
    for sites, genera, occu, pre in zip(pred['sites'], pred['genera'], pred['occurence'], pred['prediction']):
        for i in range(NUM_SITES_EACH_PACK):
            for j in range(NUM_GENERA_EACH_PACK):
                if occu[i, j] > 0:
                    site = sites[i]
                    gen = genera[j]
                    oc = occu[i, j]
                    p = pre[i, j]

                    list_preds.append({
                        'site': site,
                        'genera': gen,
                        'occurence': oc,
                        'pred': p
                    })

df_pred = pd.DataFrame.from_records(list_preds)

df_pred.head()

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 0.0


Unnamed: 0,site,genera,occurence,pred
0,707,96,1.0,1.0
1,200,148,1.0,1.0
2,200,150,1.0,1.0
3,201,148,1.0,1.0
4,201,150,1.0,1.0


In [12]:
print(f"MSE : {evaluation.calc_mse(df_pred['pred']):.6f}")
print(f"RMSE: {evaluation.calc_rmse(df_pred['pred']):.6f}")
print(f"TPR : {evaluation.calc_tpr(df_pred['pred']):.6f}")

MSE : 0.000000
RMSE: 0.000000
TPR : 1.000000


# 6. Save embedding

In [15]:
path_dir_emb = PATH_DIR_DATA_PROCESS / "embedding_mf"

path_dir_emb.mkdir(exist_ok=True, parents=True)

In [16]:
emb_site = mf.emd_site._parameters['weight'].detach().cpu().numpy()
emb_species = mf.emd_genera._parameters['weight'].detach().cpu().numpy()

np.save(path_dir_emb / f"embd_site_mf_{training_strategy}.npy", emb_site)
np.save(path_dir_emb / f"embd_species_mf_{training_strategy}.npy", emb_species)