In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix, f1_score

from dna_classification.models import DNASequenceClassifier
from dna_classification.tokenization import DNATokenizer

from tqdm import tqdm

plt.rcParams['figure.dpi'] = 300

In [None]:
data = pd.read_csv("data/virus.txt", skiprows=1, header=None, names=["sequence", "label"], sep="\t")

In [4]:
tokenizer = DNATokenizer()
tokenizer.build_vocab("data/virus.txt")

model = DNASequenceClassifier(
    vocab_size=tokenizer.vocab_size,
    embedding_dim=384,
    hidden_dim=64,
    num_layers=2,
    num_classes=data["label"].nunique(),
    dropout=0.05
)

model.add_tokenizer(tokenizer)

In [None]:
train_loss, val_loss = model.train_model(
    data=data,
    epochs=50,
    batch_size=256,
    device="cuda",
    optimizer_params={
        "lr": 0.0001,
    }
)

In [None]:
model.export("checkpoints")

In [None]:
# create loss curves
fig, ax = plt.subplots(figsize=(8, 6))

sns.lineplot(x=range(len(train_loss)), y=train_loss, ax=ax, label="train")
sns.lineplot(x=range(len(val_loss)), y=val_loss, ax=ax, label="val")

ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")

# increase font
for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
             ax.get_xticklabels() + ax.get_yticklabels()):
    item.set_fontsize(15)

# increase line width
for line in ax.lines:
    line.set_linewidth(2)

plt.tight_layout()
fig.savefig("loss_curves.png")

In [None]:
# get accuracy
model.cpu()
model.eval()
correct = 0
total = 0

# subset
data_sample = data.sample(1000)

for i in tqdm(range(len(data_sample))):
    sequence = data.iloc[i]["sequence"]
    label = data.iloc[i]["label"]
    pred = model.predict(sequence)
    if pred == label:
        correct += 1
    total += 1

print(f"Accuracy: {correct / total * 100}%")

In [None]:
# assign prediction to each sequence in the data
data_sample["prediction"] = data_sample["sequence"].apply(lambda x: model.predict(x))

# get confusion matrix
confusion_matrix = confusion_matrix(data_sample["label"], data_sample["prediction"])

# get F1 score
f1_score = f1_score(data_sample["label"], data_sample["prediction"], average="macro")

# plot the confusion matrix
fig, ax = plt.subplots(figsize=(8, 6))

sns.heatmap(confusion_matrix, annot=True, ax=ax, cmap="Blues", fmt="g")

ax.set_xlabel("Predicted")
ax.set_ylabel("True")

# increase font
for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
             ax.get_xticklabels() + ax.get_yticklabels()):
    item.set_fontsize(15)

plt.tight_layout()

fig.savefig("confusion_matrix.png")