In [10]:
import numpy as np
import torch
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
import pytorch_lightning as pl
from torchmetrics import Accuracy

## set seed
np.random.seed(121)
torch.manual_seed(121)
pl.seed_everything(121)

X,y = load_breast_cancer(return_X_y = True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify = y)
print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)
X_train = torch.from_numpy(X_train).type(torch.float32)
y_train = torch.from_numpy(y_train).type(torch.float32)
X_test = torch.from_numpy(X_test).type(torch.float32)
y_test = torch.from_numpy(y_test).type(torch.float32)
y_train = y_train.view(y_train.shape[0],1)
y_test = y_test.view(y_test.shape[0],1)
print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)

class MyDataset(Dataset):
    
    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.n_samples = x.shape[0]

    def __len__(self):
        return self.n_samples
    
    def __getitem__(self,index):
        return self.x[index],self.y[index]

class MyModel(pl.LightningModule):

    def __init__(self, in_features, out_features):
        super().__init__()
        self.fc1 = nn.Linear(in_features = in_features, out_features= 16, bias=True)
        self.fc2 = nn.Linear(in_features = 16, out_features= out_features, bias=True)
        self.loss_fn = nn.BCELoss()
        self.accuracy = Accuracy(task = "binary", num_classes = 2)       
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        return x
    
    def training_step(self,batch, batch_idx):
        x, y = batch
        yhat = self(x)
        loss = self.loss_fn(yhat, y)
        acc = self.accuracy(yhat,y)
        self.log_dict({'train_loss': loss, 'train_acc': acc}, on_step = True, on_epoch = True,prog_bar = True)
        return loss

    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        yhat = self(x)
        yhat = yhat.round()
        loss = self.loss_fn(yhat, y)
        acc = self.accuracy(yhat,y)
        self.log_dict({'val_loss': loss, 'val_acc': acc}, on_step = True, on_epoch = True,prog_bar = True)
        return loss
        
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr = 1e-4)
        return optimizer

    
    
## Dataset
train_ds = MyDataset(X_train, y_train)
train_dl = DataLoader(train_ds, batch_size = 32, shuffle = True, num_workers= 16)
test_ds = MyDataset(X_test, y_test)
test_dl = DataLoader(test_ds, batch_size = 32, shuffle = False, num_workers= 16)


## checkpoints
checkpoint_callback  = pl.callbacks.ModelCheckpoint(
                                                filename='{epoch}-{val_loss:.2f}-{val_accuracy:.2f}', 
                                                every_n_epochs = 10,
                                                save_top_k = -1,
                                                monitor='val_loss_epoch'
                                                )


model = MyModel(X_train.shape[1], 1)
trainer = pl.Trainer(accelerator="cpu",
                     max_epochs = 100,
                     check_val_every_n_epoch=50,
                    callbacks=[checkpoint_callback],
                    )
## Train the Model
trainer.fit(model, train_dl, test_dl)

Global seed set to 121
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


(455, 30) (455,)
(114, 30) (114,)
torch.Size([455, 30]) torch.Size([455, 1])
torch.Size([114, 30]) torch.Size([114, 1])


  rank_zero_warn(


## Load Model

In [21]:
model = MyModel.load_from_checkpoint("lightning_logs/version_42/checkpoints/epoch=99-val_loss=7.89-val_accuracy=0.00.ckpt", 
                                     in_features = X_train.shape[1], 
                                     out_features = 1)

## Predict

In [24]:
yhat = model(X_test)
yhat = yhat.round()

In [27]:
yhat.eq(y_test).sum()/len(y_test)

tensor(0.9211)