In [None]:
%load_ext autoreload
%autoreload 2

In [48]:
import torch
import pandas as pd
from torchmetrics import Accuracy, Precision, Recall, F1Score
from src.models.ann import Ann
from src.models.fish import Fish
from src.models.autoencoder import Autoencoder
from src.data.load import load_data
from torch.utils.data import DataLoader
from torchtext.data.utils import get_tokenizer
from functools import partial
from src.data.utils import collate_batch

Load data:

In [15]:
vocab = torch.load("../data/vocab.pt")
tokenizer = get_tokenizer("spacy", language="en_core_web_sm")
train_dataset, val_dataset, test_dataset = load_data("../data/processed/")
vocab_size = len(vocab)
batch_size = 32


In [16]:
test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    collate_fn=partial(collate_batch, vocab=vocab, tokenizer=tokenizer),
)

Load models

In [None]:
fish_128_48 = Fish(vocab_size, 128, n4=48)
fish_128_48.load_state_dict(torch.load(f"../data/models/fish-128-48.pt"))
fish_128_48.eval()

fish_128_64 = Fish(vocab_size, 128, n4=64)
fish_128_64.load_state_dict(torch.load(f"../data/models/fish-128-64.pt"))
fish_128_64.eval()

fish_128_88 = Fish(vocab_size, 128, n4=88)
fish_128_88.load_state_dict(torch.load(f"../data/models/fish-128-88.pt"))
fish_128_88.eval()

fish_256_64 = Fish(vocab_size, 256, n4=64)
fish_256_64.load_state_dict(torch.load(f"../data/models/fish-256-64.pt"))
fish_256_64.eval()

fish_256_128 = Fish(vocab_size, 256, n4=128)
fish_256_128.load_state_dict(torch.load(f"../data/models/fish-256-128.pt"))
fish_256_128.eval()

fish_256_192 = Fish(vocab_size, 256, n4=192)
fish_256_192.load_state_dict(torch.load(f"../data/models/fish-256-192.pt"))
fish_256_192.eval()

fish_512_64 = Fish(vocab_size, 512, n4=64)
fish_512_64.load_state_dict(torch.load(f"../data/models/fish-512-64.pt"))
fish_512_64.eval()

fish_512_128 = Fish(vocab_size, 512, n4=128)
fish_512_128.load_state_dict(torch.load(f"../data/models/fish-512-128.pt"))
fish_512_128.eval()

fish_512_256 = Fish(vocab_size, 512, n4=256)
fish_512_256.load_state_dict(torch.load(f"../data/models/fish-512-256.pt"))
fish_512_256.eval()

In [40]:
autoencoder_128 = Autoencoder(vocab_size, 128)
autoencoder_128.load_state_dict(torch.load(f"../data/models/autoencoder-128.pt"))
autoencoder_128.eval()

autoencoder_256 = Autoencoder(vocab_size, 256)
autoencoder_256.load_state_dict(torch.load(f"../data/models/autoencoder-256.pt"))
autoencoder_256.eval()

autoencoder_512 = Autoencoder(vocab_size, 512)
autoencoder_512.load_state_dict(torch.load(f"../data/models/autoencoder-512.pt"))
autoencoder_512.eval()

Autoencoder(
  (embedding): EmbeddingBag(196674, 512, mode=mean)
  (encoder): Sequential(
    (0): Linear(in_features=512, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=32, bias=True)
    (5): ReLU()
  )
  (decoder): Sequential(
    (0): Linear(in_features=32, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=512, bias=True)
    (5): ReLU()
  )
)

In [42]:
ann_128_48 = Ann(vocab_size, 128, autoencoder_128.encoder, n4=48)
ann_128_48.load_state_dict(torch.load(f"../data/models/ann-128-48.pt"))
ann_128_48.eval()

ann_128_64 = Ann(vocab_size, 128, autoencoder_128.encoder, n4=64)
ann_128_64.load_state_dict(torch.load(f"../data/models/ann-128-64.pt"))
ann_128_64.eval()

ann_128_88 = Ann(vocab_size, 128, autoencoder_128.encoder, n4=88)
ann_128_88.load_state_dict(torch.load(f"../data/models/ann-128-88.pt"))
ann_128_88.eval()

ann_256_64 = Ann(vocab_size, 256, autoencoder_256.encoder, n4=64)
ann_256_64.load_state_dict(torch.load(f"../data/models/ann-256-64.pt"))
ann_256_64.eval()

ann_256_128 = Ann(vocab_size, 256, autoencoder_256.encoder, n4=128)
ann_256_128.load_state_dict(torch.load(f"../data/models/ann-256-128.pt"))
ann_256_128.eval()

ann_256_192 = Ann(vocab_size, 256, autoencoder_256.encoder, n4=192)
ann_256_192.load_state_dict(torch.load(f"../data/models/ann-256-192.pt"))
ann_256_192.eval()

ann_512_64 = Ann(vocab_size, 512, autoencoder_512.encoder, n4=64)
ann_512_64.load_state_dict(torch.load(f"../data/models/ann-512-64.pt"))
ann_512_64.eval()

