In [2]:
import torch
import numpy as np
from sklearn import datasets
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

In [3]:
bc = datasets.load_breast_cancer()

X, y = bc.data, bc.target

n_samples, n_features = X.shape

#split data, train and test

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

#scale features, 0 mean
sc = StandardScaler()

X_train = sc.fit_transform(X_train)
X_test = sc.fit_transform(X_test)

#to tensor
X_train = torch.from_numpy(X_train.astype(np.float32))
X_test = torch.from_numpy(X_test.astype(np.float32))
y_train = torch.from_numpy(y_train.astype(np.float32))
y_test = torch.from_numpy(y_test.astype(np.float32))
#reshape
y_train = y_train.view(y_train.shape[0], 1)
y_test = y_test.view(y_test.shape[0], 1)


class LogRegModel(torch.nn.Module):
  def __init__(self, n_input_features):
    super(LogRegModel, self).__init__()
    self.linear = torch.nn.Linear(n_input_features, 1)

  def forward(self, x):
    out = torch.sigmoid(self.linear(x))
    return out

l_rate = 0.01

model = LogRegModel(n_features)

criterion = torch.nn.BCELoss()

optimizer = torch.optim.SGD(model.parameters(), lr=l_rate)

n_epoch = 1000

for epoch in range(n_epoch):

  y_pred = model(X_train)

  #calculate loss 
  loss = criterion(y_pred, y_train)

  #calculate gradients
  loss.backward()

  #Update parameters
  optimizer.step()

  optimizer.zero_grad()

  if (epoch + 1) % 100 == 0:
    print(f'epoch: {epoch} loss: {loss.item()}')


with torch.no_grad():
  y_pred = model(X_test)
  y_pred_cls = y_pred.round()

  accuracy = y_pred_cls.eq(y_test).sum() / float(y_test.shape[0])

  print(f'Accuracy: {accuracy}')

epoch: 99 loss: 0.24010290205478668
epoch: 199 loss: 0.18158398568630219
epoch: 299 loss: 0.15417730808258057
epoch: 399 loss: 0.13792960345745087
epoch: 499 loss: 0.12704522907733917
epoch: 599 loss: 0.11916771531105042
epoch: 699 loss: 0.11315488815307617
epoch: 799 loss: 0.10838492959737778
epoch: 899 loss: 0.10448922961950302
epoch: 999 loss: 0.101234570145607
Accuracy: 0.9561403393745422
