In [1]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix, f1_score

from dna_classification.models import DNASequenceClassifier
from dna_classification.tokenization import DNATokenizer

from tqdm import tqdm

plt.rcParams['figure.dpi'] = 300

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data = pd.read_csv("data/virus.txt", skiprows=1, header=None, names=["sequence", "label"], sep="\t")

In [4]:
ks = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
vocab_sizes = []

for k in tqdm(ks, desc="k", total=len(ks)):
    tokenizer = DNATokenizer()
    tokenizer.build_vocab("data/virus.txt", k=k)
    vocab_sizes.append(len(tokenizer.token_to_id))

Building vocab: 69352it [00:03, 18948.66it/s]
100%|██████████| 24/24 [00:00<00:00, 898779.43it/s]
Building vocab: 69352it [00:05, 13414.89it/s]/it]
100%|██████████| 500/500 [00:00<00:00, 3934619.14it/s]
Building vocab: 69352it [00:05, 12928.76it/s]/it]
100%|██████████| 8492/8492 [00:00<00:00, 5116062.85it/s]
Building vocab: 69352it [00:05, 12238.94it/s]/it]
100%|██████████| 45611/45611 [00:00<00:00, 4889245.55it/s]
Building vocab: 69352it [00:05, 12370.23it/s]/it]
100%|██████████| 70572/70572 [00:00<00:00, 5883881.41it/s]
Building vocab: 69352it [00:05, 12316.42it/s]/it]
100%|██████████| 86323/86323 [00:00<00:00, 5280685.26it/s]
Building vocab: 69352it [00:05, 12014.53it/s]/it]
100%|██████████| 100918/100918 [00:00<00:00, 4634173.48it/s]
Building vocab: 69352it [00:05, 12364.79it/s]/it]
100%|██████████| 115337/115337 [00:00<00:00, 4268770.71it/s]
Building vocab: 69352it [00:05, 12279.05it/s]/it]
100%|██████████| 129589/129589 [00:00<00:00, 4928464.08it/s]
Building vocab: 69352it [00:05

In [4]:
DROPOUT = 0.1
K = 1

tokenizer = DNATokenizer()
tokenizer.build_vocab("data/virus.txt", k=K)


model = DNASequenceClassifier(
    vocab_size=tokenizer.vocab_size,
    embedding_dim=384,
    hidden_dim=64,
    num_layers=2,
    num_classes=data["label"].nunique(),
    dropout=DROPOUT
)

model.add_tokenizer(tokenizer)

In [None]:
EPOCHS = 20
BATCH_SIZE = 256
LR = 0.0001

train_loss, val_loss = model.train_model(
    data=data,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    device="cuda",
    optimizer_params={
        "lr": LR,
    }
)

In [None]:
model.export(f"checkpoints/flu_covid_e{EPOCHS}_bs{BATCH_SIZE}_lr{LR}_d{DROPOUT}")

In [None]:
# create loss curves
fig, ax = plt.subplots(figsize=(8, 6))

sns.lineplot(x=range(len(train_loss)), y=train_loss, ax=ax, label="train")
sns.lineplot(x=range(len(val_loss)), y=val_loss, ax=ax, label="val")

ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")

# increase font
for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
             ax.get_xticklabels() + ax.get_yticklabels()):
    item.set_fontsize(15)

# increase line width
for line in ax.lines:
    line.set_linewidth(2)

plt.tight_layout()
fig.savefig(f"figs/loss_e{EPOCHS}_bs{BATCH_SIZE}_lr{LR}_d{DROPOUT}.png")

In [None]:
# get accuracy
model.cpu()
model.eval()
correct = 0
total = 0

# subset
data_sample = data.sample(1000)

for i in tqdm(range(len(data_sample))):
    sequence = data.iloc[i]["sequence"]
    label = data.iloc[i]["label"]
    pred = model.predict(sequence)
    if pred == label:
        correct += 1
    total += 1

print(f"Accuracy: {correct / total * 100}%")

In [None]:
# assign prediction to each sequence in the data
data_sample["prediction"] = data_sample["sequence"].apply(lambda x: model.predict(x))

# get confusion matrix
confusion_matrix = confusion_matrix(data_sample["label"], data_sample["prediction"])

# get F1 score
f1_score = f1_score(data_sample["label"], data_sample["prediction"], average="macro")

# plot the confusion matrix
fig, ax = plt.subplots(figsize=(8, 6))

sns.heatmap(confusion_matrix, annot=True, ax=ax, cmap="Blues", fmt="g")

ax.set_xlabel("Predicted")
ax.set_ylabel("True")

# increase font
for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
             ax.get_xticklabels() + ax.get_yticklabels()):
    item.set_fontsize(15)

plt.tight_layout()

fig.savefig(f"figs/cm_e{EPOCHS}_bs{BATCH_SIZE}_lr{LR}_d{DROPOUT}.png")