In [None]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.feature_extraction.text import CountVectorizer
from tqdm import tqdm
from fnn import FNN
import joblib

In [None]:
# Read in data used to train / evaluate the router

python = pd.read_csv('data/python.csv')

medical = pd.read_csv('data/medical.csv')

education = pd.read_parquet("hf://datasets/kaitchup/qa-chat-persona-education/data/train-00000-of-00001.parquet")

In [None]:
# Create both questions and the labels for the router 

texts = list(python['Question'])

labels = [0] * len(python)

labels.extend([1] * len(medical))

texts.extend(list(medical['question']))

texts.extend(list(education['question']))

labels.extend([2] * len(education))

In [None]:
# transform the the texts into bag of words
vectorizer = CountVectorizer(max_features=1000)
X = vectorizer.fit_transform(texts).toarray()
y = torch.tensor(labels, dtype=torch.long)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [None]:
# create dataset used for torch data loader
class TextDataset(Dataset):
    def __init__(self, data, labels):
        self.data = torch.tensor(data, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

train_dataset = TextDataset(X_train, y_train)
test_dataset = TextDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

  self.labels = torch.tensor(labels, dtype=torch.long)


In [None]:
input_dim = X_train.shape[1]
model = FNN(input_dim)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [42]:
def compute_accuracy(logits, labels):
    predictions = torch.argmax(logits, dim=1)
    correct = (predictions == labels).sum().item()
    return correct / labels.size(0)

In [None]:
# train the router for num_epochs 
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    epoch_acc = 0

    for batch_data, batch_labels in tqdm(train_loader):
        optimizer.zero_grad()
        outputs = model(batch_data)

        loss = criterion(outputs, batch_labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item() * batch_data.size(0)
        epoch_acc += compute_accuracy(outputs, batch_labels) * batch_data.size(0)

    epoch_loss /= len(train_loader.dataset)
    epoch_acc /= len(train_loader.dataset)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 8118/8118 [00:07<00:00, 1067.54it/s]


Epoch 1/5, Loss: 0.0164, Accuracy: 0.9943


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 8118/8118 [00:08<00:00, 926.06it/s]


Epoch 2/5, Loss: 0.0028, Accuracy: 0.9992


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 8118/8118 [00:09<00:00, 875.81it/s]


Epoch 3/5, Loss: 0.0012, Accuracy: 0.9996


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 8118/8118 [00:08<00:00, 1010.25it/s]


Epoch 4/5, Loss: 0.0017, Accuracy: 0.9995


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 8118/8118 [00:06<00:00, 1280.40it/s]

Epoch 5/5, Loss: 0.0005, Accuracy: 0.9998





In [None]:
# evaluate the router on the held out test set
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for batch_data, batch_labels in test_loader:
        outputs = model(batch_data)
        predicted = torch.argmax(outputs, dim=1)
        total += batch_labels.size(0)
        correct += (predicted == batch_labels).sum().item()

    print(f"Test Accuracy: {100 * correct / total:.2f}%")

Test Accuracy: 99.91%


In [None]:
# Save both the router and the bag of words vectorizer

torch.save(model.state_dict(), 'router.safetensors')

joblib.dump(vectorizer, 'count_vectorizer.joblib')
print("Vectorizer has been saved to 'count_vectorizer.joblib'.")

Vectorizer has been saved to 'count_vectorizer.joblib'.