ann_512_128 = Ann(vocab_size, 512, autoencoder_512.encoder, n4=128)
ann_512_128.load_state_dict(torch.load(f"../data/models/ann-512-128.pt"))
ann_512_128.eval()

ann_512_256 = Ann(vocab_size, 512, autoencoder_512.encoder, n4=256)
ann_512_256.load_state_dict(torch.load(f"../data/models/ann-512-256.pt"))
ann_512_256.eval()

Ann(
  (embedding): EmbeddingBag(196674, 512, mode=mean)
  (encoder): Sequential(
    (0): Linear(in_features=512, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=32, bias=True)
    (5): ReLU()
  )
  (decoder): Sequential(
    (0): Linear(in_features=32, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=32, bias=True)
    (3): ReLU()
    (4): Linear(in_features=32, out_features=16, bias=True)
    (5): ReLU()
    (6): Linear(in_features=16, out_features=2, bias=True)
  )
  (softmax): Softmax(dim=1)
)

In [44]:
fish_models = {
    "fish-128-48": fish_128_48,
    "fish-128-64": fish_128_64,
    "fish-128-88": fish_128_88,
    "fish-256-64": fish_256_64,
    "fish-256-128": fish_256_128,
    "fish-256-192": fish_256_192,
    "fish-512-64": fish_512_64,
    "fish-512-128": fish_512_128,
    "fish-512-256": fish_512_256,
}

ann_models = {
    "ann-128-48": ann_128_48,
    "ann-128-64": ann_128_64,
    "ann-128-88": ann_128_88,
    "ann-256-64": ann_256_64,
    "ann-256-128": ann_256_128,
    "ann-256-192": ann_256_192,
    "ann-512-64": ann_512_64,
    "ann-512-128": ann_512_128,
    "ann-512-256": ann_512_256,
}

Get metrics:

In [51]:
def get_metrics(model, dataloader, show=False):
    
    acc = Accuracy(task="binary",num_classes=2)
    prec = Precision(task="binary",num_classes=2)
    rec = Recall(task="binary",num_classes=2)
    f1_score = F1Score(task="binary",num_classes=2)

    for idx, (text, label, offsets) in enumerate(test_dataloader):
        y_hat = model(text, offsets).argmax(1)
        acc(y_hat, label)
        prec(y_hat, label)
        rec(y_hat, label)
        f1_score(y_hat, label)

    if(show):
        print(f"Accuracy: {acc.compute():.4f}")
        print(f"Precision: {prec.compute():.4f}")
        print(f"Recall: {rec.compute():.4f}")
        print(f"F1 Score: {f1_score.compute():.4f}")
    return acc.compute().item(), prec.compute().item(), rec.compute().item(), f1_score.compute().item()

In [52]:
report = []
for name,model in fish_models.items():
    accuracy, precision, recall, f1_score = get_metrics(model, test_dataloader, show=False)
    report.append(
        {"model":name,
         "accuracy":accuracy,
         "precision":precision,
         "recall":recall,
         "f1_score":f1_score})

for name,model in ann_models.items():
    accuracy, precision, recall, f1_score = get_metrics(model, test_dataloader, show=False)
    report.append(
        {"model":name,
         "accuracy":accuracy,
         "precision":precision,
         "recall":recall,
         "f1_score":f1_score})

In [57]:
pd.DataFrame(report).sort_values(by="f1_score", ascending=False)

Unnamed: 0,model,accuracy,precision,recall,f1_score
10,ann-128-64,0.944263,0.713101,0.756376,0.734102
0,fish-128-48,0.940908,0.686351,0.771768,0.726558
4,fish-256-128,0.940371,0.682577,0.773527,0.725211
6,fish-512-64,0.941713,0.697117,0.755057,0.724931
9,ann-128-48,0.93724,0.655148,0.808707,0.723873
7,fish-512-128,0.94225,0.707296,0.737467,0.722067
14,ann-256-192,0.947573,0.786085,0.665787,0.720952
12,ann-256-64,0.947976,0.79345,0.66051,0.720902
1,fish-128-64,0.938314,0.668044,0.782322,0.720681
8,fish-512-256,0.941624,0.702805,0.738347,0.720137


In [58]:
pd.DataFrame(report)

Unnamed: 0,model,accuracy,precision,recall,f1_score
0,fish-128-48,0.940908,0.686351,0.771768,0.726558
1,fish-128-64,0.938314,0.668044,0.782322,0.720681
2,fish-128-88,0.936882,0.660827,0.779683,0.715352
3,fish-256-64,0.938135,0.669714,0.773087,0.717698
4,fish-256-128,0.940371,0.682577,0.773527,0.725211
5,fish-256-192,0.937553,0.66807,0.76737,0.714286
6,fish-512-64,0.941713,0.697117,0.755057,0.724931
7,fish-512-128,0.94225,0.707296,0.737467,0.722067
8,fish-512-256,0.941624,0.702805,0.738347,0.720137
9,ann-128-48,0.93724,0.655148,0.808707,0.723873
