# Simple Classification Model

In [1]:
from river import metrics
from river.datasets import Phishing
from river_torch.classification import Classifier
from torch import nn

In [8]:
dataset = Phishing()

metric = metrics.Accuracy()

class MyModule(nn.Module):
    def __init__(self, n_features):
        super(MyModule, self).__init__()
        self.dense0 = nn.Linear(n_features,5)
        self.nonlin = nn.ReLU()
        self.dense1 = nn.Linear(5, 2)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, X, **kwargs):
        X = self.nonlin(self.dense0(X))
        X = self.nonlin(self.dense1(X))
        X = self.softmax(X)
        return X

model = Classifier(module=MyModule,loss_fn="binary_cross_entropy",optimizer_fn='sgd')

In [10]:
for x,y in dataset:
    y_pred = model.predict_one(x)      # make a prediction
    metric = metric.update(y, y_pred)  # update the metric
    model = model.learn_one(x, y)    # make the model learn
print(f'Accuracy: {metric.get()}')

Accuracy: 0.4396


# Variable Classifier
This Classifier has a variable output, as the number of labels might vary over time.

In [11]:
from river import metrics
from river.datasets import Phishing
from river_torch.classification import VariableClassifier
from torch import nn

In [14]:
dataset = Phishing()
metric = metrics.Accuracy()

class MyVariableModule(nn.Module):
    def __init__(self, n_features):
        super(MyVariableModule, self).__init__()
        self.dense0 = nn.Linear(n_features,5)
        self.nonlin = nn.ReLU()
        self.dense1 = nn.Linear(5, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, X, **kwargs):
        X = self.nonlin(self.dense0(X))
        X = self.nonlin(self.dense1(X))
        X = self.sigmoid(X)
        return X


model = VariableClassifier(module=MyModule,loss_fn="binary_cross_entropy",optimizer_fn='sgd')

In [15]:
for x,y in dataset:
    y_pred = model.predict_one(x)      # make a prediction
    metric = metric.update(y, y_pred)  # update the metric
    model = model.learn_one(x, y)    # make the model learn
print(f'Accuracy: {metric.get()}')

Accuracy: 0.4376
