In [1]:
import random
import math

import numpy as np
import pandas as pd
import torch
import torch.nn as torch_nn
from torch import Tensor
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import OrdinalEncoder
from tqdm.contrib.itertools import product
from tqdm.notebook import tqdm

In [2]:
NUM_SITES_EACH_PACK = 2
NUM_GENERA_EACH_PACK = 4
BATCH_SIZE = 25
LR = 1e-4
EPS = 1e-6
N_EPOCHS = 30
USE_REGULARIZATIOn = True

DEVICE = "mps"

## 0. Load data

In [3]:
path = "/Users/macos/Uni/1st_year/period_3/DSProj/data/AllSites_SiteOccurrences_AllGenera_26.1.24.csv"

df = pd.read_csv(path)

df.head()

Unnamed: 0,SITE_NAME,Equus,Coelodonta,Bos,Gazella,Ursus,Vulpes,Cervus,Canis,Sus,...,Total_Gen_Count,Large_GenCount,Small_GenCount,smallperlarge,smallprop,Herb_GenCount,Nonherb_GenCount,DietRatio,HerbProp,mid_age
0,Aba Zawei,1,1,1,1,0,0,0,0,0,...,4,4,0,0.0,0.0,4,0,,1.0,0.0265
1,Abric Romani,1,0,1,0,1,1,1,1,1,...,12,12,0,0.0,0.0,6,5,1.2,0.5,0.055
2,Acheng_Jiaojie,0,0,0,0,0,0,1,0,0,...,7,5,2,0.4,0.285714,5,2,2.5,0.714286,0.21
3,Adler cave,1,0,0,0,0,1,0,1,0,...,10,5,5,1.0,0.5,6,4,1.5,0.6,0.0275
4,Adyrgan,1,0,0,1,0,0,0,0,0,...,11,5,6,1.2,0.545455,11,0,,1.0,2.2


## 1. Preprocess

### 1.1. Remove redundant columns

In [4]:
cols_redundant = ['LAT',
 'LONG',
 'ALTITUDE',
 'MAX_AGE',
 'BFA_MAX',
 'BFA_MAX_ABS',
 'MIN_AGE',
 'BFA_MIN',
 'BFA_MIN_ABS',
 'COUNTRY',
 'age_range',
 'Total_Gen_Count',
 'Large_GenCount',
 'Small_GenCount',
 'smallperlarge',
 'smallprop',
 'Herb_GenCount',
 'Nonherb_GenCount',
 'DietRatio',
 'HerbProp',
 'mid_age'
 ]

df = df.drop(columns=cols_redundant).set_index('SITE_NAME')

df.head()

Unnamed: 0_level_0,Equus,Coelodonta,Bos,Gazella,Ursus,Vulpes,Cervus,Canis,Sus,Homo,...,Euarctos,Paracervulus,Eostyloceros,Cervocerus,Antispiroides,Sinoryx,Prospalax,Pliopetaurista,Predicrostonyx,Boocercus
SITE_NAME,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Aba Zawei,1,1,1,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
Abric Romani,1,0,1,0,1,1,1,1,1,1,...,0,0,0,0,0,0,0,0,0,0
Acheng_Jiaojie,0,0,0,0,0,0,1,0,0,0,...,0,0,0,0,0,0,0,0,0,0
Adler cave,1,0,0,0,0,1,0,1,0,0,...,0,0,0,0,0,0,0,0,0,0
Adyrgan,1,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


### 1.2. Do OrdinalEncode

In [5]:
list_sites = df.index
list_genera = df.columns

In [6]:
enc_genera, enc_site = OrdinalEncoder(), OrdinalEncoder()

enc_genera.fit(df.columns.to_numpy().reshape(-1, 1))
enc_site.fit(df.index.to_numpy().reshape(-1, 1))

## 2. Create data

