In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as torch_nn
from torch import Tensor
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import OrdinalEncoder
from tqdm.contrib.itertools import product
from tqdm.notebook import tqdm

In [14]:
NUM_SITES_EACH_PACK = 2
NUM_generaS_EACH_PACK = 4
BATCH_SIZE = 25
LR = 1e-4
EPS = 1e-6
N_EPOCHS = 50

DEVICE = "mps"

## 0. Load data

In [3]:
path = "/Users/macos/Uni/1st_year/period_3/DSProj/data/AllSites_SiteOccurrences_AllGenera_26.1.24.csv"

df = pd.read_csv(path)

df.head()

Unnamed: 0,SITE_NAME,Equus,Coelodonta,Bos,Gazella,Ursus,Vulpes,Cervus,Canis,Sus,...,Total_Gen_Count,Large_GenCount,Small_GenCount,smallperlarge,smallprop,Herb_GenCount,Nonherb_GenCount,DietRatio,HerbProp,mid_age
0,Aba Zawei,1,1,1,1,0,0,0,0,0,...,4,4,0,0.0,0.0,4,0,,1.0,0.0265
1,Abric Romani,1,0,1,0,1,1,1,1,1,...,12,12,0,0.0,0.0,6,5,1.2,0.5,0.055
2,Acheng_Jiaojie,0,0,0,0,0,0,1,0,0,...,7,5,2,0.4,0.285714,5,2,2.5,0.714286,0.21
3,Adler cave,1,0,0,0,0,1,0,1,0,...,10,5,5,1.0,0.5,6,4,1.5,0.6,0.0275
4,Adyrgan,1,0,0,1,0,0,0,0,0,...,11,5,6,1.2,0.545455,11,0,,1.0,2.2


## 1. Preprocess

### 1.1. Remove redundant columns

In [4]:
cols_redundant = ['LAT',
 'LONG',
 'ALTITUDE',
 'MAX_AGE',
 'BFA_MAX',
 'BFA_MAX_ABS',
 'MIN_AGE',
 'BFA_MIN',
 'BFA_MIN_ABS',
 'COUNTRY',
 'age_range',
 'Total_Gen_Count',
 'Large_GenCount',
 'Small_GenCount',
 'smallperlarge',
 'smallprop',
 'Herb_GenCount',
 'Nonherb_GenCount',
 'DietRatio',
 'HerbProp',
 'mid_age'
 ]

df = df.drop(columns=cols_redundant).set_index('SITE_NAME')

df.head()

Unnamed: 0_level_0,Equus,Coelodonta,Bos,Gazella,Ursus,Vulpes,Cervus,Canis,Sus,Homo,...,Euarctos,Paracervulus,Eostyloceros,Cervocerus,Antispiroides,Sinoryx,Prospalax,Pliopetaurista,Predicrostonyx,Boocercus
SITE_NAME,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Aba Zawei,1,1,1,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
Abric Romani,1,0,1,0,1,1,1,1,1,1,...,0,0,0,0,0,0,0,0,0,0
Acheng_Jiaojie,0,0,0,0,0,0,1,0,0,0,...,0,0,0,0,0,0,0,0,0,0
Adler cave,1,0,0,0,0,1,0,1,0,0,...,0,0,0,0,0,0,0,0,0,0
Adyrgan,1,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


### 1.2. Do OrdinalEncode

In [5]:
list_sites = df.index
list_generas = df.columns

In [6]:
enc_genera, enc_site = OrdinalEncoder(), OrdinalEncoder()

enc_genera.fit(df.columns.to_numpy().reshape(-1, 1))
enc_site.fit(df.index.to_numpy().reshape(-1, 1))

## 2. Create data

In [7]:
def iterate_pack(values: list, n: int):
    assert len(values) % n == 0

    for i in range(0, len(values), n):
        yield values[i:i+n]


data = []
for sites, generas in product(
    iterate_pack(list_sites, NUM_SITES_EACH_PACK),
    iterate_pack(list_generas, NUM_generaS_EACH_PACK)
):
    occurence = df.loc[sites, generas].to_numpy().astype(np.float32)
    sites_encoded = enc_site.transform(sites.to_numpy().reshape(-1, 1)).squeeze().astype(np.int32)
    generas_encoded = enc_genera.transform(generas.to_numpy().reshape(-1, 1)).squeeze().astype(np.int32)

    data.append({
        'occurence': occurence,
        'sites': sites_encoded,
        'generas': generas_encoded
    })


0it [00:00, ?it/s]

In [8]:
data_train, data_val = data[:38000], data[38000:]

In [9]:
class FossilNOW(Dataset):
    def __init__(self, data: list) -> None:
        super().__init__()

        self.data = data

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        x = self.data[index]    
    
        occurence = torch.tensor(x['occurence'], device=DEVICE, dtype=torch.float32)
        sites = torch.tensor(x['sites'], device=DEVICE, dtype=torch.int32)
        generas = torch.tensor(x['generas'], device=DEVICE, dtype=torch.int32)

        return occurence, sites, generas
    
