In [None]:
! pip install scikit-learn
! pip install datasets
! pip install wandb
! pip install seaborn 

In [None]:
import pandas as pd
from datasets import load_dataset
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.linear_model import SGDClassifier
from sklearn import metrics
import wandb

import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
%matplotlib inline


# Login to Weights and Biases


In [None]:
wandb.login()


In [None]:
wandb.init(project="sutd-mlops-project")


In [None]:
dataset = load_dataset("rotten_tomatoes")
dataset["train"][0]

In [None]:
print(dataset["train"].column_names)


In [None]:
labels = list(set(dataset['train']['label']))
print("Labels:", labels)

In [None]:
sns.countplot(x=dataset['train']['label'])
plt.xlabel('label');

In [None]:
train_text = [instance['text'] for instance in dataset['train']]
train_labels = [instance['label'] for instance in dataset['train']]

test_text = [instance['text'] for instance in dataset['test']]
test_labels = [instance['label'] for instance in dataset['test']]

In [None]:
count_vect = CountVectorizer()
train_features = count_vect.fit_transform(train_text)
test_features = count_vect.transform(test_text)


In [None]:
model = SGDClassifier(loss="log_loss", learning_rate='constant', eta0=0.01).fit(train_features, train_labels)


In [None]:
test_predicted = model.predict(test_features)
test_proba = model.predict_proba(test_features)
accuracy = metrics.accuracy_score(test_labels, test_predicted)
print(accuracy)

In [None]:
wandb.log({"accuracy": accuracy})

In [None]:
wandb.finish()


# What to try next

- experiment with different training parameters (iterations, learning rate, regulartization, ...)
- experiment with different training set sizes
- the dataset also has a validation set, what is the accuracy here?
- analyze the learning curves on training, validation and test set, can you see any signs of over- or underfitting?
- use Weights & Biases to get more insights into the model behavior (confusion matrix, percision-recall curve, ...)