In [7]:
def iterate_pack(values: list, n: int):
    assert len(values) % n == 0

    for i in range(0, len(values), n):
        yield values[i:i+n]


data = []
for sites, genera in product(
    iterate_pack(list_sites, NUM_SITES_EACH_PACK),
    iterate_pack(list_genera, NUM_GENERA_EACH_PACK)
):
    occurence = df.loc[sites, genera].to_numpy().astype(np.float32)
    sites_encoded = enc_site.transform(sites.to_numpy().reshape(-1, 1)).squeeze().astype(np.int32)
    genera_encoded = enc_genera.transform(genera.to_numpy().reshape(-1, 1)).squeeze().astype(np.int32)

    data.append({
        'occurence': occurence,
        'sites': sites_encoded,
        'genera': genera_encoded
    })


0it [00:00, ?it/s]

In [8]:
random.shuffle(data)

data_train, data_val = data[:38000], data[38000:]

In [9]:
class FossilNOW(Dataset):
    def __init__(self, data: list) -> None:
        super().__init__()

        self.data = data

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        x = self.data[index]    
    
        occurence = torch.tensor(x['occurence'], device=DEVICE, dtype=torch.float32)
        sites = torch.tensor(x['sites'], device=DEVICE, dtype=torch.int32)
        genera = torch.tensor(x['genera'], device=DEVICE, dtype=torch.int32)

        return occurence, sites, genera
    
dataset_train, dataset_val = FossilNOW(data_train), FossilNOW(data_val)
loader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
loader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=False)
    

## 3. Model

In [10]:
class MF(torch_nn.Module):
    def __init__(
        self, 
        n_sites: int, 
        n_genera: int, 
        d_hid: int = 64,
        use_regularization: bool = True
    ) -> None:
        super().__init__()

        self.use_reg = use_regularization

        self.emd_site = torch_nn.Embedding(n_sites, d_hid)
        self.emd_genera = torch_nn.Embedding(n_genera, d_hid)
        self.batchnorm = torch_nn.BatchNorm2d(1)
        # self.act_relu = torch_nn.ReLU()

        self.criterion = torch_nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, idx_sites: Tensor, idx_genera: Tensor, occurence: Tensor):
        embedding_sites = self.emd_site(idx_sites)
        # [bz, n_sites_pack, d_hid]
        embedding_genera = self.emd_genera(idx_genera)
        # [bz, n_genera_pack, d_hid]

        embedding_genera_T = torch.permute(embedding_genera, (0, 2, 1))
        # [bz, d_hid, n_genera_pack]

        occurence_pred = torch.bmm(embedding_sites, embedding_genera_T)
        # [bz, n_sites_pack, n_genera_pack]

        occurence_pred = self.batchnorm(occurence_pred.unsqueeze(1)).squeeze(1)
        # [bz, n_sites_pack, n_genera_pack]
        # occurence_pred = self.act_relu(occurence_pred)
        # [bz, n_sites_pack, n_genera_pack]


        loss = self.criterion(occurence, occurence_pred)
        # [bz, n_sites_pack * n_genera_pack]

        loss = torch.sum(occurence * loss) / (torch.sum(occurence) + EPS)

        if self.use_reg is True:
            loss_reg = torch.norm(self.emd_site._parameters['weight']) + torch.norm(self.emd_genera._parameters['weight'])
            loss += loss_reg

        return loss, occurence_pred

In [11]:
mf = MF(len(list_sites), len(list_genera), use_regularization=False).to(device=DEVICE)
optimizer = AdamW(mf.parameters(), lr=LR, weight_decay=2e-5)

## 4. Training