dataset_train, dataset_val = FossilNOW(data_train), FossilNOW(data_val)
loader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
loader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=False)
    

## 3. Model

In [10]:
class MF(torch_nn.Module):
    def __init__(self, n_sites: int, n_generas: int, d_hid: int = 64,) -> None:
        super().__init__()

        self.emd_site = torch_nn.Embedding(n_sites, d_hid)
        self.emd_genera = torch_nn.Embedding(n_generas, d_hid)

        self.criterion = torch_nn.MSELoss(reduction='none')

    def forward(self, idx_sites: Tensor, idx_generas: Tensor, occurence: Tensor):
        embedding_sites = self.emd_site(idx_sites)
        # [bz, n_sites_pack, d_hid]
        embedding_generas = self.emd_genera(idx_generas)
        # [bz, n_generas_pack, d_hid]

        embedding_generas_T = torch.permute(embedding_generas, (0, 2, 1))
        # [bz, d_hid, n_generas_pack]

        occurence_pred = torch.bmm(embedding_sites, embedding_generas_T)
        # [bz, n_sites_pack, n_generas_pack]

        # occurence_pred = occurence_pred.flatten(start_dim=1)
        # # [bz, n_sites_pack * n_generas_pack]

        loss = self.criterion(occurence, occurence_pred)
        # [bz, n_sites_pack * n_generas_pack]

        loss = torch.sum(occurence * loss) / (torch.sum(occurence) + EPS)

        return loss

In [11]:
mf = MF(len(list_sites), len(list_generas)).to(device=DEVICE)
optimizer = AdamW(mf.parameters(), lr=LR)

## 4. Training

In [12]:
def train(model, loader, optimizer):
    model.train()

    for x in tqdm(loader, total=len(dataset_train) // BATCH_SIZE):
        optimizer.zero_grad()

        occurence, idx_sites, idx_generas = x

        loss = model(idx_sites, idx_generas, occurence)

        loss.backward()
        optimizer.step()

def val(model, loader):
    model.eval()
    with torch.no_grad():
        losses = []
        for x in tqdm(loader, total=len(dataset_val) // BATCH_SIZE):
            occurence, idx_sites, idx_generas = x

            loss = model(idx_sites, idx_generas, occurence)

            losses.append(loss.item())

    print(f"Loss val: {sum(losses)/len(losses)}")


In [15]:
for n in range(N_EPOCHS):
    print(f"== Epoch: {n:02d}")
    
    train(mf, loader_train, optimizer)
    val(mf, loader_val)

== Epoch: 00


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 30.18978347962053
== Epoch: 01


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 29.870493760748396
== Epoch: 02


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 29.545680114438813
== Epoch: 03


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 29.233518315316402
== Epoch: 04


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 28.925259153819777
== Epoch: 05


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 28.62131202372822
== Epoch: 06


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 28.310444618891744
== Epoch: 07


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 28.016283929492662
== Epoch: 08


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 27.722449957456405
== Epoch: 09


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 27.435957693576235
== Epoch: 10


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 27.150413992046154
== Epoch: 11


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 26.877484294106658
== Epoch: 12


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 26.609598339399668
== Epoch: 13


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 26.34268581089609
== Epoch: 14


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 26.078809367981062
== Epoch: 15


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 25.81915492781447
== Epoch: 16


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 25.571543840691447
== Epoch: 17


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 25.315710928707013
== Epoch: 18


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 25.071594297650805
== Epoch: 19


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 24.816915701043982
== Epoch: 20


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 24.57853559134257
== Epoch: 21


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 24.338432277805957
== Epoch: 22


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 24.104145991823916
== Epoch: 23


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 23.87320751522315
== Epoch: 24


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 23.64912395047711
== Epoch: 25


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 23.433713316467582
== Epoch: 26


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 23.21331663893364
== Epoch: 27


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 22.990511013368977
== Epoch: 28


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 22.788573561252086
== Epoch: 29


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 22.57212326747198
== Epoch: 30


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 22.37902294731548
== Epoch: 31


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 22.181447520962287
== Epoch: 32


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 21.981618014927186
== Epoch: 33


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 21.781282101754435
== Epoch: 34


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 21.585372545277174
== Epoch: 35


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 21.38975885772888
== Epoch: 36


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 21.19009603072383
== Epoch: 37


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 21.00596527840481
== Epoch: 38


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 20.82283411292106
== Epoch: 39


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 20.635441954161756
== Epoch: 40


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 20.45613807112937
== Epoch: 41


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 20.28210969924818
== Epoch: 42


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 20.10484166759817
== Epoch: 43


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 19.928293146129732
== Epoch: 44


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 19.75892073619496
== Epoch: 45


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 19.587362138138523
== Epoch: 46


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 19.421513236625103
== Epoch: 47


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 19.256390579243742
== Epoch: 48


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 19.10706268420931
== Epoch: 49


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 18.9435747805802