In [12]:
def train(model, loader, optimizer):
    model.train()

    for x in tqdm(loader, total=len(dataset_train) // BATCH_SIZE):
        optimizer.zero_grad()

        occurence, idx_sites, idx_genera = x

        loss, _ = model(idx_sites, idx_genera, occurence)

        loss.backward()
        optimizer.step()

def val(model, loader) -> list:
    preds = []

    model.eval()
    with torch.no_grad():
        losses = []
        for x in tqdm(loader, total=int(math.ceil(len(dataset_val) // BATCH_SIZE))):
            occurence, idx_sites, idx_genera = x

            loss, pred = model(idx_sites, idx_genera, occurence)

            preds.append({
                'sites': idx_sites.detach().cpu().numpy(),
                'genera': idx_genera.detach().cpu().numpy(), 
                'occurence': occurence.detach().cpu().numpy(),
                'prediction': pred.detach().cpu().numpy()
            })

            losses.append(loss.item())

            break

    print(f"Loss val: {sum(losses)/len(losses)}")

    return preds


In [13]:
for n in range(N_EPOCHS):
    print(f"== Epoch: {n:02d}")
    
    train(mf, loader_train, optimizer)
    val(mf, loader_val)

== Epoch: 00


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 1.328658938407898
== Epoch: 01


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 1.165469765663147
== Epoch: 02


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 1.0306644439697266
== Epoch: 03


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 0.8851628303527832
== Epoch: 04


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 0.7080410122871399
== Epoch: 05


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 0.5688007473945618
== Epoch: 06


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 0.3960501551628113
== Epoch: 07


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 0.2934022843837738
== Epoch: 08


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: 0.10272525250911713
== Epoch: 09


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: -0.03810456395149231
== Epoch: 10


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: -0.16479533910751343
== Epoch: 11


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: -0.3422618806362152
== Epoch: 12


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: -0.478246808052063
== Epoch: 13


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: -0.6063790321350098
== Epoch: 14


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: -0.7770514488220215
== Epoch: 15


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: -0.9590171575546265
== Epoch: 16


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: -1.0544472932815552
== Epoch: 17


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: -1.2230991125106812
== Epoch: 18


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: -1.3814195394515991
== Epoch: 19


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: -1.4975454807281494
== Epoch: 20


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: -1.6582144498825073
== Epoch: 21


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: -1.7595473527908325
== Epoch: 22


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: -1.9559135437011719
== Epoch: 23


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: -2.115933418273926
== Epoch: 24


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: -2.2064719200134277
== Epoch: 25


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: -2.36443829536438
== Epoch: 26


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: -2.4882264137268066
== Epoch: 27


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: -2.6722822189331055
== Epoch: 28


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: -2.7862932682037354
== Epoch: 29


  0%|          | 0/1520 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: -2.98730731010437


## Save necessary things

In [14]:
list_preds = []

preds = val(mf, loader_val)
for pred in preds:
    break

for sites, genera, occu, pre in zip(pred['sites'], pred['genera'], pred['occurence'], pred['prediction']):
    for i in range(NUM_SITES_EACH_PACK):
        for j in range(NUM_GENERA_EACH_PACK):
            site = sites[i]
            gen = genera[j]
            oc = occu[i, j]
            p = pre[i, j]

            list_preds.append({
                'site': site,
                'genera': gen,
                'occurence': oc,
                'pred': p
            })


pd.DataFrame.from_records(list_preds).to_csv("val_preds.csv", index=False)

  0%|          | 0/102 [00:00<?, ?it/s]

Loss val: -2.98730731010437


In [15]:
emb_site = mf.emd_site._parameters['weight'].detach().cpu().numpy()
emb_species = mf.emd_genera._parameters['weight'].detach().cpu().numpy()

np.save("embd_site_mf.np", emb_site)
np.save("embd_species_mf.np", emb_species)

In [16]:
ordinal_enc_species = enc_genera.categories_
ordinal_enc_site = enc_site.categories_

np.save("ordinal_enc_site_mf.npy", ordinal_enc_site)
np.save("ordinal_enc_species_mf.npy", ordinal_enc_species)

In [17]:
# np.save("data_train.npy", data_train)
# np.save("data_val.npy", data_val)